00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
00026 #define EIGEN_SPARSEDENSEPRODUCT_H
00027
00028 namespace Eigen {
00029
00030 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
00031 {
00032 typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
00033 };
00034
00035 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
00036 {
00037 typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type;
00038 };
00039
00040 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
00041 {
00042 typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
00043 };
00044
00045 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
00046 {
00047 typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type;
00048 };
00049
00050 namespace internal {
00051
00052 template<typename Lhs, typename Rhs, bool Tr>
00053 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00054 {
00055 typedef Sparse StorageKind;
00056 typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
00057 typename traits<Rhs>::Scalar>::ReturnType Scalar;
00058 typedef typename Lhs::Index Index;
00059 typedef typename Lhs::Nested LhsNested;
00060 typedef typename Rhs::Nested RhsNested;
00061 typedef typename remove_all<LhsNested>::type _LhsNested;
00062 typedef typename remove_all<RhsNested>::type _RhsNested;
00063
00064 enum {
00065 LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
00066 RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
00067
00068 RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
00069 ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
00070 MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
00071 MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
00072
00073 Flags = Tr ? RowMajorBit : 0,
00074
00075 CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
00076 };
00077 };
00078
00079 }
00080
00081 template<typename Lhs, typename Rhs, bool Tr>
00082 class SparseDenseOuterProduct
00083 : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00084 {
00085 public:
00086
00087 typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
00088 EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
00089 typedef internal::traits<SparseDenseOuterProduct> Traits;
00090
00091 private:
00092
00093 typedef typename Traits::LhsNested LhsNested;
00094 typedef typename Traits::RhsNested RhsNested;
00095 typedef typename Traits::_LhsNested _LhsNested;
00096 typedef typename Traits::_RhsNested _RhsNested;
00097
00098 public:
00099
00100 class InnerIterator;
00101
00102 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
00103 : m_lhs(lhs), m_rhs(rhs)
00104 {
00105 EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00106 }
00107
00108 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
00109 : m_lhs(lhs), m_rhs(rhs)
00110 {
00111 EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00112 }
00113
00114 EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
00115 EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
00116
00117 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00118 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00119
00120 protected:
00121 LhsNested m_lhs;
00122 RhsNested m_rhs;
00123 };
00124
00125 template<typename Lhs, typename Rhs, bool Transpose>
00126 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
00127 {
00128 typedef typename _LhsNested::InnerIterator Base;
00129 public:
00130 EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
00131 : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
00132 {
00133 }
00134
00135 inline Index outer() const { return m_outer; }
00136 inline Index row() const { return Transpose ? Base::row() : m_outer; }
00137 inline Index col() const { return Transpose ? m_outer : Base::row(); }
00138
00139 inline Scalar value() const { return Base::value() * m_factor; }
00140
00141 protected:
00142 int m_outer;
00143 Scalar m_factor;
00144 };
00145
00146 namespace internal {
00147 template<typename Lhs, typename Rhs>
00148 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
00149 : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
00150 {
00151 typedef Dense StorageKind;
00152 typedef MatrixXpr XprKind;
00153 };
00154
00155 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
00156 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
00157 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
00158 struct sparse_time_dense_product_impl;
00159
00160 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00161 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
00162 {
00163 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00164 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00165 typedef typename internal::remove_all<DenseResType>::type Res;
00166 typedef typename Lhs::Index Index;
00167 typedef typename Lhs::InnerIterator LhsInnerIterator;
00168 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00169 {
00170 for(Index c=0; c<rhs.cols(); ++c)
00171 {
00172 int n = lhs.outerSize();
00173 for(Index j=0; j<n; ++j)
00174 {
00175 typename Res::Scalar tmp(0);
00176 for(LhsInnerIterator it(lhs,j); it ;++it)
00177 tmp += it.value() * rhs.coeff(it.index(),c);
00178 res.coeffRef(j,c) = alpha * tmp;
00179 }
00180 }
00181 }
00182 };
00183
00184 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00185 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
00186 {
00187 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00188 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00189 typedef typename internal::remove_all<DenseResType>::type Res;
00190 typedef typename Lhs::InnerIterator LhsInnerIterator;
00191 typedef typename Lhs::Index Index;
00192 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00193 {
00194 for(Index c=0; c<rhs.cols(); ++c)
00195 {
00196 for(Index j=0; j<lhs.outerSize(); ++j)
00197 {
00198 typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
00199 for(LhsInnerIterator it(lhs,j); it ;++it)
00200 res.coeffRef(it.index(),c) += it.value() * rhs_j;
00201 }
00202 }
00203 }
00204 };
00205
00206 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00207 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
00208 {
00209 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00210 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00211 typedef typename internal::remove_all<DenseResType>::type Res;
00212 typedef typename Lhs::InnerIterator LhsInnerIterator;
00213 typedef typename Lhs::Index Index;
00214 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00215 {
00216 for(Index j=0; j<lhs.outerSize(); ++j)
00217 {
00218 typename Res::RowXpr res_j(res.row(j));
00219 for(LhsInnerIterator it(lhs,j); it ;++it)
00220 res_j += (alpha*it.value()) * rhs.row(it.index());
00221 }
00222 }
00223 };
00224
00225 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00226 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
00227 {
00228 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00229 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00230 typedef typename internal::remove_all<DenseResType>::type Res;
00231 typedef typename Lhs::InnerIterator LhsInnerIterator;
00232 typedef typename Lhs::Index Index;
00233 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00234 {
00235 for(Index j=0; j<lhs.outerSize(); ++j)
00236 {
00237 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
00238 for(LhsInnerIterator it(lhs,j); it ;++it)
00239 res.row(it.index()) += (alpha*it.value()) * rhs_j;
00240 }
00241 }
00242 };
00243
00244 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
00245 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
00246 {
00247 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
00248 }
00249
00250 }
00251
00252 template<typename Lhs, typename Rhs>
00253 class SparseTimeDenseProduct
00254 : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
00255 {
00256 public:
00257 EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
00258
00259 SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00260 {}
00261
00262 template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00263 {
00264 internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
00265 }
00266
00267 private:
00268 SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
00269 };
00270
00271
00272
00273 namespace internal {
00274 template<typename Lhs, typename Rhs>
00275 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
00276 : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
00277 {
00278 typedef Dense StorageKind;
00279 };
00280 }
00281
00282 template<typename Lhs, typename Rhs>
00283 class DenseTimeSparseProduct
00284 : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
00285 {
00286 public:
00287 EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
00288
00289 DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00290 {}
00291
00292 template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00293 {
00294 Transpose<const _LhsNested> lhs_t(m_lhs);
00295 Transpose<const _RhsNested> rhs_t(m_rhs);
00296 Transpose<Dest> dest_t(dest);
00297 internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
00298 }
00299
00300 private:
00301 DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
00302 };
00303
00304
00305 template<typename Derived>
00306 template<typename OtherDerived>
00307 inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
00308 SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
00309 {
00310 return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00311 }
00312
00313 }
00314
00315 #endif // EIGEN_SPARSEDENSEPRODUCT_H