SparseCwiseBinaryOp.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
00026 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
00027 
00028 namespace Eigen { 
00029 
00030 // Here we have to handle 3 cases:
00031 //  1 - sparse op dense
00032 //  2 - dense op sparse
00033 //  3 - sparse op sparse
00034 // We also need to implement a 4th iterator for:
00035 //  4 - dense op dense
00036 // Finally, we also need to distinguish between the product and other operations :
00037 //                configuration      returned mode
00038 //  1 - sparse op dense    product      sparse
00039 //                         generic      dense
00040 //  2 - dense op sparse    product      sparse
00041 //                         generic      dense
00042 //  3 - sparse op sparse   product      sparse
00043 //                         generic      sparse
00044 //  4 - dense op dense     product      dense
00045 //                         generic      dense
00046 
00047 namespace internal {
00048 
00049 template<> struct promote_storage_type<Dense,Sparse>
00050 { typedef Sparse ret; };
00051 
00052 template<> struct promote_storage_type<Sparse,Dense>
00053 { typedef Sparse ret; };
00054 
00055 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
00056   typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
00057   typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
00058 class sparse_cwise_binary_op_inner_iterator_selector;
00059 
00060 } // end namespace internal
00061 
00062 template<typename BinaryOp, typename Lhs, typename Rhs>
00063 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
00064   : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
00065 {
00066   public:
00067     class InnerIterator;
00068     class ReverseInnerIterator;
00069     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived;
00070     EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
00071     CwiseBinaryOpImpl()
00072     {
00073       typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
00074       typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
00075       EIGEN_STATIC_ASSERT((
00076                 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
00077             ||  ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
00078             THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
00079     }
00080 };
00081 
00082 template<typename BinaryOp, typename Lhs, typename Rhs>
00083 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
00084   : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
00085 {
00086   public:
00087     typedef typename Lhs::Index Index;
00088     typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
00089       BinaryOp,Lhs,Rhs, InnerIterator> Base;
00090 
00091     EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename CwiseBinaryOpImpl::Index outer)
00092       : Base(binOp.derived(),outer)
00093     {}
00094 };
00095 
00096 /***************************************************************************
00097 * Implementation of inner-iterators
00098 ***************************************************************************/
00099 
00100 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; };
00101 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; };
00102 
00103 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any !
00104 
00105 namespace internal {
00106 
00107 // sparse - sparse  (generic)
00108 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
00109 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
00110 {
00111     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
00112     typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
00113     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00114     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00115     typedef typename _LhsNested::InnerIterator LhsIterator;
00116     typedef typename _RhsNested::InnerIterator RhsIterator;
00117     typedef typename Lhs::Index Index;
00118 
00119   public:
00120 
00121     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00122       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00123     {
00124       this->operator++();
00125     }
00126 
00127     EIGEN_STRONG_INLINE Derived& operator++()
00128     {
00129       if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
00130       {
00131         m_id = m_lhsIter.index();
00132         m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
00133         ++m_lhsIter;
00134         ++m_rhsIter;
00135       }
00136       else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
00137       {
00138         m_id = m_lhsIter.index();
00139         m_value = m_functor(m_lhsIter.value(), Scalar(0));
00140         ++m_lhsIter;
00141       }
00142       else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
00143       {
00144         m_id = m_rhsIter.index();
00145         m_value = m_functor(Scalar(0), m_rhsIter.value());
00146         ++m_rhsIter;
00147       }
00148       else
00149       {
00150         m_value = 0; // this is to avoid a compilation warning
00151         m_id = -1;
00152       }
00153       return *static_cast<Derived*>(this);
00154     }
00155 
00156     EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
00157 
00158     EIGEN_STRONG_INLINE Index index() const { return m_id; }
00159     EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
00160     EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
00161 
00162     EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
00163 
00164   protected:
00165     LhsIterator m_lhsIter;
00166     RhsIterator m_rhsIter;
00167     const BinaryOp& m_functor;
00168     Scalar m_value;
00169     Index m_id;
00170 };
00171 
00172 // sparse - sparse  (product)
00173 template<typename T, typename Lhs, typename Rhs, typename Derived>
00174 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
00175 {
00176     typedef scalar_product_op<T> BinaryFunc;
00177     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00178     typedef typename CwiseBinaryXpr::Scalar Scalar;
00179     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00180     typedef typename _LhsNested::InnerIterator LhsIterator;
00181     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00182     typedef typename _RhsNested::InnerIterator RhsIterator;
00183     typedef typename Lhs::Index Index;
00184   public:
00185 
00186     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00187       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00188     {
00189       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00190       {
00191         if (m_lhsIter.index() < m_rhsIter.index())
00192           ++m_lhsIter;
00193         else
00194           ++m_rhsIter;
00195       }
00196     }
00197 
00198     EIGEN_STRONG_INLINE Derived& operator++()
00199     {
00200       ++m_lhsIter;
00201       ++m_rhsIter;
00202       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00203       {
00204         if (m_lhsIter.index() < m_rhsIter.index())
00205           ++m_lhsIter;
00206         else
00207           ++m_rhsIter;
00208       }
00209       return *static_cast<Derived*>(this);
00210     }
00211 
00212     EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
00213 
00214     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00215     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00216     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00217 
00218     EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
00219 
00220   protected:
00221     LhsIterator m_lhsIter;
00222     RhsIterator m_rhsIter;
00223     const BinaryFunc& m_functor;
00224 };
00225 
00226 // sparse - dense  (product)
00227 template<typename T, typename Lhs, typename Rhs, typename Derived>
00228 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
00229 {
00230     typedef scalar_product_op<T> BinaryFunc;
00231     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00232     typedef typename CwiseBinaryXpr::Scalar Scalar;
00233     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00234     typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
00235     typedef typename _LhsNested::InnerIterator LhsIterator;
00236     typedef typename Lhs::Index Index;
00237     enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
00238   public:
00239 
00240     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00241       : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00242     {}
00243 
00244     EIGEN_STRONG_INLINE Derived& operator++()
00245     {
00246       ++m_lhsIter;
00247       return *static_cast<Derived*>(this);
00248     }
00249 
00250     EIGEN_STRONG_INLINE Scalar value() const
00251     { return m_functor(m_lhsIter.value(),
00252                        m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
00253 
00254     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00255     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00256     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00257 
00258     EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
00259 
00260   protected:
00261     RhsNested m_rhs;
00262     LhsIterator m_lhsIter;
00263     const BinaryFunc m_functor;
00264     const Index m_outer;
00265 };
00266 
00267 // sparse - dense  (product)
00268 template<typename T, typename Lhs, typename Rhs, typename Derived>
00269 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
00270 {
00271     typedef scalar_product_op<T> BinaryFunc;
00272     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00273     typedef typename CwiseBinaryXpr::Scalar Scalar;
00274     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00275     typedef typename _RhsNested::InnerIterator RhsIterator;
00276     typedef typename Lhs::Index Index;
00277 
00278     enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
00279   public:
00280 
00281     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00282       : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00283     {}
00284 
00285     EIGEN_STRONG_INLINE Derived& operator++()
00286     {
00287       ++m_rhsIter;
00288       return *static_cast<Derived*>(this);
00289     }
00290 
00291     EIGEN_STRONG_INLINE Scalar value() const
00292     { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
00293 
00294     EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
00295     EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
00296     EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
00297 
00298     EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
00299 
00300   protected:
00301     const CwiseBinaryXpr& m_xpr;
00302     RhsIterator m_rhsIter;
00303     const BinaryFunc& m_functor;
00304     const Index m_outer;
00305 };
00306 
00307 } // end namespace internal
00308 
00309 /***************************************************************************
00310 * Implementation of SparseMatrixBase and SparseCwise functions/operators
00311 ***************************************************************************/
00312 
00313 template<typename Derived>
00314 template<typename OtherDerived>
00315 EIGEN_STRONG_INLINE Derived &
00316 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
00317 {
00318   return *this = derived() - other.derived();
00319 }
00320 
00321 template<typename Derived>
00322 template<typename OtherDerived>
00323 EIGEN_STRONG_INLINE Derived &
00324 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
00325 {
00326   return *this = derived() + other.derived();
00327 }
00328 
00329 template<typename Derived>
00330 template<typename OtherDerived>
00331 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
00332 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
00333 {
00334   return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
00335 }
00336 
00337 } // end namespace Eigen
00338 
00339 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H