00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #ifndef EIGEN_DIAGONALPRODUCT_H
00027 #define EIGEN_DIAGONALPRODUCT_H
00028
00029 namespace Eigen {
00030
00031 namespace internal {
00032 template<typename MatrixType, typename DiagonalType, int ProductOrder>
00033 struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
00034 : traits<MatrixType>
00035 {
00036 typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
00037 enum {
00038 RowsAtCompileTime = MatrixType::RowsAtCompileTime,
00039 ColsAtCompileTime = MatrixType::ColsAtCompileTime,
00040 MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
00041 MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
00042
00043 _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
00044 _PacketOnDiag = !((int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
00045 ||(int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)),
00046 _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
00047
00048
00049 _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && ((!_PacketOnDiag) || (bool(int(DiagonalType::Flags)&PacketAccessBit))),
00050
00051 Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),
00052 CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
00053 };
00054 };
00055 }
00056
00057 template<typename MatrixType, typename DiagonalType, int ProductOrder>
00058 class DiagonalProduct : internal::no_assignment_operator,
00059 public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
00060 {
00061 public:
00062
00063 typedef MatrixBase<DiagonalProduct> Base;
00064 EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
00065
00066 inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
00067 : m_matrix(matrix), m_diagonal(diagonal)
00068 {
00069 eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
00070 }
00071
00072 inline Index rows() const { return m_matrix.rows(); }
00073 inline Index cols() const { return m_matrix.cols(); }
00074
00075 const Scalar coeff(Index row, Index col) const
00076 {
00077 return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
00078 }
00079
00080 template<int LoadMode>
00081 EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
00082 {
00083 enum {
00084 StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
00085 };
00086 const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
00087
00088 return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
00089 ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
00090 ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
00091 }
00092
00093 protected:
00094 template<int LoadMode>
00095 EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
00096 {
00097 return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
00098 internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
00099 }
00100
00101 template<int LoadMode>
00102 EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
00103 {
00104 enum {
00105 InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
00106 DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned
00107 };
00108 return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
00109 m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
00110 }
00111
00112 typename MatrixType::Nested m_matrix;
00113 typename DiagonalType::Nested m_diagonal;
00114 };
00115
00118 template<typename Derived>
00119 template<typename DiagonalDerived>
00120 inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
00121 MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &diagonal) const
00122 {
00123 return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), diagonal.derived());
00124 }
00125
00128 template<typename DiagonalDerived>
00129 template<typename MatrixDerived>
00130 inline const DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>
00131 DiagonalBase<DiagonalDerived>::operator*(const MatrixBase<MatrixDerived> &matrix) const
00132 {
00133 return DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>(matrix.derived(), derived());
00134 }
00135
00136 }
00137
00138 #endif // EIGEN_DIAGONALPRODUCT_H