MatrixSquareRoot.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_MATRIX_SQUARE_ROOT
00026 #define EIGEN_MATRIX_SQUARE_ROOT
00027 
00028 namespace Eigen { 
00029 
00041 template <typename MatrixType>
00042 class MatrixSquareRootQuasiTriangular
00043 {
00044   public:
00045 
00054     MatrixSquareRootQuasiTriangular(const MatrixType& A) 
00055       : m_A(A) 
00056     {
00057       eigen_assert(A.rows() == A.cols());
00058     }
00059     
00068     template <typename ResultType> void compute(ResultType &result);    
00069     
00070   private:
00071     typedef typename MatrixType::Index Index;
00072     typedef typename MatrixType::Scalar Scalar;
00073     
00074     void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00075     void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00076     void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
00077     void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00078                                   typename MatrixType::Index i, typename MatrixType::Index j);
00079     void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00080                                   typename MatrixType::Index i, typename MatrixType::Index j);
00081     void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00082                                   typename MatrixType::Index i, typename MatrixType::Index j);
00083     void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00084                                   typename MatrixType::Index i, typename MatrixType::Index j);
00085   
00086     template <typename SmallMatrixType>
00087     static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 
00088                                      const SmallMatrixType& B, const SmallMatrixType& C);
00089   
00090     const MatrixType& m_A;
00091 };
00092 
00093 template <typename MatrixType>
00094 template <typename ResultType> 
00095 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
00096 {
00097   // Compute Schur decomposition of m_A
00098   const RealSchur<MatrixType> schurOfA(m_A);  
00099   const MatrixType& T = schurOfA.matrixT();
00100   const MatrixType& U = schurOfA.matrixU();
00101 
00102   // Compute square root of T
00103   MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00104   computeDiagonalPartOfSqrt(sqrtT, T);
00105   computeOffDiagonalPartOfSqrt(sqrtT, T);
00106 
00107   // Compute square root of m_A
00108   result = U * sqrtT * U.adjoint();
00109 }
00110 
00111 // pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
00112 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
00113 template <typename MatrixType>
00114 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 
00115                                                                           const MatrixType& T)
00116 {
00117   const Index size = m_A.rows();
00118   for (Index i = 0; i < size; i++) {
00119     if (i == size - 1 || T.coeff(i+1, i) == 0) {
00120       eigen_assert(T(i,i) > 0);
00121       sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00122     }
00123     else {
00124       compute2x2diagonalBlock(sqrtT, T, i);
00125       ++i;
00126     }
00127   }
00128 }
00129 
00130 // pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
00131 // post: sqrtT is the square root of T.
00132 template <typename MatrixType>
00133 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 
00134                                                                              const MatrixType& T)
00135 {
00136   const Index size = m_A.rows();
00137   for (Index j = 1; j < size; j++) {
00138       if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
00139         continue;
00140     for (Index i = j-1; i >= 0; i--) {
00141       if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
00142         continue;
00143       bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
00144       bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
00145       if (iBlockIs2x2 && jBlockIs2x2) 
00146         compute2x2offDiagonalBlock(sqrtT, T, i, j);
00147       else if (iBlockIs2x2 && !jBlockIs2x2) 
00148         compute2x1offDiagonalBlock(sqrtT, T, i, j);
00149       else if (!iBlockIs2x2 && jBlockIs2x2) 
00150         compute1x2offDiagonalBlock(sqrtT, T, i, j);
00151       else if (!iBlockIs2x2 && !jBlockIs2x2) 
00152         compute1x1offDiagonalBlock(sqrtT, T, i, j);
00153     }
00154   }
00155 }
00156 
00157 // pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
00158 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
00159 template <typename MatrixType>
00160 void MatrixSquareRootQuasiTriangular<MatrixType>
00161      ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
00162 {
00163   // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
00164   //       in EigenSolver. If we expose it, we could call it directly from here.
00165   Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
00166   EigenSolver<Matrix<Scalar,2,2> > es(block);
00167   sqrtT.template block<2,2>(i,i)
00168     = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
00169 }
00170 
00171 // pre:  block structure of T is such that (i,j) is a 1x1 block,
00172 //       all blocks of sqrtT to left of and below (i,j) are correct
00173 // post: sqrtT(i,j) has the correct value
00174 template <typename MatrixType>
00175 void MatrixSquareRootQuasiTriangular<MatrixType>
00176      ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00177                                   typename MatrixType::Index i, typename MatrixType::Index j)
00178 {
00179   Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
00180   sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
00181 }
00182 
00183 // similar to compute1x1offDiagonalBlock()
00184 template <typename MatrixType>
00185 void MatrixSquareRootQuasiTriangular<MatrixType>
00186      ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00187                                   typename MatrixType::Index i, typename MatrixType::Index j)
00188 {
00189   Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
00190   if (j-i > 1)
00191     rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
00192   Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
00193   A += sqrtT.template block<2,2>(j,j).transpose();
00194   sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
00195 }
00196 
00197 // similar to compute1x1offDiagonalBlock()
00198 template <typename MatrixType>
00199 void MatrixSquareRootQuasiTriangular<MatrixType>
00200      ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00201                                   typename MatrixType::Index i, typename MatrixType::Index j)
00202 {
00203   Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
00204   if (j-i > 2)
00205     rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
00206   Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
00207   A += sqrtT.template block<2,2>(i,i);
00208   sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
00209 }
00210 
00211 // similar to compute1x1offDiagonalBlock()
00212 template <typename MatrixType>
00213 void MatrixSquareRootQuasiTriangular<MatrixType>
00214      ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00215                                   typename MatrixType::Index i, typename MatrixType::Index j)
00216 {
00217   Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
00218   Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
00219   Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
00220   if (j-i > 2)
00221     C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
00222   Matrix<Scalar,2,2> X;
00223   solveAuxiliaryEquation(X, A, B, C);
00224   sqrtT.template block<2,2>(i,j) = X;
00225 }
00226 
00227 // solves the equation A X + X B = C where all matrices are 2-by-2
00228 template <typename MatrixType>
00229 template <typename SmallMatrixType>
00230 void MatrixSquareRootQuasiTriangular<MatrixType>
00231      ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00232                               const SmallMatrixType& B, const SmallMatrixType& C)
00233 {
00234   EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
00235                       EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
00236 
00237   Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
00238   coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
00239   coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
00240   coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
00241   coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
00242   coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
00243   coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
00244   coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
00245   coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
00246   coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
00247   coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
00248   coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
00249   coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
00250   
00251   Matrix<Scalar,4,1> rhs;
00252   rhs.coeffRef(0) = C.coeff(0,0);
00253   rhs.coeffRef(1) = C.coeff(0,1);
00254   rhs.coeffRef(2) = C.coeff(1,0);
00255   rhs.coeffRef(3) = C.coeff(1,1);
00256   
00257   Matrix<Scalar,4,1> result;
00258   result = coeffMatrix.fullPivLu().solve(rhs);
00259 
00260   X.coeffRef(0,0) = result.coeff(0);
00261   X.coeffRef(0,1) = result.coeff(1);
00262   X.coeffRef(1,0) = result.coeff(2);
00263   X.coeffRef(1,1) = result.coeff(3);
00264 }
00265 
00266 
00278 template <typename MatrixType>
00279 class MatrixSquareRootTriangular
00280 {
00281   public:
00282     MatrixSquareRootTriangular(const MatrixType& A) 
00283       : m_A(A) 
00284     {
00285       eigen_assert(A.rows() == A.cols());
00286     }
00287 
00297     template <typename ResultType> void compute(ResultType &result);    
00298 
00299  private:
00300     const MatrixType& m_A;
00301 };
00302 
00303 template <typename MatrixType>
00304 template <typename ResultType> 
00305 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
00306 {
00307   // Compute Schur decomposition of m_A
00308   const ComplexSchur<MatrixType> schurOfA(m_A);  
00309   const MatrixType& T = schurOfA.matrixT();
00310   const MatrixType& U = schurOfA.matrixU();
00311 
00312   // Compute square root of T and store it in upper triangular part of result
00313   // This uses that the square root of triangular matrices can be computed directly.
00314   result.resize(m_A.rows(), m_A.cols());
00315   typedef typename MatrixType::Index Index;
00316   for (Index i = 0; i < m_A.rows(); i++) {
00317     result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00318   }
00319   for (Index j = 1; j < m_A.cols(); j++) {
00320     for (Index i = j-1; i >= 0; i--) {
00321       typedef typename MatrixType::Scalar Scalar;
00322       // if i = j-1, then segment has length 0 so tmp = 0
00323       Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
00324       // denominator may be zero if original matrix is singular
00325       result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
00326     }
00327   }
00328 
00329   // Compute square root of m_A as U * result * U.adjoint()
00330   MatrixType tmp;
00331   tmp.noalias() = U * result.template triangularView<Upper>();
00332   result.noalias() = tmp * U.adjoint();
00333 }
00334 
00335 
00343 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00344 class MatrixSquareRoot
00345 {
00346   public:
00347 
00355     MatrixSquareRoot(const MatrixType& A); 
00356     
00364     template <typename ResultType> void compute(ResultType &result);    
00365 };
00366 
00367 
00368 // ********** Partial specialization for real matrices **********
00369 
00370 template <typename MatrixType>
00371 class MatrixSquareRoot<MatrixType, 0>
00372 {
00373   public:
00374 
00375     MatrixSquareRoot(const MatrixType& A) 
00376       : m_A(A) 
00377     {  
00378       eigen_assert(A.rows() == A.cols());
00379     }
00380   
00381     template <typename ResultType> void compute(ResultType &result)
00382     {
00383       // Compute Schur decomposition of m_A
00384       const RealSchur<MatrixType> schurOfA(m_A);  
00385       const MatrixType& T = schurOfA.matrixT();
00386       const MatrixType& U = schurOfA.matrixU();
00387     
00388       // Compute square root of T
00389       MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
00390       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00391       tmp.compute(sqrtT);
00392     
00393       // Compute square root of m_A
00394       result = U * sqrtT * U.adjoint();
00395     }
00396     
00397   private:
00398     const MatrixType& m_A;
00399 };
00400 
00401 
00402 // ********** Partial specialization for complex matrices **********
00403 
00404 template <typename MatrixType>
00405 class MatrixSquareRoot<MatrixType, 1>
00406 {
00407   public:
00408 
00409     MatrixSquareRoot(const MatrixType& A) 
00410       : m_A(A) 
00411     {  
00412       eigen_assert(A.rows() == A.cols());
00413     }
00414   
00415     template <typename ResultType> void compute(ResultType &result)
00416     {
00417       // Compute Schur decomposition of m_A
00418       const ComplexSchur<MatrixType> schurOfA(m_A);  
00419       const MatrixType& T = schurOfA.matrixT();
00420       const MatrixType& U = schurOfA.matrixU();
00421     
00422       // Compute square root of T
00423       MatrixSquareRootTriangular<MatrixType> tmp(T);
00424       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00425       tmp.compute(sqrtT);
00426     
00427       // Compute square root of m_A
00428       result = U * sqrtT * U.adjoint();
00429     }
00430     
00431   private:
00432     const MatrixType& m_A;
00433 };
00434 
00435 
00448 template<typename Derived> class MatrixSquareRootReturnValue
00449 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
00450 {
00451     typedef typename Derived::Index Index;
00452   public:
00458     MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
00459 
00465     template <typename ResultType>
00466     inline void evalTo(ResultType& result) const
00467     {
00468       const typename Derived::PlainObject srcEvaluated = m_src.eval();
00469       MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
00470       me.compute(result);
00471     }
00472 
00473     Index rows() const { return m_src.rows(); }
00474     Index cols() const { return m_src.cols(); }
00475 
00476   protected:
00477     const Derived& m_src;
00478   private:
00479     MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
00480 };
00481 
00482 namespace internal {
00483 template<typename Derived>
00484 struct traits<MatrixSquareRootReturnValue<Derived> >
00485 {
00486   typedef typename Derived::PlainObject ReturnType;
00487 };
00488 }
00489 
00490 template <typename Derived>
00491 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
00492 {
00493   eigen_assert(rows() == cols());
00494   return MatrixSquareRootReturnValue<Derived>(derived());
00495 }
00496 
00497 } // end namespace Eigen
00498 
00499 #endif // EIGEN_MATRIX_FUNCTION