00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_MATRIX_SQUARE_ROOT
00026 #define EIGEN_MATRIX_SQUARE_ROOT
00027
00028 namespace Eigen {
00029
00041 template <typename MatrixType>
00042 class MatrixSquareRootQuasiTriangular
00043 {
00044 public:
00045
00054 MatrixSquareRootQuasiTriangular(const MatrixType& A)
00055 : m_A(A)
00056 {
00057 eigen_assert(A.rows() == A.cols());
00058 }
00059
00068 template <typename ResultType> void compute(ResultType &result);
00069
00070 private:
00071 typedef typename MatrixType::Index Index;
00072 typedef typename MatrixType::Scalar Scalar;
00073
00074 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00075 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00076 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
00077 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00078 typename MatrixType::Index i, typename MatrixType::Index j);
00079 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00080 typename MatrixType::Index i, typename MatrixType::Index j);
00081 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00082 typename MatrixType::Index i, typename MatrixType::Index j);
00083 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00084 typename MatrixType::Index i, typename MatrixType::Index j);
00085
00086 template <typename SmallMatrixType>
00087 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00088 const SmallMatrixType& B, const SmallMatrixType& C);
00089
00090 const MatrixType& m_A;
00091 };
00092
00093 template <typename MatrixType>
00094 template <typename ResultType>
00095 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
00096 {
00097
00098 const RealSchur<MatrixType> schurOfA(m_A);
00099 const MatrixType& T = schurOfA.matrixT();
00100 const MatrixType& U = schurOfA.matrixU();
00101
00102
00103 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00104 computeDiagonalPartOfSqrt(sqrtT, T);
00105 computeOffDiagonalPartOfSqrt(sqrtT, T);
00106
00107
00108 result = U * sqrtT * U.adjoint();
00109 }
00110
00111
00112
00113 template <typename MatrixType>
00114 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
00115 const MatrixType& T)
00116 {
00117 const Index size = m_A.rows();
00118 for (Index i = 0; i < size; i++) {
00119 if (i == size - 1 || T.coeff(i+1, i) == 0) {
00120 eigen_assert(T(i,i) > 0);
00121 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00122 }
00123 else {
00124 compute2x2diagonalBlock(sqrtT, T, i);
00125 ++i;
00126 }
00127 }
00128 }
00129
00130
00131
00132 template <typename MatrixType>
00133 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
00134 const MatrixType& T)
00135 {
00136 const Index size = m_A.rows();
00137 for (Index j = 1; j < size; j++) {
00138 if (T.coeff(j, j-1) != 0)
00139 continue;
00140 for (Index i = j-1; i >= 0; i--) {
00141 if (i > 0 && T.coeff(i, i-1) != 0)
00142 continue;
00143 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
00144 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
00145 if (iBlockIs2x2 && jBlockIs2x2)
00146 compute2x2offDiagonalBlock(sqrtT, T, i, j);
00147 else if (iBlockIs2x2 && !jBlockIs2x2)
00148 compute2x1offDiagonalBlock(sqrtT, T, i, j);
00149 else if (!iBlockIs2x2 && jBlockIs2x2)
00150 compute1x2offDiagonalBlock(sqrtT, T, i, j);
00151 else if (!iBlockIs2x2 && !jBlockIs2x2)
00152 compute1x1offDiagonalBlock(sqrtT, T, i, j);
00153 }
00154 }
00155 }
00156
00157
00158
00159 template <typename MatrixType>
00160 void MatrixSquareRootQuasiTriangular<MatrixType>
00161 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
00162 {
00163
00164
00165 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
00166 EigenSolver<Matrix<Scalar,2,2> > es(block);
00167 sqrtT.template block<2,2>(i,i)
00168 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
00169 }
00170
00171
00172
00173
00174 template <typename MatrixType>
00175 void MatrixSquareRootQuasiTriangular<MatrixType>
00176 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00177 typename MatrixType::Index i, typename MatrixType::Index j)
00178 {
00179 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
00180 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
00181 }
00182
00183
00184 template <typename MatrixType>
00185 void MatrixSquareRootQuasiTriangular<MatrixType>
00186 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00187 typename MatrixType::Index i, typename MatrixType::Index j)
00188 {
00189 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
00190 if (j-i > 1)
00191 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
00192 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
00193 A += sqrtT.template block<2,2>(j,j).transpose();
00194 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
00195 }
00196
00197
00198 template <typename MatrixType>
00199 void MatrixSquareRootQuasiTriangular<MatrixType>
00200 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00201 typename MatrixType::Index i, typename MatrixType::Index j)
00202 {
00203 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
00204 if (j-i > 2)
00205 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
00206 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
00207 A += sqrtT.template block<2,2>(i,i);
00208 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
00209 }
00210
00211
00212 template <typename MatrixType>
00213 void MatrixSquareRootQuasiTriangular<MatrixType>
00214 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00215 typename MatrixType::Index i, typename MatrixType::Index j)
00216 {
00217 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
00218 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
00219 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
00220 if (j-i > 2)
00221 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
00222 Matrix<Scalar,2,2> X;
00223 solveAuxiliaryEquation(X, A, B, C);
00224 sqrtT.template block<2,2>(i,j) = X;
00225 }
00226
00227
00228 template <typename MatrixType>
00229 template <typename SmallMatrixType>
00230 void MatrixSquareRootQuasiTriangular<MatrixType>
00231 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00232 const SmallMatrixType& B, const SmallMatrixType& C)
00233 {
00234 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
00235 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
00236
00237 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
00238 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
00239 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
00240 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
00241 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
00242 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
00243 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
00244 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
00245 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
00246 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
00247 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
00248 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
00249 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
00250
00251 Matrix<Scalar,4,1> rhs;
00252 rhs.coeffRef(0) = C.coeff(0,0);
00253 rhs.coeffRef(1) = C.coeff(0,1);
00254 rhs.coeffRef(2) = C.coeff(1,0);
00255 rhs.coeffRef(3) = C.coeff(1,1);
00256
00257 Matrix<Scalar,4,1> result;
00258 result = coeffMatrix.fullPivLu().solve(rhs);
00259
00260 X.coeffRef(0,0) = result.coeff(0);
00261 X.coeffRef(0,1) = result.coeff(1);
00262 X.coeffRef(1,0) = result.coeff(2);
00263 X.coeffRef(1,1) = result.coeff(3);
00264 }
00265
00266
00278 template <typename MatrixType>
00279 class MatrixSquareRootTriangular
00280 {
00281 public:
00282 MatrixSquareRootTriangular(const MatrixType& A)
00283 : m_A(A)
00284 {
00285 eigen_assert(A.rows() == A.cols());
00286 }
00287
00297 template <typename ResultType> void compute(ResultType &result);
00298
00299 private:
00300 const MatrixType& m_A;
00301 };
00302
00303 template <typename MatrixType>
00304 template <typename ResultType>
00305 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
00306 {
00307
00308 const ComplexSchur<MatrixType> schurOfA(m_A);
00309 const MatrixType& T = schurOfA.matrixT();
00310 const MatrixType& U = schurOfA.matrixU();
00311
00312
00313
00314 result.resize(m_A.rows(), m_A.cols());
00315 typedef typename MatrixType::Index Index;
00316 for (Index i = 0; i < m_A.rows(); i++) {
00317 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00318 }
00319 for (Index j = 1; j < m_A.cols(); j++) {
00320 for (Index i = j-1; i >= 0; i--) {
00321 typedef typename MatrixType::Scalar Scalar;
00322
00323 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
00324
00325 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
00326 }
00327 }
00328
00329
00330 MatrixType tmp;
00331 tmp.noalias() = U * result.template triangularView<Upper>();
00332 result.noalias() = tmp * U.adjoint();
00333 }
00334
00335
00343 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00344 class MatrixSquareRoot
00345 {
00346 public:
00347
00355 MatrixSquareRoot(const MatrixType& A);
00356
00364 template <typename ResultType> void compute(ResultType &result);
00365 };
00366
00367
00368
00369
00370 template <typename MatrixType>
00371 class MatrixSquareRoot<MatrixType, 0>
00372 {
00373 public:
00374
00375 MatrixSquareRoot(const MatrixType& A)
00376 : m_A(A)
00377 {
00378 eigen_assert(A.rows() == A.cols());
00379 }
00380
00381 template <typename ResultType> void compute(ResultType &result)
00382 {
00383
00384 const RealSchur<MatrixType> schurOfA(m_A);
00385 const MatrixType& T = schurOfA.matrixT();
00386 const MatrixType& U = schurOfA.matrixU();
00387
00388
00389 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
00390 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00391 tmp.compute(sqrtT);
00392
00393
00394 result = U * sqrtT * U.adjoint();
00395 }
00396
00397 private:
00398 const MatrixType& m_A;
00399 };
00400
00401
00402
00403
00404 template <typename MatrixType>
00405 class MatrixSquareRoot<MatrixType, 1>
00406 {
00407 public:
00408
00409 MatrixSquareRoot(const MatrixType& A)
00410 : m_A(A)
00411 {
00412 eigen_assert(A.rows() == A.cols());
00413 }
00414
00415 template <typename ResultType> void compute(ResultType &result)
00416 {
00417
00418 const ComplexSchur<MatrixType> schurOfA(m_A);
00419 const MatrixType& T = schurOfA.matrixT();
00420 const MatrixType& U = schurOfA.matrixU();
00421
00422
00423 MatrixSquareRootTriangular<MatrixType> tmp(T);
00424 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00425 tmp.compute(sqrtT);
00426
00427
00428 result = U * sqrtT * U.adjoint();
00429 }
00430
00431 private:
00432 const MatrixType& m_A;
00433 };
00434
00435
00448 template<typename Derived> class MatrixSquareRootReturnValue
00449 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
00450 {
00451 typedef typename Derived::Index Index;
00452 public:
00458 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
00459
00465 template <typename ResultType>
00466 inline void evalTo(ResultType& result) const
00467 {
00468 const typename Derived::PlainObject srcEvaluated = m_src.eval();
00469 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
00470 me.compute(result);
00471 }
00472
00473 Index rows() const { return m_src.rows(); }
00474 Index cols() const { return m_src.cols(); }
00475
00476 protected:
00477 const Derived& m_src;
00478 private:
00479 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
00480 };
00481
00482 namespace internal {
00483 template<typename Derived>
00484 struct traits<MatrixSquareRootReturnValue<Derived> >
00485 {
00486 typedef typename Derived::PlainObject ReturnType;
00487 };
00488 }
00489
00490 template <typename Derived>
00491 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
00492 {
00493 eigen_assert(rows() == cols());
00494 return MatrixSquareRootReturnValue<Derived>(derived());
00495 }
00496
00497 }
00498
00499 #endif // EIGEN_MATRIX_FUNCTION