SparseDenseProduct.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
00026 #define EIGEN_SPARSEDENSEPRODUCT_H
00027 
00028 namespace Eigen { 
00029 
00030 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
00031 {
00032   typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
00033 };
00034 
00035 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
00036 {
00037   typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type;
00038 };
00039 
00040 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
00041 {
00042   typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
00043 };
00044 
00045 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
00046 {
00047   typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type;
00048 };
00049 
00050 namespace internal {
00051 
00052 template<typename Lhs, typename Rhs, bool Tr>
00053 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00054 {
00055   typedef Sparse StorageKind;
00056   typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
00057                                             typename traits<Rhs>::Scalar>::ReturnType Scalar;
00058   typedef typename Lhs::Index Index;
00059   typedef typename Lhs::Nested LhsNested;
00060   typedef typename Rhs::Nested RhsNested;
00061   typedef typename remove_all<LhsNested>::type _LhsNested;
00062   typedef typename remove_all<RhsNested>::type _RhsNested;
00063 
00064   enum {
00065     LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
00066     RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
00067 
00068     RowsAtCompileTime    = Tr ? int(traits<Rhs>::RowsAtCompileTime)     : int(traits<Lhs>::RowsAtCompileTime),
00069     ColsAtCompileTime    = Tr ? int(traits<Lhs>::ColsAtCompileTime)     : int(traits<Rhs>::ColsAtCompileTime),
00070     MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime)  : int(traits<Lhs>::MaxRowsAtCompileTime),
00071     MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime)  : int(traits<Rhs>::MaxColsAtCompileTime),
00072 
00073     Flags = Tr ? RowMajorBit : 0,
00074 
00075     CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
00076   };
00077 };
00078 
00079 } // end namespace internal
00080 
00081 template<typename Lhs, typename Rhs, bool Tr>
00082 class SparseDenseOuterProduct
00083  : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00084 {
00085   public:
00086 
00087     typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
00088     EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
00089     typedef internal::traits<SparseDenseOuterProduct> Traits;
00090 
00091   private:
00092 
00093     typedef typename Traits::LhsNested LhsNested;
00094     typedef typename Traits::RhsNested RhsNested;
00095     typedef typename Traits::_LhsNested _LhsNested;
00096     typedef typename Traits::_RhsNested _RhsNested;
00097 
00098   public:
00099 
00100     class InnerIterator;
00101 
00102     EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
00103       : m_lhs(lhs), m_rhs(rhs)
00104     {
00105       EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00106     }
00107 
00108     EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
00109       : m_lhs(lhs), m_rhs(rhs)
00110     {
00111       EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00112     }
00113 
00114     EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
00115     EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
00116 
00117     EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00118     EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00119 
00120   protected:
00121     LhsNested m_lhs;
00122     RhsNested m_rhs;
00123 };
00124 
00125 template<typename Lhs, typename Rhs, bool Transpose>
00126 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
00127 {
00128     typedef typename _LhsNested::InnerIterator Base;
00129   public:
00130     EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
00131       : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
00132     {
00133     }
00134 
00135     inline Index outer() const { return m_outer; }
00136     inline Index row() const { return Transpose ? Base::row() : m_outer; }
00137     inline Index col() const { return Transpose ? m_outer : Base::row(); }
00138 
00139     inline Scalar value() const { return Base::value() * m_factor; }
00140 
00141   protected:
00142     int m_outer;
00143     Scalar m_factor;
00144 };
00145 
00146 namespace internal {
00147 template<typename Lhs, typename Rhs>
00148 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
00149  : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
00150 {
00151   typedef Dense StorageKind;
00152   typedef MatrixXpr XprKind;
00153 };
00154 
00155 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
00156          int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
00157          bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
00158 struct sparse_time_dense_product_impl;
00159 
00160 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00161 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
00162 {
00163   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00164   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00165   typedef typename internal::remove_all<DenseResType>::type Res;
00166   typedef typename Lhs::Index Index;
00167   typedef typename Lhs::InnerIterator LhsInnerIterator;
00168   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00169   {
00170     for(Index c=0; c<rhs.cols(); ++c)
00171     {
00172       int n = lhs.outerSize();
00173       for(Index j=0; j<n; ++j)
00174       {
00175         typename Res::Scalar tmp(0);
00176         for(LhsInnerIterator it(lhs,j); it ;++it)
00177           tmp += it.value() * rhs.coeff(it.index(),c);
00178         res.coeffRef(j,c) = alpha * tmp;
00179       }
00180     }
00181   }
00182 };
00183 
00184 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00185 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
00186 {
00187   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00188   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00189   typedef typename internal::remove_all<DenseResType>::type Res;
00190   typedef typename Lhs::InnerIterator LhsInnerIterator;
00191   typedef typename Lhs::Index Index;
00192   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00193   {
00194     for(Index c=0; c<rhs.cols(); ++c)
00195     {
00196       for(Index j=0; j<lhs.outerSize(); ++j)
00197       {
00198         typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
00199         for(LhsInnerIterator it(lhs,j); it ;++it)
00200           res.coeffRef(it.index(),c) += it.value() * rhs_j;
00201       }
00202     }
00203   }
00204 };
00205 
00206 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00207 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
00208 {
00209   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00210   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00211   typedef typename internal::remove_all<DenseResType>::type Res;
00212   typedef typename Lhs::InnerIterator LhsInnerIterator;
00213   typedef typename Lhs::Index Index;
00214   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00215   {
00216     for(Index j=0; j<lhs.outerSize(); ++j)
00217     {
00218       typename Res::RowXpr res_j(res.row(j));
00219       for(LhsInnerIterator it(lhs,j); it ;++it)
00220         res_j += (alpha*it.value()) * rhs.row(it.index());
00221     }
00222   }
00223 };
00224 
00225 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00226 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
00227 {
00228   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00229   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00230   typedef typename internal::remove_all<DenseResType>::type Res;
00231   typedef typename Lhs::InnerIterator LhsInnerIterator;
00232   typedef typename Lhs::Index Index;
00233   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00234   {
00235     for(Index j=0; j<lhs.outerSize(); ++j)
00236     {
00237       typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
00238       for(LhsInnerIterator it(lhs,j); it ;++it)
00239         res.row(it.index()) += (alpha*it.value()) * rhs_j;
00240     }
00241   }
00242 };
00243 
00244 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
00245 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
00246 {
00247   sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
00248 }
00249 
00250 } // end namespace internal
00251 
00252 template<typename Lhs, typename Rhs>
00253 class SparseTimeDenseProduct
00254   : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
00255 {
00256   public:
00257     EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
00258 
00259     SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00260     {}
00261 
00262     template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00263     {
00264       internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
00265     }
00266 
00267   private:
00268     SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
00269 };
00270 
00271 
00272 // dense = dense * sparse
00273 namespace internal {
00274 template<typename Lhs, typename Rhs>
00275 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
00276  : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
00277 {
00278   typedef Dense StorageKind;
00279 };
00280 } // end namespace internal
00281 
00282 template<typename Lhs, typename Rhs>
00283 class DenseTimeSparseProduct
00284   : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
00285 {
00286   public:
00287     EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
00288 
00289     DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00290     {}
00291 
00292     template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00293     {
00294       Transpose<const _LhsNested> lhs_t(m_lhs);
00295       Transpose<const _RhsNested> rhs_t(m_rhs);
00296       Transpose<Dest> dest_t(dest);
00297       internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
00298     }
00299 
00300   private:
00301     DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
00302 };
00303 
00304 // sparse * dense
00305 template<typename Derived>
00306 template<typename OtherDerived>
00307 inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
00308 SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
00309 {
00310   return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00311 }
00312 
00313 } // end namespace Eigen
00314 
00315 #endif // EIGEN_SPARSEDENSEPRODUCT_H