@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
3737 solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
3838 ThreadPool *pool) const ;
3939
40- template <class _Scalar , int _Rows, int _Cols>
41- Eigen::Matrix<_Scalar, _Rows, _Cols>
42- sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43- ThreadPool *pool) const ;
40+ template <typename Derived>
41+ Eigen::MatrixXd sqrt_solve (const Eigen::DenseBase<Derived> &rhs,
42+ ThreadPool *pool) const ;
4443
4544 BlockDiagonal sqrt_transpose () const ;
4645
@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
5150 Eigen::Index rows () const ;
5251
5352 Eigen::Index cols () const ;
53+
54+ bool operator ==(const BlockDiagonalLDLT &other) const ;
5455};
5556
5657struct BlockDiagonal {
@@ -141,15 +142,16 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141142 return output;
142143}
143144
144- template <class _Scalar , int _Rows, int _Cols >
145- inline Eigen::Matrix<_Scalar, _Rows, _Cols>
146- BlockDiagonalLDLT::sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols > &rhs,
145+ template <typename Derived >
146+ inline Eigen::MatrixXd
147+ BlockDiagonalLDLT::sqrt_solve (const Eigen::DenseBase<Derived > &rhs,
147148 ThreadPool *pool) const {
148149 ALBATROSS_ASSERT (cols () == rhs.rows ());
149- Eigen::Matrix<_Scalar, _Rows, _Cols> output (rows (), rhs.cols ());
150+ Eigen::MatrixXd output (rows (), rhs.cols ());
150151
151152 auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
152- const auto rhs_chunk = rhs.block (row, 0 , blocks[i].rows (), rhs.cols ());
153+ const auto rhs_chunk =
154+ rhs.derived ().block (row, 0 , blocks[i].rows (), rhs.cols ());
153155 output.block (row, 0 , blocks[i].rows (), rhs.cols ()) =
154156 blocks[i].sqrt_solve (rhs_chunk);
155157 };
@@ -182,6 +184,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182184 return n;
183185}
184186
187+ inline bool
188+ BlockDiagonalLDLT::operator ==(const BlockDiagonalLDLT &other) const {
189+ return blocks == other.blocks ;
190+ }
185191/*
186192 * Block Diagonal
187193 */
0 commit comments