00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_BLASUTIL_H
00026 #define EIGEN_BLASUTIL_H
00027
00028
00029
00030
00031 namespace Eigen {
00032
00033 namespace internal {
00034
00035
00036 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00037 struct gebp_kernel;
00038
00039 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00040 struct gemm_pack_rhs;
00041
00042 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00043 struct gemm_pack_lhs;
00044
00045 template<
00046 typename Index,
00047 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00048 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00049 int ResStorageOrder>
00050 struct general_matrix_matrix_product;
00051
00052 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
00053 struct general_matrix_vector_product;
00054
00055
00056 template<bool Conjugate> struct conj_if;
00057
00058 template<> struct conj_if<true> {
00059 template<typename T>
00060 inline T operator()(const T& x) { return conj(x); }
00061 template<typename T>
00062 inline T pconj(const T& x) { return internal::pconj(x); }
00063 };
00064
00065 template<> struct conj_if<false> {
00066 template<typename T>
00067 inline const T& operator()(const T& x) { return x; }
00068 template<typename T>
00069 inline const T& pconj(const T& x) { return x; }
00070 };
00071
00072 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00073 {
00074 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00075 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00076 };
00077
00078 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00079 {
00080 typedef std::complex<RealScalar> Scalar;
00081 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00082 { return c + pmul(x,y); }
00083
00084 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00085 { return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); }
00086 };
00087
00088 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00089 {
00090 typedef std::complex<RealScalar> Scalar;
00091 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00092 { return c + pmul(x,y); }
00093
00094 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00095 { return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); }
00096 };
00097
00098 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00099 {
00100 typedef std::complex<RealScalar> Scalar;
00101 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00102 { return c + pmul(x,y); }
00103
00104 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00105 { return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); }
00106 };
00107
00108 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00109 {
00110 typedef std::complex<RealScalar> Scalar;
00111 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00112 { return padd(c, pmul(x,y)); }
00113 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00114 { return conj_if<Conj>()(x)*y; }
00115 };
00116
00117 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00118 {
00119 typedef std::complex<RealScalar> Scalar;
00120 EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00121 { return padd(c, pmul(x,y)); }
00122 EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00123 { return x*conj_if<Conj>()(y); }
00124 };
00125
00126 template<typename From,typename To> struct get_factor {
00127 static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
00128 };
00129
00130 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00131 static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); }
00132 };
00133
00134
00135
00136
00137 template<typename Scalar, typename Index, int StorageOrder>
00138 class blas_data_mapper
00139 {
00140 public:
00141 blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00142 EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
00143 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00144 protected:
00145 Scalar* EIGEN_RESTRICT m_data;
00146 Index m_stride;
00147 };
00148
00149
00150 template<typename Scalar, typename Index, int StorageOrder>
00151 class const_blas_data_mapper
00152 {
00153 public:
00154 const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00155 EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
00156 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00157 protected:
00158 const Scalar* EIGEN_RESTRICT m_data;
00159 Index m_stride;
00160 };
00161
00162
00163
00164
00165
00166 template<typename XprType> struct blas_traits
00167 {
00168 typedef typename traits<XprType>::Scalar Scalar;
00169 typedef const XprType& ExtractType;
00170 typedef XprType _ExtractType;
00171 enum {
00172 IsComplex = NumTraits<Scalar>::IsComplex,
00173 IsTransposed = false,
00174 NeedToConjugate = false,
00175 HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
00176 && ( bool(XprType::IsVectorAtCompileTime)
00177 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00178 ) ? 1 : 0
00179 };
00180 typedef typename conditional<bool(HasUsableDirectAccess),
00181 ExtractType,
00182 typename _ExtractType::PlainObject
00183 >::type DirectLinearAccessType;
00184 static inline ExtractType extract(const XprType& x) { return x; }
00185 static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00186 };
00187
00188
00189 template<typename Scalar, typename NestedXpr>
00190 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00191 : blas_traits<NestedXpr>
00192 {
00193 typedef blas_traits<NestedXpr> Base;
00194 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00195 typedef typename Base::ExtractType ExtractType;
00196
00197 enum {
00198 IsComplex = NumTraits<Scalar>::IsComplex,
00199 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00200 };
00201 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00202 static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00203 };
00204
00205
00206 template<typename Scalar, typename NestedXpr>
00207 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
00208 : blas_traits<NestedXpr>
00209 {
00210 typedef blas_traits<NestedXpr> Base;
00211 typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
00212 typedef typename Base::ExtractType ExtractType;
00213 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00214 static inline Scalar extractScalarFactor(const XprType& x)
00215 { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
00216 };
00217
00218
00219 template<typename Scalar, typename NestedXpr>
00220 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00221 : blas_traits<NestedXpr>
00222 {
00223 typedef blas_traits<NestedXpr> Base;
00224 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00225 typedef typename Base::ExtractType ExtractType;
00226 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00227 static inline Scalar extractScalarFactor(const XprType& x)
00228 { return - Base::extractScalarFactor(x.nestedExpression()); }
00229 };
00230
00231
00232 template<typename NestedXpr>
00233 struct blas_traits<Transpose<NestedXpr> >
00234 : blas_traits<NestedXpr>
00235 {
00236 typedef typename NestedXpr::Scalar Scalar;
00237 typedef blas_traits<NestedXpr> Base;
00238 typedef Transpose<NestedXpr> XprType;
00239 typedef Transpose<const typename Base::_ExtractType> ExtractType;
00240 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00241 typedef typename conditional<bool(Base::HasUsableDirectAccess),
00242 ExtractType,
00243 typename ExtractType::PlainObject
00244 >::type DirectLinearAccessType;
00245 enum {
00246 IsTransposed = Base::IsTransposed ? 0 : 1
00247 };
00248 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00249 static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00250 };
00251
00252 template<typename T>
00253 struct blas_traits<const T>
00254 : blas_traits<T>
00255 {};
00256
00257 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00258 struct extract_data_selector {
00259 static const typename T::Scalar* run(const T& m)
00260 {
00261 return blas_traits<T>::extract(m).data();
00262 }
00263 };
00264
00265 template<typename T>
00266 struct extract_data_selector<T,false> {
00267 static typename T::Scalar* run(const T&) { return 0; }
00268 };
00269
00270 template<typename T> const typename T::Scalar* extract_data(const T& m)
00271 {
00272 return extract_data_selector<T>::run(m);
00273 }
00274
00275 }
00276
00277 }
00278
00279 #endif // EIGEN_BLASUTIL_H