Skip to content

Commit c33d436

Browse files
committed
add block diag methods
1 parent 884903e commit c33d436

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

include/albatross/src/linalg/block_diagonal.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
3737
solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
3838
ThreadPool *pool) const;
3939

40-
template <class _Scalar, int _Rows, int _Cols>
41-
Eigen::Matrix<_Scalar, _Rows, _Cols>
42-
sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43-
ThreadPool *pool) const;
40+
template <typename Derived>
41+
Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
42+
ThreadPool *pool) const;
4443

4544
BlockDiagonal sqrt_transpose() const;
4645

@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
5150
Eigen::Index rows() const;
5251

5352
Eigen::Index cols() const;
53+
54+
bool operator==(const BlockDiagonalLDLT &other) const;
5455
};
5556

5657
struct BlockDiagonal {
@@ -141,15 +142,16 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141142
return output;
142143
}
143144

144-
template <class _Scalar, int _Rows, int _Cols>
145-
inline Eigen::Matrix<_Scalar, _Rows, _Cols>
146-
BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
145+
template <typename Derived>
146+
inline Eigen::MatrixXd
147+
BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
147148
ThreadPool *pool) const {
148149
ALBATROSS_ASSERT(cols() == rhs.rows());
149-
Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols());
150+
Eigen::MatrixXd output(rows(), rhs.cols());
150151

151152
auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
152-
const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols());
153+
const auto rhs_chunk =
154+
rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols());
153155
output.block(row, 0, blocks[i].rows(), rhs.cols()) =
154156
blocks[i].sqrt_solve(rhs_chunk);
155157
};
@@ -182,6 +184,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182184
return n;
183185
}
184186

187+
inline bool
188+
BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const {
189+
return blocks == other.blocks;
190+
}
185191
/*
186192
* Block Diagonal
187193
*/

0 commit comments

Comments
 (0)