BlasUtil.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_BLASUTIL_H
00026 #define EIGEN_BLASUTIL_H
00027 
00028 // This file contains many lightweight helper classes used to
00029 // implement and control fast level 2 and level 3 BLAS-like routines.
00030 
00031 namespace Eigen {
00032 
00033 namespace internal {
00034 
00035 // forward declarations
00036 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00037 struct gebp_kernel;
00038 
00039 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00040 struct gemm_pack_rhs;
00041 
00042 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00043 struct gemm_pack_lhs;
00044 
00045 template<
00046   typename Index,
00047   typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00048   typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00049   int ResStorageOrder>
00050 struct general_matrix_matrix_product;
00051 
00052 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
00053 struct general_matrix_vector_product;
00054 
00055 
00056 template<bool Conjugate> struct conj_if;
00057 
00058 template<> struct conj_if<true> {
00059   template<typename T>
00060   inline T operator()(const T& x) { return conj(x); }
00061   template<typename T>
00062   inline T pconj(const T& x) { return internal::pconj(x); }
00063 };
00064 
00065 template<> struct conj_if<false> {
00066   template<typename T>
00067   inline const T& operator()(const T& x) { return x; }
00068   template<typename T>
00069   inline const T& pconj(const T& x) { return x; }
00070 };
00071 
00072 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00073 {
00074   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00075   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00076 };
00077 
00078 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00079 {
00080   typedef std::complex<RealScalar> Scalar;
00081   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00082   { return c + pmul(x,y); }
00083 
00084   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00085   { return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); }
00086 };
00087 
00088 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00089 {
00090   typedef std::complex<RealScalar> Scalar;
00091   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00092   { return c + pmul(x,y); }
00093 
00094   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00095   { return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); }
00096 };
00097 
00098 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00099 {
00100   typedef std::complex<RealScalar> Scalar;
00101   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00102   { return c + pmul(x,y); }
00103 
00104   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00105   { return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); }
00106 };
00107 
00108 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00109 {
00110   typedef std::complex<RealScalar> Scalar;
00111   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00112   { return padd(c, pmul(x,y)); }
00113   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00114   { return conj_if<Conj>()(x)*y; }
00115 };
00116 
00117 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00118 {
00119   typedef std::complex<RealScalar> Scalar;
00120   EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00121   { return padd(c, pmul(x,y)); }
00122   EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00123   { return x*conj_if<Conj>()(y); }
00124 };
00125 
00126 template<typename From,typename To> struct get_factor {
00127   static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
00128 };
00129 
00130 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00131   static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); }
00132 };
00133 
00134 // Lightweight helper class to access matrix coefficients.
00135 // Yes, this is somehow redundant with Map<>, but this version is much much lighter,
00136 // and so I hope better compilation performance (time and code quality).
00137 template<typename Scalar, typename Index, int StorageOrder>
00138 class blas_data_mapper
00139 {
00140   public:
00141     blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00142     EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
00143     { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00144   protected:
00145     Scalar* EIGEN_RESTRICT m_data;
00146     Index m_stride;
00147 };
00148 
00149 // lightweight helper class to access matrix coefficients (const version)
00150 template<typename Scalar, typename Index, int StorageOrder>
00151 class const_blas_data_mapper
00152 {
00153   public:
00154     const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00155     EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
00156     { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00157   protected:
00158     const Scalar* EIGEN_RESTRICT m_data;
00159     Index m_stride;
00160 };
00161 
00162 
00163 /* Helper class to analyze the factors of a Product expression.
00164  * In particular it allows to pop out operator-, scalar multiples,
00165  * and conjugate */
00166 template<typename XprType> struct blas_traits
00167 {
00168   typedef typename traits<XprType>::Scalar Scalar;
00169   typedef const XprType& ExtractType;
00170   typedef XprType _ExtractType;
00171   enum {
00172     IsComplex = NumTraits<Scalar>::IsComplex,
00173     IsTransposed = false,
00174     NeedToConjugate = false,
00175     HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
00176                               && (   bool(XprType::IsVectorAtCompileTime)
00177                                   || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00178                              ) ?  1 : 0
00179   };
00180   typedef typename conditional<bool(HasUsableDirectAccess),
00181     ExtractType,
00182     typename _ExtractType::PlainObject
00183     >::type DirectLinearAccessType;
00184   static inline ExtractType extract(const XprType& x) { return x; }
00185   static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00186 };
00187 
00188 // pop conjugate
00189 template<typename Scalar, typename NestedXpr>
00190 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00191  : blas_traits<NestedXpr>
00192 {
00193   typedef blas_traits<NestedXpr> Base;
00194   typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00195   typedef typename Base::ExtractType ExtractType;
00196 
00197   enum {
00198     IsComplex = NumTraits<Scalar>::IsComplex,
00199     NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00200   };
00201   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00202   static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00203 };
00204 
00205 // pop scalar multiple
00206 template<typename Scalar, typename NestedXpr>
00207 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
00208  : blas_traits<NestedXpr>
00209 {
00210   typedef blas_traits<NestedXpr> Base;
00211   typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
00212   typedef typename Base::ExtractType ExtractType;
00213   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00214   static inline Scalar extractScalarFactor(const XprType& x)
00215   { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
00216 };
00217 
00218 // pop opposite
00219 template<typename Scalar, typename NestedXpr>
00220 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00221  : blas_traits<NestedXpr>
00222 {
00223   typedef blas_traits<NestedXpr> Base;
00224   typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00225   typedef typename Base::ExtractType ExtractType;
00226   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00227   static inline Scalar extractScalarFactor(const XprType& x)
00228   { return - Base::extractScalarFactor(x.nestedExpression()); }
00229 };
00230 
00231 // pop/push transpose
00232 template<typename NestedXpr>
00233 struct blas_traits<Transpose<NestedXpr> >
00234  : blas_traits<NestedXpr>
00235 {
00236   typedef typename NestedXpr::Scalar Scalar;
00237   typedef blas_traits<NestedXpr> Base;
00238   typedef Transpose<NestedXpr> XprType;
00239   typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
00240   typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00241   typedef typename conditional<bool(Base::HasUsableDirectAccess),
00242     ExtractType,
00243     typename ExtractType::PlainObject
00244     >::type DirectLinearAccessType;
00245   enum {
00246     IsTransposed = Base::IsTransposed ? 0 : 1
00247   };
00248   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00249   static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00250 };
00251 
00252 template<typename T>
00253 struct blas_traits<const T>
00254      : blas_traits<T>
00255 {};
00256 
00257 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00258 struct extract_data_selector {
00259   static const typename T::Scalar* run(const T& m)
00260   {
00261     return blas_traits<T>::extract(m).data();
00262   }
00263 };
00264 
00265 template<typename T>
00266 struct extract_data_selector<T,false> {
00267   static typename T::Scalar* run(const T&) { return 0; }
00268 };
00269 
00270 template<typename T> const typename T::Scalar* extract_data(const T& m)
00271 {
00272   return extract_data_selector<T>::run(m);
00273 }
00274 
00275 } // end namespace internal
00276 
00277 } // end namespace Eigen
00278 
00279 #endif // EIGEN_BLASUTIL_H