TriangularSolver.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_SPARSETRIANGULARSOLVER_H
00026 #define EIGEN_SPARSETRIANGULARSOLVER_H
00027 
00028 namespace Eigen { 
00029 
00030 namespace internal {
00031 
00032 template<typename Lhs, typename Rhs, int Mode,
00033   int UpLo = (Mode & Lower)
00034            ? Lower
00035            : (Mode & Upper)
00036            ? Upper
00037            : -1,
00038   int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
00039 struct sparse_solve_triangular_selector;
00040 
00041 // forward substitution, row-major
00042 template<typename Lhs, typename Rhs, int Mode>
00043 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
00044 {
00045   typedef typename Rhs::Scalar Scalar;
00046   static void run(const Lhs& lhs, Rhs& other)
00047   {
00048     for(int col=0 ; col<other.cols() ; ++col)
00049     {
00050       for(int i=0; i<lhs.rows(); ++i)
00051       {
00052         Scalar tmp = other.coeff(i,col);
00053         Scalar lastVal(0);
00054         int lastIndex = 0;
00055         for(typename Lhs::InnerIterator it(lhs, i); it; ++it)
00056         {
00057           lastVal = it.value();
00058           lastIndex = it.index();
00059           if(lastIndex==i)
00060             break;
00061           tmp -= lastVal * other.coeff(lastIndex,col);
00062         }
00063         if (Mode & UnitDiag)
00064           other.coeffRef(i,col) = tmp;
00065         else
00066         {
00067           eigen_assert(lastIndex==i);
00068           other.coeffRef(i,col) = tmp/lastVal;
00069         }
00070       }
00071     }
00072   }
00073 };
00074 
00075 // backward substitution, row-major
00076 template<typename Lhs, typename Rhs, int Mode>
00077 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
00078 {
00079   typedef typename Rhs::Scalar Scalar;
00080   static void run(const Lhs& lhs, Rhs& other)
00081   {
00082     for(int col=0 ; col<other.cols() ; ++col)
00083     {
00084       for(int i=lhs.rows()-1 ; i>=0 ; --i)
00085       {
00086         Scalar tmp = other.coeff(i,col);
00087         Scalar l_ii = 0;
00088         typename Lhs::InnerIterator it(lhs, i);
00089         while(it && it.index()<i)
00090           ++it;
00091         if(!(Mode & UnitDiag))
00092         {
00093           eigen_assert(it && it.index()==i);
00094           l_ii = it.value();
00095           ++it;
00096         }
00097         else if (it && it.index() == i)
00098           ++it;
00099         for(; it; ++it)
00100         {
00101           tmp -= it.value() * other.coeff(it.index(),col);
00102         }
00103 
00104         if (Mode & UnitDiag)
00105           other.coeffRef(i,col) = tmp;
00106         else
00107           other.coeffRef(i,col) = tmp/l_ii;
00108       }
00109     }
00110   }
00111 };
00112 
00113 // forward substitution, col-major
00114 template<typename Lhs, typename Rhs, int Mode>
00115 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
00116 {
00117   typedef typename Rhs::Scalar Scalar;
00118   static void run(const Lhs& lhs, Rhs& other)
00119   {
00120     for(int col=0 ; col<other.cols() ; ++col)
00121     {
00122       for(int i=0; i<lhs.cols(); ++i)
00123       {
00124         Scalar& tmp = other.coeffRef(i,col);
00125         if (tmp!=Scalar(0)) // optimization when other is actually sparse
00126         {
00127           typename Lhs::InnerIterator it(lhs, i);
00128           while(it && it.index()<i)
00129             ++it;
00130           if(!(Mode & UnitDiag))
00131           {
00132             eigen_assert(it && it.index()==i);
00133             tmp /= it.value();
00134           }
00135           if (it && it.index()==i)
00136             ++it;
00137           for(; it; ++it)
00138             other.coeffRef(it.index(), col) -= tmp * it.value();
00139         }
00140       }
00141     }
00142   }
00143 };
00144 
00145 // backward substitution, col-major
00146 template<typename Lhs, typename Rhs, int Mode>
00147 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
00148 {
00149   typedef typename Rhs::Scalar Scalar;
00150   static void run(const Lhs& lhs, Rhs& other)
00151   {
00152     for(int col=0 ; col<other.cols() ; ++col)
00153     {
00154       for(int i=lhs.cols()-1; i>=0; --i)
00155       {
00156         Scalar& tmp = other.coeffRef(i,col);
00157         if (tmp!=Scalar(0)) // optimization when other is actually sparse
00158         {
00159           if(!(Mode & UnitDiag))
00160           {
00161             // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
00162             typename Lhs::ReverseInnerIterator it(lhs, i);
00163             while(it && it.index()!=i)
00164               --it;
00165             eigen_assert(it && it.index()==i);
00166             other.coeffRef(i,col) /= it.value();
00167           }
00168           typename Lhs::InnerIterator it(lhs, i);
00169           for(; it && it.index()<i; ++it)
00170             other.coeffRef(it.index(), col) -= tmp * it.value();
00171         }
00172       }
00173     }
00174   }
00175 };
00176 
00177 } // end namespace internal
00178 
00179 template<typename ExpressionType,int Mode>
00180 template<typename OtherDerived>
00181 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
00182 {
00183   eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
00184   eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
00185 
00186   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
00187 
00188   typedef typename internal::conditional<copy,
00189     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
00190   OtherCopy otherCopy(other.derived());
00191 
00192   internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
00193 
00194   if (copy)
00195     other = otherCopy;
00196 }
00197 
00198 template<typename ExpressionType,int Mode>
00199 template<typename OtherDerived>
00200 typename internal::plain_matrix_type_column_major<OtherDerived>::type
00201 SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
00202 {
00203   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
00204   solveInPlace(res);
00205   return res;
00206 }
00207 
00208 // pure sparse path
00209 
00210 namespace internal {
00211 
00212 template<typename Lhs, typename Rhs, int Mode,
00213   int UpLo = (Mode & Lower)
00214            ? Lower
00215            : (Mode & Upper)
00216            ? Upper
00217            : -1,
00218   int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
00219 struct sparse_solve_triangular_sparse_selector;
00220 
00221 // forward substitution, col-major
00222 template<typename Lhs, typename Rhs, int Mode, int UpLo>
00223 struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
00224 {
00225   typedef typename Rhs::Scalar Scalar;
00226   typedef typename promote_index_type<typename traits<Lhs>::Index,
00227                                          typename traits<Rhs>::Index>::type Index;
00228   static void run(const Lhs& lhs, Rhs& other)
00229   {
00230     const bool IsLower = (UpLo==Lower);
00231     AmbiVector<Scalar,Index> tempVector(other.rows()*2);
00232     tempVector.setBounds(0,other.rows());
00233 
00234     Rhs res(other.rows(), other.cols());
00235     res.reserve(other.nonZeros());
00236 
00237     for(int col=0 ; col<other.cols() ; ++col)
00238     {
00239       // FIXME estimate number of non zeros
00240       tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
00241       tempVector.setZero();
00242       tempVector.restart();
00243       for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
00244       {
00245         tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
00246       }
00247 
00248       for(int i=IsLower?0:lhs.cols()-1;
00249           IsLower?i<lhs.cols():i>=0;
00250           i+=IsLower?1:-1)
00251       {
00252         tempVector.restart();
00253         Scalar& ci = tempVector.coeffRef(i);
00254         if (ci!=Scalar(0))
00255         {
00256           // find
00257           typename Lhs::InnerIterator it(lhs, i);
00258           if(!(Mode & UnitDiag))
00259           {
00260             if (IsLower)
00261             {
00262               eigen_assert(it.index()==i);
00263               ci /= it.value();
00264             }
00265             else
00266               ci /= lhs.coeff(i,i);
00267           }
00268           tempVector.restart();
00269           if (IsLower)
00270           {
00271             if (it.index()==i)
00272               ++it;
00273             for(; it; ++it)
00274               tempVector.coeffRef(it.index()) -= ci * it.value();
00275           }
00276           else
00277           {
00278             for(; it && it.index()<i; ++it)
00279               tempVector.coeffRef(it.index()) -= ci * it.value();
00280           }
00281         }
00282       }
00283 
00284 
00285       int count = 0;
00286       // FIXME compute a reference value to filter zeros
00287       for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector/*,1e-12*/); it; ++it)
00288       {
00289         ++ count;
00290 //         std::cerr << "fill " << it.index() << ", " << col << "\n";
00291 //         std::cout << it.value() << "  ";
00292         // FIXME use insertBack
00293         res.insert(it.index(), col) = it.value();
00294       }
00295 //       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
00296     }
00297     res.finalize();
00298     other = res.markAsRValue();
00299   }
00300 };
00301 
00302 } // end namespace internal
00303 
00304 template<typename ExpressionType,int Mode>
00305 template<typename OtherDerived>
00306 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
00307 {
00308   eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
00309   eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
00310 
00311 //   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
00312 
00313 //   typedef typename internal::conditional<copy,
00314 //     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
00315 //   OtherCopy otherCopy(other.derived());
00316 
00317   internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
00318 
00319 //   if (copy)
00320 //     other = otherCopy;
00321 }
00322 
00323 #ifdef EIGEN2_SUPPORT
00324 
00325 // deprecated stuff:
00326 
00328 template<typename Derived>
00329 template<typename OtherDerived>
00330 void SparseMatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const
00331 {
00332   this->template triangular<Flags&(Upper|Lower)>().solveInPlace(other);
00333 }
00334 
00336 template<typename Derived>
00337 template<typename OtherDerived>
00338 typename internal::plain_matrix_type_column_major<OtherDerived>::type
00339 SparseMatrixBase<Derived>::solveTriangular(const MatrixBase<OtherDerived>& other) const
00340 {
00341   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
00342   derived().solveTriangularInPlace(res);
00343   return res;
00344 }
00345 #endif // EIGEN2_SUPPORT
00346 
00347 } // end namespace Eigen
00348 
00349 #endif // EIGEN_SPARSETRIANGULARSOLVER_H