Skip to content

Commit a55cb3b

Browse files
Finish random variable port to backend
1 parent 3f3062b commit a55cb3b

File tree

7 files changed

+29
-10
lines changed

7 files changed

+29
-10
lines changed

src/probnum/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"is_floating",
1919
"finfo",
2020
# Shape Arithmetic
21+
"reshape",
2122
"atleast_1d",
2223
"atleast_2d",
2324
"broadcast_arrays",

src/probnum/backend/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
finfo = _core.finfo
2828

2929
# Shape Arithmetic
30+
reshape = _core.reshape
3031
atleast_1d = _core.atleast_1d
3132
atleast_2d = _core.atleast_2d
3233
broadcast_arrays = _core.broadcast_arrays

src/probnum/backend/_core/_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ones_like,
3434
pi,
3535
promote_types,
36+
reshape,
3637
sin,
3738
single,
3839
sqrt,

src/probnum/backend/_core/_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ones_like,
3333
pi,
3434
promote_types,
35+
reshape,
3536
sin,
3637
single,
3738
sqrt,

src/probnum/backend/_core/_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
maximum,
2727
pi,
2828
promote_types,
29+
reshape,
2930
sin,
3031
sqrt,
3132
)

src/probnum/randvars/_arithmetic.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,21 @@ def _rv_binary_op(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable
123123

124124

125125
def _make_rv_binary_op_result_shape_dtype_sample_fn(op_fn, rv1, rv2):
126-
seed = backend.random.seed(1)
127-
sample_fn = lambda sample_shape: op_fn(
128-
rv1.sample(seed=seed, sample_shape=sample_shape),
129-
rv2.sample(seed=seed, sample_shape=sample_shape),
130-
)
126+
def sample(seed, sample_shape):
127+
seed1, seed2, _ = backend.random.split(seed, 3)
128+
129+
return op_fn(
130+
rv1.sample(seed=seed1, sample_shape=sample_shape),
131+
rv2.sample(seed=seed2, sample_shape=sample_shape),
132+
)
131133

132134
# Infer shape and dtype
133-
infer_sample = sample_fn(())
135+
infer_sample = sample(backend.random.seed(1), ())
134136

135137
shape = infer_sample.shape
136138
dtype = infer_sample.dtype
137139

138-
return shape, dtype, sample_fn
140+
return shape, dtype, sample
139141

140142

141143
def _generic_rv_add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:

src/probnum/randvars/_random_variable.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,10 @@ def var(self) -> ArrayType:
313313
"""
314314
if self.__var is None:
315315
try:
316-
var = np.diag(self.cov).reshape(self.__shape).copy()
316+
var = backend.reshape(
317+
backend.diag(self.cov),
318+
self.__shape,
319+
).copy()
317320
except NotImplementedError as exc:
318321
raise NotImplementedError from exc
319322
else:
@@ -514,8 +517,12 @@ def quantile(self, p: ArrayType) -> ArrayType:
514517
return quantile
515518

516519
def __getitem__(self, key: ArrayLikeGetitemArgType) -> "RandomVariable":
520+
# Shape inference
521+
# For simplicity, this should not be computed using backend, but rather in numpy
522+
shape = np.broadcast_to(np.empty(()), self.shape)[key].shape
523+
517524
return RandomVariable(
518-
shape=np.empty(shape=self.shape)[key].shape,
525+
shape=shape,
519526
dtype=self.dtype,
520527
sample=lambda rng, size: self.sample(rng, size)[key],
521528
mode=lambda: self.mode[key],
@@ -557,8 +564,13 @@ def transpose(self, *axes: int) -> "RandomVariable":
557564
axes :
558565
See documentation of :meth:`numpy.ndarray.transpose`.
559566
"""
567+
568+
# Shape inference
569+
# For simplicity, this should not be computed using backend, but rather in numpy
570+
shape = np.broadcast_to(np.empty(()), self.shape).transpose(*axes).shape
571+
560572
return RandomVariable(
561-
shape=np.empty(shape=self.shape).transpose(*axes).shape,
573+
shape=shape,
562574
dtype=self.dtype,
563575
sample=lambda rng, size: self.sample(rng, size).transpose(*axes),
564576
mode=lambda: self.mode.transpose(*axes),

0 commit comments

Comments
 (0)