00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00026 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00027
00028 namespace Eigen {
00029
00030 namespace internal {
00031
00032
00033
00034 template<typename Lhs, typename Rhs, typename ResultType>
00035 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance)
00036 {
00037
00038
00039 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00040 typedef typename remove_all<Lhs>::type::Index Index;
00041
00042
00043 Index rows = lhs.innerSize();
00044 Index cols = rhs.outerSize();
00045
00046 eigen_assert(lhs.outerSize() == rhs.innerSize());
00047
00048
00049 AmbiVector<Scalar,Index> tempVector(rows);
00050
00051
00052
00053
00054
00055
00056
00057 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
00058
00059
00060 if(ResultType::IsRowMajor)
00061 res.resize(cols, rows);
00062 else
00063 res.resize(rows, cols);
00064
00065 res.reserve(estimated_nnz_prod);
00066 double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
00067 for (Index j=0; j<cols; ++j)
00068 {
00069
00070
00071
00072 tempVector.init(ratioColRes);
00073 tempVector.setZero();
00074 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00075 {
00076
00077 tempVector.restart();
00078 Scalar x = rhsIt.value();
00079 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00080 {
00081 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00082 }
00083 }
00084 res.startVec(j);
00085 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
00086 res.insertBackByOuterInner(j,it.index()) = it.value();
00087 }
00088 res.finalize();
00089 }
00090
00091 template<typename Lhs, typename Rhs, typename ResultType,
00092 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00093 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00094 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00095 struct sparse_sparse_product_with_pruning_selector;
00096
00097 template<typename Lhs, typename Rhs, typename ResultType>
00098 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00099 {
00100 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00101 typedef typename ResultType::RealScalar RealScalar;
00102
00103 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00104 {
00105 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00106 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
00107 res.swap(_res);
00108 }
00109 };
00110
00111 template<typename Lhs, typename Rhs, typename ResultType>
00112 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00113 {
00114 typedef typename ResultType::RealScalar RealScalar;
00115 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00116 {
00117
00118 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00119 SparseTemporaryType _res(res.rows(), res.cols());
00120 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
00121 res = _res;
00122 }
00123 };
00124
00125 template<typename Lhs, typename Rhs, typename ResultType>
00126 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00127 {
00128 typedef typename ResultType::RealScalar RealScalar;
00129 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00130 {
00131
00132 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00133 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
00134 res.swap(_res);
00135 }
00136 };
00137
00138 template<typename Lhs, typename Rhs, typename ResultType>
00139 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00140 {
00141 typedef typename ResultType::RealScalar RealScalar;
00142 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00143 {
00144 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00145 ColMajorMatrix colLhs(lhs);
00146 ColMajorMatrix colRhs(rhs);
00147 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance);
00148
00149
00150
00151
00152
00153
00154 }
00155 };
00156
00157
00158
00159
00160 }
00161
00162 }
00163
00164 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H