TriangularMatrixVector.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
00026 #define EIGEN_TRIANGULARMATRIXVECTOR_H
00027 
00028 namespace Eigen { 
00029 
00030 namespace internal {
00031 
00032 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
00033 struct triangular_matrix_vector_product;
00034 
00035 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
00036 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
00037 {
00038   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00039   enum {
00040     IsLower = ((Mode&Lower)==Lower),
00041     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
00042     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
00043   };
00044   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
00045                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00046   {
00047     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00048     Index size = (std::min)(_rows,_cols);
00049     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
00050     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
00051 
00052     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
00053     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00054     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00055     
00056     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
00057     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
00058     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00059 
00060     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
00061     ResMap res(_res,rows);
00062 
00063     for (Index pi=0; pi<size; pi+=PanelWidth)
00064     {
00065       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
00066       for (Index k=0; k<actualPanelWidth; ++k)
00067       {
00068         Index i = pi + k;
00069         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
00070         Index r = IsLower ? actualPanelWidth-k : k+1;
00071         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
00072           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
00073         if (HasUnitDiag)
00074           res.coeffRef(i) += alpha * cjRhs.coeff(i);
00075       }
00076       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
00077       if (r>0)
00078       {
00079         Index s = IsLower ? pi+actualPanelWidth : 0;
00080         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
00081             r, actualPanelWidth,
00082             &lhs.coeffRef(s,pi), lhsStride,
00083             &rhs.coeffRef(pi), rhsIncr,
00084             &res.coeffRef(s), resIncr, alpha);
00085       }
00086     }
00087     if((!IsLower) && cols>size)
00088     {
00089       general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00090           rows, cols-size,
00091           &lhs.coeffRef(0,size), lhsStride,
00092           &rhs.coeffRef(size), rhsIncr,
00093           _res, resIncr, alpha);
00094     }
00095   }
00096 };
00097 
00098 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
00099 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
00100 {
00101   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00102   enum {
00103     IsLower = ((Mode&Lower)==Lower),
00104     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
00105     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
00106   };
00107   static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
00108                   const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
00109   {
00110     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
00111     Index diagSize = (std::min)(_rows,_cols);
00112     Index rows = IsLower ? _rows : diagSize;
00113     Index cols = IsLower ? diagSize : _cols;
00114 
00115     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
00116     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
00117     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
00118 
00119     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
00120     const RhsMap rhs(_rhs,cols);
00121     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
00122 
00123     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
00124     ResMap res(_res,rows,InnerStride<>(resIncr));
00125     
00126     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
00127     {
00128       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
00129       for (Index k=0; k<actualPanelWidth; ++k)
00130       {
00131         Index i = pi + k;
00132         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
00133         Index r = IsLower ? k+1 : actualPanelWidth-k;
00134         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
00135           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
00136         if (HasUnitDiag)
00137           res.coeffRef(i) += alpha * cjRhs.coeff(i);
00138       }
00139       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
00140       if (r>0)
00141       {
00142         Index s = IsLower ? 0 : pi + actualPanelWidth;
00143         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
00144             actualPanelWidth, r,
00145             &lhs.coeffRef(pi,s), lhsStride,
00146             &rhs.coeffRef(s), rhsIncr,
00147             &res.coeffRef(pi), resIncr, alpha);
00148       }
00149     }
00150     if(IsLower && rows>diagSize)
00151     {
00152       general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
00153             rows-diagSize, cols,
00154             &lhs.coeffRef(diagSize,0), lhsStride,
00155             &rhs.coeffRef(0), rhsIncr,
00156             &res.coeffRef(diagSize), resIncr, alpha);
00157     }
00158   }
00159 };
00160 
00161 /***************************************************************************
00162 * Wrapper to product_triangular_vector
00163 ***************************************************************************/
00164 
00165 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00166 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
00167  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
00168 {};
00169 
00170 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00171 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
00172  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
00173 {};
00174 
00175 
00176 template<int StorageOrder>
00177 struct trmv_selector;
00178 
00179 } // end namespace internal
00180 
00181 template<int Mode, typename Lhs, typename Rhs>
00182 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
00183   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
00184 {
00185   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00186 
00187   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00188 
00189   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00190   {
00191     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00192   
00193     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
00194   }
00195 };
00196 
00197 template<int Mode, typename Lhs, typename Rhs>
00198 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
00199   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
00200 {
00201   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00202 
00203   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00204 
00205   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00206   {
00207     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00208 
00209     typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
00210     Transpose<Dest> dstT(dst);
00211     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
00212       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
00213   }
00214 };
00215 
00216 namespace internal {
00217 
00218 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
00219   
00220 template<> struct trmv_selector<ColMajor>
00221 {
00222   template<int Mode, typename Lhs, typename Rhs, typename Dest>
00223   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00224   {
00225     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00226     typedef typename ProductType::Index Index;
00227     typedef typename ProductType::LhsScalar   LhsScalar;
00228     typedef typename ProductType::RhsScalar   RhsScalar;
00229     typedef typename ProductType::Scalar      ResScalar;
00230     typedef typename ProductType::RealScalar  RealScalar;
00231     typedef typename ProductType::ActualLhsType ActualLhsType;
00232     typedef typename ProductType::ActualRhsType ActualRhsType;
00233     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00234     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00235     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
00236 
00237     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
00238     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
00239 
00240     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00241                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
00242 
00243     enum {
00244       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
00245       // on, the other hand it is good for the cache to pack the vector anyways...
00246       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
00247       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
00248       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
00249     };
00250 
00251     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
00252 
00253     bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
00254     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
00255     
00256     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
00257 
00258     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
00259                                                   evalToDest ? dest.data() : static_dest.data());
00260 
00261     if(!evalToDest)
00262     {
00263       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00264       int size = dest.size();
00265       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00266       #endif
00267       if(!alphaIsCompatible)
00268       {
00269         MappedDest(actualDestPtr, dest.size()).setZero();
00270         compatibleAlpha = RhsScalar(1);
00271       }
00272       else
00273         MappedDest(actualDestPtr, dest.size()) = dest;
00274     }
00275     
00276     internal::triangular_matrix_vector_product
00277       <Index,Mode,
00278        LhsScalar, LhsBlasTraits::NeedToConjugate,
00279        RhsScalar, RhsBlasTraits::NeedToConjugate,
00280        ColMajor>
00281       ::run(actualLhs.rows(),actualLhs.cols(),
00282             actualLhs.data(),actualLhs.outerStride(),
00283             actualRhs.data(),actualRhs.innerStride(),
00284             actualDestPtr,1,compatibleAlpha);
00285 
00286     if (!evalToDest)
00287     {
00288       if(!alphaIsCompatible)
00289         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
00290       else
00291         dest = MappedDest(actualDestPtr, dest.size());
00292     }
00293   }
00294 };
00295 
00296 template<> struct trmv_selector<RowMajor>
00297 {
00298   template<int Mode, typename Lhs, typename Rhs, typename Dest>
00299   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
00300   {
00301     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
00302     typedef typename ProductType::LhsScalar LhsScalar;
00303     typedef typename ProductType::RhsScalar RhsScalar;
00304     typedef typename ProductType::Scalar    ResScalar;
00305     typedef typename ProductType::Index Index;
00306     typedef typename ProductType::ActualLhsType ActualLhsType;
00307     typedef typename ProductType::ActualRhsType ActualRhsType;
00308     typedef typename ProductType::_ActualRhsType _ActualRhsType;
00309     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
00310     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
00311 
00312     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
00313     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
00314 
00315     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
00316                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
00317 
00318     enum {
00319       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
00320     };
00321 
00322     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
00323 
00324     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
00325         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
00326 
00327     if(!DirectlyUseRhs)
00328     {
00329       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00330       int size = actualRhs.size();
00331       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
00332       #endif
00333       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
00334     }
00335     
00336     internal::triangular_matrix_vector_product
00337       <Index,Mode,
00338        LhsScalar, LhsBlasTraits::NeedToConjugate,
00339        RhsScalar, RhsBlasTraits::NeedToConjugate,
00340        RowMajor>
00341       ::run(actualLhs.rows(),actualLhs.cols(),
00342             actualLhs.data(),actualLhs.outerStride(),
00343             actualRhsPtr,1,
00344             dest.data(),dest.innerStride(),
00345             actualAlpha);
00346   }
00347 };
00348 
00349 } // end namespace internal
00350 
00351 } // end namespace Eigen
00352 
00353 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H