SparseSparseProductWithPruning.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00026 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00027 
00028 namespace Eigen { 
00029 
00030 namespace internal {
00031 
00032 
00033 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
00034 template<typename Lhs, typename Rhs, typename ResultType>
00035 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance)
00036 {
00037   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
00038 
00039   typedef typename remove_all<Lhs>::type::Scalar Scalar;
00040   typedef typename remove_all<Lhs>::type::Index Index;
00041 
00042   // make sure to call innerSize/outerSize since we fake the storage order.
00043   Index rows = lhs.innerSize();
00044   Index cols = rhs.outerSize();
00045   //int size = lhs.outerSize();
00046   eigen_assert(lhs.outerSize() == rhs.innerSize());
00047 
00048   // allocate a temporary buffer
00049   AmbiVector<Scalar,Index> tempVector(rows);
00050 
00051   // estimate the number of non zero entries
00052   // given a rhs column containing Y non zeros, we assume that the respective Y columns
00053   // of the lhs differs in average of one non zeros, thus the number of non zeros for
00054   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
00055   // per column of the lhs.
00056   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
00057   Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
00058 
00059   // mimics a resizeByInnerOuter:
00060   if(ResultType::IsRowMajor)
00061     res.resize(cols, rows);
00062   else
00063     res.resize(rows, cols);
00064 
00065   res.reserve(estimated_nnz_prod);
00066   double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
00067   for (Index j=0; j<cols; ++j)
00068   {
00069     // FIXME:
00070     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
00071     // let's do a more accurate determination of the nnz ratio for the current column j of res
00072     tempVector.init(ratioColRes);
00073     tempVector.setZero();
00074     for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00075     {
00076       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
00077       tempVector.restart();
00078       Scalar x = rhsIt.value();
00079       for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00080       {
00081         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00082       }
00083     }
00084     res.startVec(j);
00085     for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
00086       res.insertBackByOuterInner(j,it.index()) = it.value();
00087   }
00088   res.finalize();
00089 }
00090 
00091 template<typename Lhs, typename Rhs, typename ResultType,
00092   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00093   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00094   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00095 struct sparse_sparse_product_with_pruning_selector;
00096 
00097 template<typename Lhs, typename Rhs, typename ResultType>
00098 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00099 {
00100   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00101   typedef typename ResultType::RealScalar RealScalar;
00102 
00103   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00104   {
00105     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00106     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
00107     res.swap(_res);
00108   }
00109 };
00110 
00111 template<typename Lhs, typename Rhs, typename ResultType>
00112 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00113 {
00114   typedef typename ResultType::RealScalar RealScalar;
00115   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00116   {
00117     // we need a col-major matrix to hold the result
00118     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00119     SparseTemporaryType _res(res.rows(), res.cols());
00120     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
00121     res = _res;
00122   }
00123 };
00124 
00125 template<typename Lhs, typename Rhs, typename ResultType>
00126 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00127 {
00128   typedef typename ResultType::RealScalar RealScalar;
00129   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00130   {
00131     // let's transpose the product to get a column x column product
00132     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00133     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
00134     res.swap(_res);
00135   }
00136 };
00137 
00138 template<typename Lhs, typename Rhs, typename ResultType>
00139 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00140 {
00141   typedef typename ResultType::RealScalar RealScalar;
00142   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00143   {
00144     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00145     ColMajorMatrix colLhs(lhs);
00146     ColMajorMatrix colRhs(rhs);
00147     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance);
00148 
00149     // let's transpose the product to get a column x column product
00150 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00151 //     SparseTemporaryType _res(res.cols(), res.rows());
00152 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
00153 //     res = _res.transpose();
00154   }
00155 };
00156 
00157 // NOTE the 2 others cases (col row *) must never occur since they are caught
00158 // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
00159 
00160 } // end namespace internal
00161 
00162 } // end namespace Eigen
00163 
00164 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H