Skip to content

Commit 0e1c379

Browse files
Bugfixes for Normal distribution in torch backend
1 parent 38908ae commit 0e1c379

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

src/probnum/backend/linalg/_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def solve_triangular(
2626
upper=not lower,
2727
transpose=transpose,
2828
unitriangular=unit_diagonal,
29-
)[:, 0]
29+
).solution[:, 0]
3030

3131
return torch.triangular_solve(
3232
b,
3333
A,
3434
upper=not lower,
3535
transpose=transpose,
3636
unitriangular=unit_diagonal,
37-
)
37+
).solution
3838

3939

4040
def solve_cholesky(

src/probnum/randvars/_normal.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def cov_cholesky(self) -> ArrayType:
204204

205205
return self._cov_cholesky
206206

207-
@property
207+
@functools.cached_property
208208
def _cov_matrix_cholesky(self) -> ArrayType:
209-
return self._cov_op_cholesky.todense()
209+
return backend.asarray(self._cov_op_cholesky.todense())
210210

211211
@property
212212
def _cov_op_cholesky(self) -> ArrayType:
@@ -453,7 +453,9 @@ def _sample(self, seed: SeedLike, sample_shape: ShapeType = ()) -> ArrayType:
453453
dtype=self.dtype,
454454
)
455455

456-
samples = self._cov_op_cholesky(backend.to_numpy(samples), axis=-1)
456+
samples = backend.asarray(
457+
self._cov_op_cholesky(backend.to_numpy(samples), axis=-1)
458+
)
457459
samples += self.dense_mean
458460

459461
return samples.reshape(sample_shape + self.shape)
@@ -489,7 +491,7 @@ def _logpdf(self, x: ArrayType) -> ArrayType:
489491
# TODO (#569): Replace `solve_triangular` with:
490492
# self._cov_op_cholesky.inv() @ x_centered[..., None]
491493
x_whitened = backend.linalg.solve_triangular(
492-
self._cov_matrix_cholesky,
494+
backend.asarray(self._cov_matrix_cholesky),
493495
x_centered[..., None],
494496
lower=True,
495497
)[..., 0]

0 commit comments

Comments
 (0)