Skip to content

Commit 3c8203d

Browse files
Refactor Normal._cdf to use backend.Dispatcher
1 parent f9cdf13 commit 3c8203d

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

src/probnum/backend/_dispatcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def torch(self, impl: Callable) -> Callable:
4747
return impl
4848

4949
def __call__(self, *args, **kwargs):
50+
if BACKEND not in self._impl:
51+
raise NotImplementedError(
52+
f"This function is not implemented for the backend `{BACKEND.name}`"
53+
)
5054
return self._impl[BACKEND](*args, **kwargs)
5155

5256
def __get__(self, obj, objtype=None):

src/probnum/randvars/_normal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,10 @@ def _logpdf(self, x: ArrayType) -> ArrayType:
501501

502502
return res
503503

504-
def _cdf(self, x: ArrayType) -> ArrayType:
505-
if backend.BACKEND is not backend.Backend.NUMPY:
506-
raise NotImplementedError()
504+
_cdf = backend.Dispatcher()
507505

506+
@_cdf.numpy
507+
def _cdf_numpy(self, x: ArrayType) -> ArrayType:
508508
import scipy.stats # pylint: disable=import-outside-toplevel
509509

510510
return scipy.stats.multivariate_normal.cdf(
@@ -513,10 +513,10 @@ def _cdf(self, x: ArrayType) -> ArrayType:
513513
cov=self.dense_cov,
514514
)
515515

516-
def _logcdf(self, x: ArrayType) -> ArrayType:
517-
if backend.BACKEND is not backend.Backend.NUMPY:
518-
raise NotImplementedError()
516+
_logcdf = backend.Dispatcher()
519517

518+
@_logcdf.numpy
519+
def _logcdf_numpy(self, x: ArrayType) -> ArrayType:
520520
import scipy.stats # pylint: disable=import-outside-toplevel
521521

522522
return scipy.stats.multivariate_normal.logcdf(

src/probnum/randvars/_random_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def ndim(self) -> int:
162162
def size(self) -> int:
163163
"""Size of realizations of the random variable, defined as the product over all
164164
components of :attr:`shape`."""
165-
return functools.reduce(operator.mul, self.__shape, initial=1)
165+
return functools.reduce(operator.mul, self.__shape, 1)
166166

167167
@property
168168
def dtype(self) -> backend.dtype:

0 commit comments

Comments
 (0)