diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 86e75a7ae..c2bcb2118 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,11 +1,18 @@ * @probabilistic-numerics/probnum-global-codeowners +# Compute Backends +/src/probnum/backend @marvinpfoertner @JonathanWenger +/tests/probnum/backend @marvinpfoertner @JonathanWenger + +# Compatibility Functions +/src/probnum/compat @marvinpfoertner @JonathanWenger + # Differential Equations /src/probnum/diffeq/ @pnkraemer @schmidtjonathan /src/probnum/problems/zoo/diffeq/ @pnkraemer @schmidtjonathan /tests/test_diffeq/ @pnkraemer @schmidtjonathan -/tests/test_problems/test_zoo/test_diffeq/ @pnkraemer @schmidtjonathan +/tests/probnum/problems/zoo/diffeq/ @pnkraemer @schmidtjonathan /benchmarks/ivpsolvers.py @pnkraemer @schmidtjonathan @@ -14,7 +21,7 @@ /src/probnum/problems/zoo/filtsmooth/ @pnkraemer @schmidtjonathan /tests/test_filtsmooth/ @pnkraemer @schmidtjonathan -/tests/test_problems/test_zoo/test_filtsmooth/ @pnkraemer @schmidtjonathan +/tests/problems/zoo/filtsmooth/ @pnkraemer @schmidtjonathan /benchmarks/filtsmooth.py @pnkraemer @schmidtjonathan @@ -22,14 +29,14 @@ /src/probnum/linalg/ @JonathanWenger @marvinpfoertner /src/probnum/problems/zoo/linalg/ @JonathanWenger @marvinpfoertner -/tests/test_linalg/ @JonathanWenger @marvinpfoertner -/tests/test_problems/test_zoo/test_linalg/ @JonathanWenger @marvinpfoertner +/tests/probnum/linalg/ @JonathanWenger @marvinpfoertner +/tests/problems/zoo/linalg/ @JonathanWenger @marvinpfoertner /benchmarks/linearsolvers.py @JonathanWenger @marvinpfoertner # Linear Operators /src/probnum/linops/ @marvinpfoertner @JonathanWenger -/tests/test_linops/ @marvinpfoertner @JonathanWenger +/tests/probnum/linops/ @marvinpfoertner @JonathanWenger /benchmarks/linops.py @marvinpfoertner @JonathanWenger @@ -38,24 +45,20 @@ /src/probnum/problems/zoo/quad/ @mmahsereci @tskarvone /tests/test_quad/ @mmahsereci @tskarvone -/tests/test_problems/test_zoo/test_quad/ @mmahsereci @tskarvone +/tests/problems/zoo/quad/ @mmahsereci @tskarvone # Random Processes & Kernels /src/probnum/randprocs/ @marvinpfoertner @JonathanWenger -/tests/test_randprocs/ @marvinpfoertner @JonathanWenger +/tests/probnum/randprocs/ @marvinpfoertner @JonathanWenger /benchmarks/randprocs.py @marvinpfoertner @JonathanWenger /benchmarks/kernels.py @marvinpfoertner @JonathanWenger /src/probnum/randprocs/markov/ @pnkraemer @schmidtjonathan -/tests/test_randprocs/test_markov/ @pnkraemer @schmidtjonathan +/tests/probnum/randprocs/markov/ @pnkraemer @schmidtjonathan # Random Variables /src/probnum/randvars/ @marvinpfoertner @JonathanWenger -/tests/test_randvars/ @marvinpfoertner @JonathanWenger +/tests/probnum/randvars/ @marvinpfoertner @JonathanWenger /benchmarks/random_variables.py @marvinpfoertner @JonathanWenger - -# Utils -/src/probnum/utils/linalg/_cholesky_updates.py @pnkraemer -/tests/test_utils/test_linalg/test_cholesky_updates.py @pnkraemer diff --git a/.github/workflows/CI-build.yml b/.github/workflows/CI-build.yml index e5297ee4f..199f0b633 100644 --- a/.github/workflows/CI-build.yml +++ b/.github/workflows/CI-build.yml @@ -16,6 +16,7 @@ jobs: matrix: platform: [ubuntu-latest, macos-latest, windows-latest] python: ["3.8", "3.9", "3.10"] + backend: ["numpy", "jax", "torch"] steps: - uses: actions/checkout@v2 @@ -28,11 +29,11 @@ jobs: - name: Install Tox and any other packages run: pip install tox - name: Run Tox - # Run tox using the version of Python in `PATH` - run: tox -e py3 + # Run tox using the version of Python in `PATH` and the corresponding compute backend + run: tox -e py3-${{ matrix.backend }} - name: Upload coverage report to Codecov uses: codecov/codecov-action@v2 - if: startsWith(matrix.platform,'ubuntu') && matrix.python == '3.8' + if: startsWith(matrix.platform,'ubuntu') && matrix.python == '3.8' && matrix.backend == 'numpy' documentation: runs-on: ubuntu-latest diff --git a/benchmarks/linearsolvers.py b/benchmarks/linearsolvers.py index f79befb4e..3268b29fe 100644 --- a/benchmarks/linearsolvers.py +++ b/benchmarks/linearsolvers.py @@ -1,7 +1,7 @@ """Benchmarks for linear solvers.""" import numpy as np -from probnum import linops, problems, randvars +from probnum import backend, linops, problems, randvars from probnum.linalg import problinsolve from probnum.problems.zoo.linalg import random_sparse_spd_matrix, random_spd_matrix @@ -11,26 +11,36 @@ def get_linear_system(name: str, dim: int): - rng = np.random.default_rng(0) + rng_state = backend.random.rng_state(42) if name == "dense": if dim > 1000: raise NotImplementedError() - A = random_spd_matrix(rng=rng, dim=dim) + rng_state, rng_state_A = backend.random.split(rng_state, 2) + A = random_spd_matrix(rng_state=rng_state_A, shape=(dim, dim)) elif name == "sparse": + rng_state, rng_state_A_sparse = backend.random.split(rng_state, 2) A = random_sparse_spd_matrix( - rng=rng, dim=dim, density=np.minimum(1.0, 1000 / dim**2) + rng_state=rng_state_A_sparse, + shape=(dim, dim), + density=backend.minimum(1.0, 1000 / dim**2), ) elif name == "linop": if dim > 100: raise NotImplementedError() # TODO: Larger benchmarks currently fail. Remove once PLS refactor # (https://github.com/probabilistic-numerics/probnum/issues/51) is resolved - A = linops.Scaling(factors=rng.normal(size=(dim,))) + rng_state, rng_state_A_linop = backend.random.split(rng_state, 2) + A = linops.Scaling( + factors=backend.random.standard_normal( + rng_state=rng_state_A_linop, shape=(dim,) + ) + ) else: raise NotImplementedError() - solution = rng.normal(size=(dim,)) + rng_state, rng_state_solution = backend.random.split(rng_state, 2) + solution = backend.random.standard_normal(rng_state_solution, shape=(dim,)) b = A @ solution return problems.LinearSystem(A=A, b=b, solution=solution) @@ -72,14 +82,16 @@ def peakmem_solve(self, linsys, dim): problinsolve(A=self.linsys.A, b=self.linsys.b) def track_residual_norm(self, linsys, dim): - return np.linalg.norm(self.linsys.b - self.linsys.A @ self.xhat.mean) + return backend.linalg.vector_norm( + self.linsys.b - self.linsys.A @ self.xhat.mean + ).item() def track_error_2norm(self, linsys, dim): - return np.linalg.norm(self.linsys.solution - self.xhat.mean) + return backend.linalg.vector_norm(self.linsys.solution - self.xhat.mean).item() def track_error_Anorm(self, linsys, dim): diff = self.linsys.solution - self.xhat.mean - return np.sqrt(np.inner(diff, self.linsys.A @ diff)) + return backend.sqrt(np.inner(diff, self.linsys.A @ diff)) class PosteriorBelief: diff --git a/benchmarks/random_variables.py b/benchmarks/random_variables.py index d9549869a..d0f6b89ce 100644 --- a/benchmarks/random_variables.py +++ b/benchmarks/random_variables.py @@ -2,7 +2,7 @@ import numpy as np -from probnum import linops, randvars as rvs +from probnum import backend, linops, randvars # Module level variables RV_NAMES = [ @@ -39,15 +39,15 @@ def get_randvar(rv_name): cov_2d_symkron = linops.SymmetricKronecker(A=SPD_MATRIX_5x5) if rv_name == "univar_normal": - randvar = rvs.Normal(mean=mean_0d, cov=cov_0d) + randvar = randvars.Normal(mean=mean_0d, cov=cov_0d) elif rv_name == "multivar_normal": - randvar = rvs.Normal(mean=mean_1d, cov=cov_1d) + randvar = randvars.Normal(mean=mean_1d, cov=cov_1d) elif rv_name == "matrixvar_normal": - randvar = rvs.Normal(mean=mean_2d_mat, cov=cov_2d_kron) + randvar = randvars.Normal(mean=mean_2d_mat, cov=cov_2d_kron) elif rv_name == "symmatrixvar_normal": - randvar = rvs.Normal(mean=mean_2d_mat, cov=cov_2d_symkron) + randvar = randvars.Normal(mean=mean_2d_mat, cov=cov_2d_symkron) elif rv_name == "operatorvar_normal": - randvar = rvs.Normal(mean=mean_2d_linop, cov=cov_2d_symkron) + randvar = randvars.Normal(mean=mean_2d_linop, cov=cov_2d_symkron) else: raise ValueError("Random variable not found.") @@ -87,14 +87,14 @@ class Sampling: params = [RV_NAMES] def setup(self, randvar): - self.rng = np.random.default_rng(seed=2) + self.rng_state = backend.random.rng_state(23529) self.n_samples = 1000 self.randvar = get_randvar(rv_name=randvar) def time_sample(self, randvar): """Times sampling from this distribution.""" - self.randvar.sample(rng=self.rng, size=self.n_samples) + self.randvar.sample(rng_state=self.rng_state, sample_shape=self.n_samples) def peakmem_sample(self, randvar): """Peak memory of sampling process.""" - self.randvar.sample(rng=self.rng, size=self.n_samples) + self.randvar.sample(rng_state=self.rng_state, sample_shape=self.n_samples) diff --git a/docs/source/api.rst b/docs/source/api.rst index 0bfbf8d68..cbef2811d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -7,6 +7,10 @@ API Reference +-------------------------------------------------+--------------------------------------------------------------+ | **Subpackage** | **Description** | +-------------------------------------------------+--------------------------------------------------------------+ + | :mod:`~probnum.backend` | Generic computation backend. | + +-------------------------------------------------+--------------------------------------------------------------+ + | :mod:`~probnum.compat` | Compatibility functions. | + +-------------------------------------------------+--------------------------------------------------------------+ | :class:`config ` | Global configuration options. | +-------------------------------------------------+--------------------------------------------------------------+ | :mod:`~probnum.diffeq` | Probabilistic solvers for ordinary differential equations. | @@ -29,8 +33,6 @@ API Reference +-------------------------------------------------+--------------------------------------------------------------+ | :mod:`~probnum.typing` | Type aliases. | +-------------------------------------------------+--------------------------------------------------------------+ - | :mod:`~probnum.utils` | Utility functions. | - +-------------------------------------------------+--------------------------------------------------------------+ .. toctree:: @@ -38,6 +40,8 @@ API Reference :hidden: api/probnum + api/backend + api/compat api/config api/diffeq api/filtsmooth @@ -49,4 +53,3 @@ API Reference api/randprocs api/randvars api/typing - api/utils diff --git a/docs/source/api/backend.rst b/docs/source/api/backend.rst new file mode 100644 index 000000000..06b361ca3 --- /dev/null +++ b/docs/source/api/backend.rst @@ -0,0 +1,100 @@ +*************** +probnum.backend +*************** + +.. automodule:: probnum.backend + +.. currentmodule:: probnum.backend + +Classes +------- + +.. autosummary:: + + ~probnum.backend.Dispatcher + + +.. toctree:: + :hidden: + + backend/array_object + +.. toctree:: + :hidden: + + backend/data_types + +.. toctree:: + :hidden: + + backend/creation_functions + +.. toctree:: + :hidden: + + backend/elementwise_functions + +.. toctree:: + :hidden: + + backend/logic_functions + +.. toctree:: + :hidden: + + backend/manipulation_functions + +.. toctree:: + :hidden: + + backend/searching_functions + +.. toctree:: + :hidden: + + backend/sorting_functions + +.. toctree:: + :hidden: + + backend/statistical_functions + +.. toctree:: + :hidden: + + backend/jit_compilation + +.. toctree:: + :hidden: + + backend/vectorization + +.. toctree:: + :hidden: + + backend/probnum.backend.Dispatcher + +.. toctree:: + :hidden: + + backend/autodiff + +.. toctree:: + :hidden: + + backend/linalg + +.. toctree:: + :hidden: + + backend/random + +.. toctree:: + :hidden: + + backend/special + +.. toctree:: + :hidden: + + backend/typing diff --git a/docs/source/api/backend/array_object.rst b/docs/source/api/backend/array_object.rst new file mode 100644 index 000000000..80e31aab2 --- /dev/null +++ b/docs/source/api/backend/array_object.rst @@ -0,0 +1,40 @@ +Array Object +============ + +The basic object representing a multi-dimensional array and adjacent functionality. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.asshape + ~probnum.backend.isarray + ~probnum.backend.ndim + ~probnum.backend.to_numpy + +Classes +------- + ++----------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.Array` | Object representing a multi-dimensional array stored on a :class:`~probnum.backend.Device` and containing elements of the same | +| | :class:`~probnum.backend.DType`. | ++----------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.Scalar` | Object representing a scalar with a :class:`~probnum.backend.DType`. | ++----------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.Device` | Device, such as a CPU or GPU, on which an :class:`~probnum.backend.Array` is located. | ++----------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ + + +.. toctree:: + :hidden: + + array_object/probnum.backend.asshape + array_object/probnum.backend.isarray + array_object/probnum.backend.ndim + array_object/probnum.backend.to_numpy + array_object/probnum.backend.Array + array_object/probnum.backend.Device + array_object/probnum.backend.Scalar diff --git a/docs/source/api/backend/array_object/probnum.backend.Array.rst b/docs/source/api/backend/array_object/probnum.backend.Array.rst new file mode 100644 index 000000000..2c05391c9 --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.Array.rst @@ -0,0 +1,11 @@ +Array +===== + +.. currentmodule:: probnum.backend + +.. autoclass:: Array + +Object representing a multi-dimensional array stored on a :class:`~probnum.backend.Device` and containing elements of the same :class:`~probnum.backend.DType`. + +Depending on the chosen backend, :class:`~probnum.backend.Array` is an alias of +:class:`numpy.ndarray`, :class:`jax.numpy.ndarray` or :class:`torch.Tensor`. diff --git a/docs/source/api/backend/array_object/probnum.backend.Device.rst b/docs/source/api/backend/array_object/probnum.backend.Device.rst new file mode 100644 index 000000000..85f27b3bc --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.Device.rst @@ -0,0 +1,12 @@ +Device +====== + +.. currentmodule:: probnum.backend + +.. autoclass:: Device + +Device, such as a CPU or GPU, on which an :class:`~probnum.backend.Array` is located. + +.. note:: + + Currently the NumPy backend only supports the CPU. diff --git a/docs/source/api/backend/array_object/probnum.backend.Scalar.rst b/docs/source/api/backend/array_object/probnum.backend.Scalar.rst new file mode 100644 index 000000000..8bd6d1045 --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.Scalar.rst @@ -0,0 +1,11 @@ +Scalar +====== + +.. currentmodule:: probnum.backend + +.. autoclass:: Scalar + +Object representing a scalar with a :class:`~probnum.backend.DType`. + +Depending on the chosen backend :class:`~probnum.backend.Scalar` is an alias of +:class:`numpy.generic`, :class:`jax.numpy.ndarray` or :class:`torch.Tensor`. diff --git a/docs/source/api/backend/array_object/probnum.backend.asshape.rst b/docs/source/api/backend/array_object/probnum.backend.asshape.rst new file mode 100644 index 000000000..4472417cf --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.asshape.rst @@ -0,0 +1,6 @@ +asshape +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: asshape diff --git a/docs/source/api/backend/array_object/probnum.backend.isarray.rst b/docs/source/api/backend/array_object/probnum.backend.isarray.rst new file mode 100644 index 000000000..749542d1a --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.isarray.rst @@ -0,0 +1,6 @@ +isarray +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: isarray diff --git a/docs/source/api/backend/array_object/probnum.backend.ndim.rst b/docs/source/api/backend/array_object/probnum.backend.ndim.rst new file mode 100644 index 000000000..665aeb793 --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.ndim.rst @@ -0,0 +1,6 @@ +ndim +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: ndim diff --git a/docs/source/api/backend/array_object/probnum.backend.to_numpy.rst b/docs/source/api/backend/array_object/probnum.backend.to_numpy.rst new file mode 100644 index 000000000..2455c44af --- /dev/null +++ b/docs/source/api/backend/array_object/probnum.backend.to_numpy.rst @@ -0,0 +1,6 @@ +to_numpy +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: to_numpy diff --git a/docs/source/api/backend/autodiff.rst b/docs/source/api/backend/autodiff.rst new file mode 100644 index 000000000..63b0346ea --- /dev/null +++ b/docs/source/api/backend/autodiff.rst @@ -0,0 +1,5 @@ +probnum.backend.autodiff +------------------------ +.. automodapi:: probnum.backend.autodiff + :no-heading: + :headings: "*" diff --git a/docs/source/api/backend/creation_functions.rst b/docs/source/api/backend/creation_functions.rst new file mode 100644 index 000000000..cb8407f1b --- /dev/null +++ b/docs/source/api/backend/creation_functions.rst @@ -0,0 +1,51 @@ +Array Creation Functions +======================== + +Functions for creating arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.arange + ~probnum.backend.asarray + ~probnum.backend.asscalar + ~probnum.backend.diag + ~probnum.backend.empty + ~probnum.backend.empty_like + ~probnum.backend.eye + ~probnum.backend.full + ~probnum.backend.full_like + ~probnum.backend.linspace + ~probnum.backend.meshgrid + ~probnum.backend.ones + ~probnum.backend.ones_like + ~probnum.backend.tril + ~probnum.backend.triu + ~probnum.backend.zeros + ~probnum.backend.zeros_like + + +.. toctree:: + :hidden: + + creation_functions/probnum.backend.arange + creation_functions/probnum.backend.asarray + creation_functions/probnum.backend.asscalar + creation_functions/probnum.backend.diag + creation_functions/probnum.backend.empty + creation_functions/probnum.backend.empty_like + creation_functions/probnum.backend.eye + creation_functions/probnum.backend.full + creation_functions/probnum.backend.full_like + creation_functions/probnum.backend.linspace + creation_functions/probnum.backend.meshgrid + creation_functions/probnum.backend.ones + creation_functions/probnum.backend.ones_like + creation_functions/probnum.backend.tril + creation_functions/probnum.backend.triu + creation_functions/probnum.backend.zeros + creation_functions/probnum.backend.zeros_like diff --git a/docs/source/api/backend/creation_functions/probnum.backend.arange.rst b/docs/source/api/backend/creation_functions/probnum.backend.arange.rst new file mode 100644 index 000000000..a9ee929b8 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.arange.rst @@ -0,0 +1,6 @@ +arange +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: arange diff --git a/docs/source/api/backend/creation_functions/probnum.backend.asarray.rst b/docs/source/api/backend/creation_functions/probnum.backend.asarray.rst new file mode 100644 index 000000000..01ac3ce3f --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.asarray.rst @@ -0,0 +1,6 @@ +asarray +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: asarray diff --git a/docs/source/api/backend/creation_functions/probnum.backend.asscalar.rst b/docs/source/api/backend/creation_functions/probnum.backend.asscalar.rst new file mode 100644 index 000000000..48ad95b5c --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.asscalar.rst @@ -0,0 +1,6 @@ +asscalar +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: asscalar diff --git a/docs/source/api/backend/creation_functions/probnum.backend.diag.rst b/docs/source/api/backend/creation_functions/probnum.backend.diag.rst new file mode 100644 index 000000000..f3e2cc50d --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.diag.rst @@ -0,0 +1,6 @@ +diag +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: diag diff --git a/docs/source/api/backend/creation_functions/probnum.backend.empty.rst b/docs/source/api/backend/creation_functions/probnum.backend.empty.rst new file mode 100644 index 000000000..51f924d91 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.empty.rst @@ -0,0 +1,6 @@ +empty +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: empty diff --git a/docs/source/api/backend/creation_functions/probnum.backend.empty_like.rst b/docs/source/api/backend/creation_functions/probnum.backend.empty_like.rst new file mode 100644 index 000000000..6480d0e5a --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.empty_like.rst @@ -0,0 +1,6 @@ +empty_like +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: empty_like diff --git a/docs/source/api/backend/creation_functions/probnum.backend.eye.rst b/docs/source/api/backend/creation_functions/probnum.backend.eye.rst new file mode 100644 index 000000000..986532ad1 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.eye.rst @@ -0,0 +1,6 @@ +eye +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: eye diff --git a/docs/source/api/backend/creation_functions/probnum.backend.full.rst b/docs/source/api/backend/creation_functions/probnum.backend.full.rst new file mode 100644 index 000000000..982d7cec9 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.full.rst @@ -0,0 +1,6 @@ +full +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: full diff --git a/docs/source/api/backend/creation_functions/probnum.backend.full_like.rst b/docs/source/api/backend/creation_functions/probnum.backend.full_like.rst new file mode 100644 index 000000000..386bee2c6 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.full_like.rst @@ -0,0 +1,6 @@ +full_like +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: full_like diff --git a/docs/source/api/backend/creation_functions/probnum.backend.linspace.rst b/docs/source/api/backend/creation_functions/probnum.backend.linspace.rst new file mode 100644 index 000000000..f7080f72f --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.linspace.rst @@ -0,0 +1,6 @@ +linspace +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: linspace diff --git a/docs/source/api/backend/creation_functions/probnum.backend.meshgrid.rst b/docs/source/api/backend/creation_functions/probnum.backend.meshgrid.rst new file mode 100644 index 000000000..087766f3e --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.meshgrid.rst @@ -0,0 +1,6 @@ +meshgrid +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: meshgrid diff --git a/docs/source/api/backend/creation_functions/probnum.backend.ones.rst b/docs/source/api/backend/creation_functions/probnum.backend.ones.rst new file mode 100644 index 000000000..1cef92351 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.ones.rst @@ -0,0 +1,6 @@ +ones +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: ones diff --git a/docs/source/api/backend/creation_functions/probnum.backend.ones_like.rst b/docs/source/api/backend/creation_functions/probnum.backend.ones_like.rst new file mode 100644 index 000000000..703cf0a5d --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.ones_like.rst @@ -0,0 +1,6 @@ +ones_like +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: ones_like diff --git a/docs/source/api/backend/creation_functions/probnum.backend.tril.rst b/docs/source/api/backend/creation_functions/probnum.backend.tril.rst new file mode 100644 index 000000000..b11aa2265 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.tril.rst @@ -0,0 +1,6 @@ +tril +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: tril diff --git a/docs/source/api/backend/creation_functions/probnum.backend.triu.rst b/docs/source/api/backend/creation_functions/probnum.backend.triu.rst new file mode 100644 index 000000000..2f1aab4c4 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.triu.rst @@ -0,0 +1,6 @@ +triu +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: triu diff --git a/docs/source/api/backend/creation_functions/probnum.backend.zeros.rst b/docs/source/api/backend/creation_functions/probnum.backend.zeros.rst new file mode 100644 index 000000000..4c722eda5 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.zeros.rst @@ -0,0 +1,6 @@ +zeros +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: zeros diff --git a/docs/source/api/backend/creation_functions/probnum.backend.zeros_like.rst b/docs/source/api/backend/creation_functions/probnum.backend.zeros_like.rst new file mode 100644 index 000000000..16a4e3b00 --- /dev/null +++ b/docs/source/api/backend/creation_functions/probnum.backend.zeros_like.rst @@ -0,0 +1,6 @@ +zeros_like +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: zeros_like diff --git a/docs/source/api/backend/data_types.rst b/docs/source/api/backend/data_types.rst new file mode 100644 index 000000000..31f9185eb --- /dev/null +++ b/docs/source/api/backend/data_types.rst @@ -0,0 +1,68 @@ +Data Types +========== + +Fundamental (array) data types. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.asdtype + ~probnum.backend.can_cast + ~probnum.backend.cast + ~probnum.backend.finfo + ~probnum.backend.iinfo + ~probnum.backend.is_floating_dtype + ~probnum.backend.promote_types + ~probnum.backend.result_type + + +Classes +------- + ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.DType` | Data type of an :class:`~probnum.backend.Array`. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.bool` | Boolean (``True`` or ``False``). | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.int32` | A 32-bit signed integer. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.int64` | A 64-bit signed integer. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.float16` | IEEE 754 half-precision (16-bit) binary floating-point number. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.float32` | IEEE 754 single-precision (32-bit) binary floating-point number. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.float64` | IEEE 754 double-precision (64-bit) binary floating-point number. | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.complex64` | Single-precision complex number represented by two :class:`~probnum.backend.float32`\s (real and imaginary components). | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ +| :class:`~probnum.backend.complex128` | Double-precision complex number represented by two :class:`~probnum.backend.float64`\s (real and imaginary components). | ++--------------------------------------+-------------------------------------------------------------------------------------------------------------------------+ + + +.. toctree:: + :hidden: + + data_types/probnum.backend.DType + data_types/probnum.backend.bool + data_types/probnum.backend.int32 + data_types/probnum.backend.int64 + data_types/probnum.backend.float16 + data_types/probnum.backend.float32 + data_types/probnum.backend.float64 + data_types/probnum.backend.complex64 + data_types/probnum.backend.complex128 + data_types/probnum.backend.MachineLimitsFloatingPoint + data_types/probnum.backend.MachineLimitsInteger + data_types/probnum.backend.asdtype + data_types/probnum.backend.can_cast + data_types/probnum.backend.cast + data_types/probnum.backend.finfo + data_types/probnum.backend.iinfo + data_types/probnum.backend.is_floating_dtype + data_types/probnum.backend.promote_types + data_types/probnum.backend.result_type diff --git a/docs/source/api/backend/data_types/probnum.backend.DType.rst b/docs/source/api/backend/data_types/probnum.backend.DType.rst new file mode 100644 index 000000000..27058e4e6 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.DType.rst @@ -0,0 +1,8 @@ +DType +===== + +.. currentmodule:: probnum.backend + +.. autoclass:: DType + +Data type of an :class:`~probnum.backend.Array`. diff --git a/docs/source/api/backend/data_types/probnum.backend.MachineLimitsFloatingPoint.rst b/docs/source/api/backend/data_types/probnum.backend.MachineLimitsFloatingPoint.rst new file mode 100644 index 000000000..5c10daf27 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.MachineLimitsFloatingPoint.rst @@ -0,0 +1,6 @@ +MachineLimitsFloatingPoint +========================== + +.. currentmodule:: probnum.backend + +.. autoclass:: MachineLimitsFloatingPoint diff --git a/docs/source/api/backend/data_types/probnum.backend.MachineLimitsInteger.rst b/docs/source/api/backend/data_types/probnum.backend.MachineLimitsInteger.rst new file mode 100644 index 000000000..4d121e211 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.MachineLimitsInteger.rst @@ -0,0 +1,6 @@ +MachineLimitsInteger +==================== + +.. currentmodule:: probnum.backend + +.. autoclass:: MachineLimitsInteger diff --git a/docs/source/api/backend/data_types/probnum.backend.asdtype.rst b/docs/source/api/backend/data_types/probnum.backend.asdtype.rst new file mode 100644 index 000000000..436d837da --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.asdtype.rst @@ -0,0 +1,6 @@ +asdtype +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: asdtype diff --git a/docs/source/api/backend/data_types/probnum.backend.bool.rst b/docs/source/api/backend/data_types/probnum.backend.bool.rst new file mode 100644 index 000000000..0e2cd697c --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.bool.rst @@ -0,0 +1,8 @@ +bool +==== + +.. currentmodule:: probnum.backend + +.. autoclass:: bool + +Boolean (``True`` or ``False``). diff --git a/docs/source/api/backend/data_types/probnum.backend.can_cast.rst b/docs/source/api/backend/data_types/probnum.backend.can_cast.rst new file mode 100644 index 000000000..56f3127aa --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.can_cast.rst @@ -0,0 +1,6 @@ +can_cast +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: can_cast diff --git a/docs/source/api/backend/data_types/probnum.backend.cast.rst b/docs/source/api/backend/data_types/probnum.backend.cast.rst new file mode 100644 index 000000000..ee331169a --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.cast.rst @@ -0,0 +1,6 @@ +cast +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: cast diff --git a/docs/source/api/backend/data_types/probnum.backend.complex128.rst b/docs/source/api/backend/data_types/probnum.backend.complex128.rst new file mode 100644 index 000000000..44b4a4443 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.complex128.rst @@ -0,0 +1,9 @@ +complex128 +========== + +.. currentmodule:: probnum.backend + +.. autoclass:: complex128 + +Double-precision complex number represented by two double-precision floats (real and +imaginary components). diff --git a/docs/source/api/backend/data_types/probnum.backend.complex64.rst b/docs/source/api/backend/data_types/probnum.backend.complex64.rst new file mode 100644 index 000000000..c02f1c731 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.complex64.rst @@ -0,0 +1,9 @@ +complex64 +========= + +.. currentmodule:: probnum.backend + +.. autoclass:: complex64 + +Single-precision complex number represented by two single-precision floats (real and +imaginary components). diff --git a/docs/source/api/backend/data_types/probnum.backend.finfo.rst b/docs/source/api/backend/data_types/probnum.backend.finfo.rst new file mode 100644 index 000000000..d156b6c2b --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.finfo.rst @@ -0,0 +1,6 @@ +finfo +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: finfo diff --git a/docs/source/api/backend/data_types/probnum.backend.float16.rst b/docs/source/api/backend/data_types/probnum.backend.float16.rst new file mode 100644 index 000000000..242947519 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.float16.rst @@ -0,0 +1,8 @@ +float16 +======= + +.. currentmodule:: probnum.backend + +.. autoclass:: float16 + +IEEE 754 half-precision (16-bit) binary floating-point number (see IEEE 754-2019). diff --git a/docs/source/api/backend/data_types/probnum.backend.float32.rst b/docs/source/api/backend/data_types/probnum.backend.float32.rst new file mode 100644 index 000000000..3d428a409 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.float32.rst @@ -0,0 +1,8 @@ +float32 +======= + +.. currentmodule:: probnum.backend + +.. autoclass:: float32 + +IEEE 754 single-precision (32-bit) binary floating-point number (see IEEE 754-2019). diff --git a/docs/source/api/backend/data_types/probnum.backend.float64.rst b/docs/source/api/backend/data_types/probnum.backend.float64.rst new file mode 100644 index 000000000..4037fa0ec --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.float64.rst @@ -0,0 +1,8 @@ +float64 +======= + +.. currentmodule:: probnum.backend + +.. autoclass:: float64 + +IEEE 754 double-precision (64-bit) binary floating-point number (see IEEE 754-2019). diff --git a/docs/source/api/backend/data_types/probnum.backend.iinfo.rst b/docs/source/api/backend/data_types/probnum.backend.iinfo.rst new file mode 100644 index 000000000..56afb3d23 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.iinfo.rst @@ -0,0 +1,6 @@ +iinfo +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: iinfo diff --git a/docs/source/api/backend/data_types/probnum.backend.int32.rst b/docs/source/api/backend/data_types/probnum.backend.int32.rst new file mode 100644 index 000000000..1407256d8 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.int32.rst @@ -0,0 +1,8 @@ +int32 +===== + +.. currentmodule:: probnum.backend + +.. autoclass:: int32 + +A 32-bit signed integer whose values exist on the interval ``[-2e9, +2e9]``. diff --git a/docs/source/api/backend/data_types/probnum.backend.int64.rst b/docs/source/api/backend/data_types/probnum.backend.int64.rst new file mode 100644 index 000000000..3df48aa76 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.int64.rst @@ -0,0 +1,8 @@ +int64 +===== + +.. currentmodule:: probnum.backend + +.. autoclass:: int64 + +A 64-bit signed integer whose values exist on the interval ``[-9e18, +9e18]``. diff --git a/docs/source/api/backend/data_types/probnum.backend.is_floating_dtype.rst b/docs/source/api/backend/data_types/probnum.backend.is_floating_dtype.rst new file mode 100644 index 000000000..0ad407b9d --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.is_floating_dtype.rst @@ -0,0 +1,6 @@ +is_floating_dtype +================= + +.. currentmodule:: probnum.backend + +.. autofunction:: is_floating_dtype diff --git a/docs/source/api/backend/data_types/probnum.backend.promote_types.rst b/docs/source/api/backend/data_types/probnum.backend.promote_types.rst new file mode 100644 index 000000000..9f57202b4 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.promote_types.rst @@ -0,0 +1,6 @@ +promote_types +============= + +.. currentmodule:: probnum.backend + +.. autofunction:: promote_types diff --git a/docs/source/api/backend/data_types/probnum.backend.result_type.rst b/docs/source/api/backend/data_types/probnum.backend.result_type.rst new file mode 100644 index 000000000..e12906699 --- /dev/null +++ b/docs/source/api/backend/data_types/probnum.backend.result_type.rst @@ -0,0 +1,6 @@ +result_type +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: result_type diff --git a/docs/source/api/backend/elementwise_functions.rst b/docs/source/api/backend/elementwise_functions.rst new file mode 100644 index 000000000..ee76c6d28 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions.rst @@ -0,0 +1,119 @@ +Element-wise Functions +====================== + +Functions applied element-wise to arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.abs + ~probnum.backend.acos + ~probnum.backend.acosh + ~probnum.backend.add + ~probnum.backend.asin + ~probnum.backend.asinh + ~probnum.backend.atan + ~probnum.backend.atan2 + ~probnum.backend.atanh + ~probnum.backend.bitwise_and + ~probnum.backend.bitwise_left_shift + ~probnum.backend.bitwise_invert + ~probnum.backend.bitwise_or + ~probnum.backend.bitwise_right_shift + ~probnum.backend.bitwise_xor + ~probnum.backend.ceil + ~probnum.backend.conj + ~probnum.backend.cos + ~probnum.backend.cosh + ~probnum.backend.divide + ~probnum.backend.exp + ~probnum.backend.expm1 + ~probnum.backend.floor + ~probnum.backend.floor_divide + ~probnum.backend.imag + ~probnum.backend.isfinite + ~probnum.backend.isinf + ~probnum.backend.isnan + ~probnum.backend.log + ~probnum.backend.log1p + ~probnum.backend.log2 + ~probnum.backend.log10 + ~probnum.backend.logaddexp + ~probnum.backend.multiply + ~probnum.backend.maximum + ~probnum.backend.minimum + ~probnum.backend.negative + ~probnum.backend.positive + ~probnum.backend.pow + ~probnum.backend.real + ~probnum.backend.remainder + ~probnum.backend.round + ~probnum.backend.sign + ~probnum.backend.sin + ~probnum.backend.sinh + ~probnum.backend.square + ~probnum.backend.sqrt + ~probnum.backend.subtract + ~probnum.backend.tan + ~probnum.backend.tanh + ~probnum.backend.trunc + + +.. toctree:: + :hidden: + + elementwise_functions/probnum.backend.abs + elementwise_functions/probnum.backend.acos + elementwise_functions/probnum.backend.acosh + elementwise_functions/probnum.backend.add + elementwise_functions/probnum.backend.asin + elementwise_functions/probnum.backend.asinh + elementwise_functions/probnum.backend.atan + elementwise_functions/probnum.backend.atan2 + elementwise_functions/probnum.backend.atanh + elementwise_functions/probnum.backend.bitwise_and + elementwise_functions/probnum.backend.bitwise_left_shift + elementwise_functions/probnum.backend.bitwise_invert + elementwise_functions/probnum.backend.bitwise_or + elementwise_functions/probnum.backend.bitwise_right_shift + elementwise_functions/probnum.backend.bitwise_xor + elementwise_functions/probnum.backend.ceil + elementwise_functions/probnum.backend.conj + elementwise_functions/probnum.backend.cos + elementwise_functions/probnum.backend.cosh + elementwise_functions/probnum.backend.divide + elementwise_functions/probnum.backend.exp + elementwise_functions/probnum.backend.expm1 + elementwise_functions/probnum.backend.floor + elementwise_functions/probnum.backend.floor_divide + elementwise_functions/probnum.backend.imag + elementwise_functions/probnum.backend.isfinite + elementwise_functions/probnum.backend.isinf + elementwise_functions/probnum.backend.isnan + elementwise_functions/probnum.backend.log + elementwise_functions/probnum.backend.log1p + elementwise_functions/probnum.backend.log2 + elementwise_functions/probnum.backend.log10 + elementwise_functions/probnum.backend.logaddexp + elementwise_functions/probnum.backend.multiply + elementwise_functions/probnum.backend.maximum + elementwise_functions/probnum.backend.minimum + elementwise_functions/probnum.backend.negative + elementwise_functions/probnum.backend.positive + elementwise_functions/probnum.backend.pow + elementwise_functions/probnum.backend.real + elementwise_functions/probnum.backend.remainder + elementwise_functions/probnum.backend.round + elementwise_functions/probnum.backend.sign + elementwise_functions/probnum.backend.sin + elementwise_functions/probnum.backend.sinh + elementwise_functions/probnum.backend.square + elementwise_functions/probnum.backend.sqrt + elementwise_functions/probnum.backend.subtract + elementwise_functions/probnum.backend.tan + elementwise_functions/probnum.backend.tanh + elementwise_functions/probnum.backend.trunc diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.abs.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.abs.rst new file mode 100644 index 000000000..3f5cb354e --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.abs.rst @@ -0,0 +1,6 @@ +abs +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: abs diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.acos.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.acos.rst new file mode 100644 index 000000000..716d99b82 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.acos.rst @@ -0,0 +1,6 @@ +acos +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: acos diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.acosh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.acosh.rst new file mode 100644 index 000000000..c3749154f --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.acosh.rst @@ -0,0 +1,6 @@ +acosh +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: acosh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.add.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.add.rst new file mode 100644 index 000000000..26da9fc95 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.add.rst @@ -0,0 +1,6 @@ +add +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: add diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.asin.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.asin.rst new file mode 100644 index 000000000..3095e776f --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.asin.rst @@ -0,0 +1,6 @@ +asin +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: asin diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.asinh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.asinh.rst new file mode 100644 index 000000000..a5c2457c3 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.asinh.rst @@ -0,0 +1,6 @@ +asinh +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: asinh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.atan.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.atan.rst new file mode 100644 index 000000000..225199dc6 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.atan.rst @@ -0,0 +1,6 @@ +atan +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: atan diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.atan2.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.atan2.rst new file mode 100644 index 000000000..60f12204b --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.atan2.rst @@ -0,0 +1,6 @@ +atan2 +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: atan2 diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.atanh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.atanh.rst new file mode 100644 index 000000000..a76c030b6 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.atanh.rst @@ -0,0 +1,6 @@ +atanh +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: atanh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_and.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_and.rst new file mode 100644 index 000000000..4c04f58de --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_and.rst @@ -0,0 +1,6 @@ +bitwise_and +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_and diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_invert.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_invert.rst new file mode 100644 index 000000000..c354a6a33 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_invert.rst @@ -0,0 +1,6 @@ +bitwise_invert +============== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_invert diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_left_shift.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_left_shift.rst new file mode 100644 index 000000000..cf6dd7b98 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_left_shift.rst @@ -0,0 +1,6 @@ +bitwise_left_shift +================== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_left_shift diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_or.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_or.rst new file mode 100644 index 000000000..0541ac355 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_or.rst @@ -0,0 +1,6 @@ +bitwise_or +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_or diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_right_shift.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_right_shift.rst new file mode 100644 index 000000000..2a259bfa8 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_right_shift.rst @@ -0,0 +1,6 @@ +bitwise_right_shift +=================== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_right_shift diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_xor.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_xor.rst new file mode 100644 index 000000000..20f245391 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.bitwise_xor.rst @@ -0,0 +1,6 @@ +bitwise_xor +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: bitwise_xor diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.ceil.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.ceil.rst new file mode 100644 index 000000000..7d56f2c9f --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.ceil.rst @@ -0,0 +1,6 @@ +ceil +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: ceil diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.conj.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.conj.rst new file mode 100644 index 000000000..77940070c --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.conj.rst @@ -0,0 +1,6 @@ +conj +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: conj diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.cos.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.cos.rst new file mode 100644 index 000000000..e3b9725d3 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.cos.rst @@ -0,0 +1,6 @@ +cos +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: cos diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.cosh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.cosh.rst new file mode 100644 index 000000000..3bb66b941 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.cosh.rst @@ -0,0 +1,6 @@ +cosh +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: cosh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.divide.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.divide.rst new file mode 100644 index 000000000..1d5c5a3e9 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.divide.rst @@ -0,0 +1,6 @@ +divide +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: divide diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.exp.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.exp.rst new file mode 100644 index 000000000..9d4d55a17 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.exp.rst @@ -0,0 +1,6 @@ +exp +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: exp diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.expm1.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.expm1.rst new file mode 100644 index 000000000..59092e229 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.expm1.rst @@ -0,0 +1,6 @@ +expm1 +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: expm1 diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.floor.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.floor.rst new file mode 100644 index 000000000..59d7fa5c4 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.floor.rst @@ -0,0 +1,6 @@ +floor +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: floor diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.floor_divide.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.floor_divide.rst new file mode 100644 index 000000000..7a51db315 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.floor_divide.rst @@ -0,0 +1,6 @@ +floor_divide +============ + +.. currentmodule:: probnum.backend + +.. autofunction:: floor_divide diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.imag.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.imag.rst new file mode 100644 index 000000000..caf6b5890 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.imag.rst @@ -0,0 +1,6 @@ +imag +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: imag diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.isfinite.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.isfinite.rst new file mode 100644 index 000000000..50c11f217 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.isfinite.rst @@ -0,0 +1,6 @@ +isfinite +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: isfinite diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.isinf.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.isinf.rst new file mode 100644 index 000000000..a6dac5d4a --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.isinf.rst @@ -0,0 +1,6 @@ +isinf +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: isinf diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.isnan.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.isnan.rst new file mode 100644 index 000000000..8ebca277e --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.isnan.rst @@ -0,0 +1,6 @@ +isnan +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: isnan diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.log.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.log.rst new file mode 100644 index 000000000..8f01cbfe1 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.log.rst @@ -0,0 +1,6 @@ +log +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: log diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.log10.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.log10.rst new file mode 100644 index 000000000..6828cbaa8 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.log10.rst @@ -0,0 +1,6 @@ +log10 +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: log10 diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.log1p.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.log1p.rst new file mode 100644 index 000000000..c2dd32e15 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.log1p.rst @@ -0,0 +1,6 @@ +log1p +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: log1p diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.log2.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.log2.rst new file mode 100644 index 000000000..db9a7b7bd --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.log2.rst @@ -0,0 +1,6 @@ +log2 +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: log2 diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.logaddexp.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.logaddexp.rst new file mode 100644 index 000000000..5f4619389 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.logaddexp.rst @@ -0,0 +1,6 @@ +logaddexp +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: logaddexp diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.maximum.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.maximum.rst new file mode 100644 index 000000000..9b10b9c53 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.maximum.rst @@ -0,0 +1,6 @@ +maximum +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: maximum diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.minimum.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.minimum.rst new file mode 100644 index 000000000..dbce948a9 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.minimum.rst @@ -0,0 +1,6 @@ +minimum +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: minimum diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.multiply.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.multiply.rst new file mode 100644 index 000000000..6813c4009 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.multiply.rst @@ -0,0 +1,6 @@ +multiply +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: multiply diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.negative.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.negative.rst new file mode 100644 index 000000000..4ba6006a1 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.negative.rst @@ -0,0 +1,6 @@ +negative +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: negative diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.positive.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.positive.rst new file mode 100644 index 000000000..f1f206326 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.positive.rst @@ -0,0 +1,6 @@ +positive +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: positive diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.pow.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.pow.rst new file mode 100644 index 000000000..20bfd5e8b --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.pow.rst @@ -0,0 +1,6 @@ +pow +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: pow diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.real.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.real.rst new file mode 100644 index 000000000..6b6dc8e62 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.real.rst @@ -0,0 +1,6 @@ +real +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: real diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.remainder.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.remainder.rst new file mode 100644 index 000000000..f07bb7f66 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.remainder.rst @@ -0,0 +1,6 @@ +remainder +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: remainder diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.round.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.round.rst new file mode 100644 index 000000000..aa4bf6f6b --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.round.rst @@ -0,0 +1,6 @@ +round +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: round diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.sign.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.sign.rst new file mode 100644 index 000000000..c310faad4 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.sign.rst @@ -0,0 +1,6 @@ +sign +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: sign diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.sin.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.sin.rst new file mode 100644 index 000000000..f9adba041 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.sin.rst @@ -0,0 +1,6 @@ +sin +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: sin diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.sinh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.sinh.rst new file mode 100644 index 000000000..8e004f90d --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.sinh.rst @@ -0,0 +1,6 @@ +sinh +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: sinh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.sqrt.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.sqrt.rst new file mode 100644 index 000000000..c4750613a --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.sqrt.rst @@ -0,0 +1,6 @@ +sqrt +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: sqrt diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.square.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.square.rst new file mode 100644 index 000000000..69d725ec6 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.square.rst @@ -0,0 +1,6 @@ +square +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: square diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.subtract.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.subtract.rst new file mode 100644 index 000000000..1f456f800 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.subtract.rst @@ -0,0 +1,6 @@ +subtract +======== + +.. currentmodule:: probnum.backend + +.. autofunction:: subtract diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.tan.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.tan.rst new file mode 100644 index 000000000..25c3dd8c0 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.tan.rst @@ -0,0 +1,6 @@ +tan +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: tan diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.tanh.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.tanh.rst new file mode 100644 index 000000000..2a6c70621 --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.tanh.rst @@ -0,0 +1,6 @@ +tanh +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: tanh diff --git a/docs/source/api/backend/elementwise_functions/probnum.backend.trunc.rst b/docs/source/api/backend/elementwise_functions/probnum.backend.trunc.rst new file mode 100644 index 000000000..a2022ab2f --- /dev/null +++ b/docs/source/api/backend/elementwise_functions/probnum.backend.trunc.rst @@ -0,0 +1,6 @@ +trunc +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: trunc diff --git a/docs/source/api/backend/jit_compilation.rst b/docs/source/api/backend/jit_compilation.rst new file mode 100644 index 000000000..19e6a417a --- /dev/null +++ b/docs/source/api/backend/jit_compilation.rst @@ -0,0 +1,21 @@ +JIT Compilation +=============== + +Just-in-time compilation of functions. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.jit + ~probnum.backend.jit_method + + +.. toctree:: + :hidden: + + jit_compilation/probnum.backend.jit + jit_compilation/probnum.backend.jit_method diff --git a/docs/source/api/backend/jit_compilation/probnum.backend.jit.rst b/docs/source/api/backend/jit_compilation/probnum.backend.jit.rst new file mode 100644 index 000000000..568fb3067 --- /dev/null +++ b/docs/source/api/backend/jit_compilation/probnum.backend.jit.rst @@ -0,0 +1,6 @@ +jit +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: jit diff --git a/docs/source/api/backend/jit_compilation/probnum.backend.jit_method.rst b/docs/source/api/backend/jit_compilation/probnum.backend.jit_method.rst new file mode 100644 index 000000000..2b32c56b2 --- /dev/null +++ b/docs/source/api/backend/jit_compilation/probnum.backend.jit_method.rst @@ -0,0 +1,6 @@ +jit_method +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: jit_method diff --git a/docs/source/api/backend/linalg.rst b/docs/source/api/backend/linalg.rst new file mode 100644 index 000000000..40ffbe597 --- /dev/null +++ b/docs/source/api/backend/linalg.rst @@ -0,0 +1,5 @@ +probnum.backend.linalg +---------------------- +.. automodapi:: probnum.backend.linalg + :no-heading: + :headings: "*" diff --git a/docs/source/api/backend/logic_functions.rst b/docs/source/api/backend/logic_functions.rst new file mode 100644 index 000000000..6074dad1a --- /dev/null +++ b/docs/source/api/backend/logic_functions.rst @@ -0,0 +1,41 @@ +Logic Functions +=============== + +Logic functions applied to arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.all + ~probnum.backend.any + ~probnum.backend.equal + ~probnum.backend.greater + ~probnum.backend.greater_equal + ~probnum.backend.less + ~probnum.backend.less_equal + ~probnum.backend.logical_and + ~probnum.backend.logical_not + ~probnum.backend.logical_or + ~probnum.backend.logical_xor + ~probnum.backend.not_equal + + +.. toctree:: + :hidden: + + logic_functions/probnum.backend.all + logic_functions/probnum.backend.any + logic_functions/probnum.backend.equal + logic_functions/probnum.backend.greater + logic_functions/probnum.backend.greater_equal + logic_functions/probnum.backend.less + logic_functions/probnum.backend.less_equal + logic_functions/probnum.backend.logical_and + logic_functions/probnum.backend.logical_not + logic_functions/probnum.backend.logical_or + logic_functions/probnum.backend.logical_xor + logic_functions/probnum.backend.not_equal diff --git a/docs/source/api/backend/logic_functions/probnum.backend.all.rst b/docs/source/api/backend/logic_functions/probnum.backend.all.rst new file mode 100644 index 000000000..0928207be --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.all.rst @@ -0,0 +1,6 @@ +all +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: all diff --git a/docs/source/api/backend/logic_functions/probnum.backend.any.rst b/docs/source/api/backend/logic_functions/probnum.backend.any.rst new file mode 100644 index 000000000..8176f40a6 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.any.rst @@ -0,0 +1,6 @@ +any +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: any diff --git a/docs/source/api/backend/logic_functions/probnum.backend.equal.rst b/docs/source/api/backend/logic_functions/probnum.backend.equal.rst new file mode 100644 index 000000000..c21df92e7 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.equal.rst @@ -0,0 +1,6 @@ +equal +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: equal diff --git a/docs/source/api/backend/logic_functions/probnum.backend.greater.rst b/docs/source/api/backend/logic_functions/probnum.backend.greater.rst new file mode 100644 index 000000000..18be0c415 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.greater.rst @@ -0,0 +1,6 @@ +greater +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: greater diff --git a/docs/source/api/backend/logic_functions/probnum.backend.greater_equal.rst b/docs/source/api/backend/logic_functions/probnum.backend.greater_equal.rst new file mode 100644 index 000000000..58d80f768 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.greater_equal.rst @@ -0,0 +1,6 @@ +greater_equal +============= + +.. currentmodule:: probnum.backend + +.. autofunction:: greater_equal diff --git a/docs/source/api/backend/logic_functions/probnum.backend.less.rst b/docs/source/api/backend/logic_functions/probnum.backend.less.rst new file mode 100644 index 000000000..4edbc2e23 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.less.rst @@ -0,0 +1,6 @@ +less +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: less diff --git a/docs/source/api/backend/logic_functions/probnum.backend.less_equal.rst b/docs/source/api/backend/logic_functions/probnum.backend.less_equal.rst new file mode 100644 index 000000000..3a17bda62 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.less_equal.rst @@ -0,0 +1,6 @@ +less_equal +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: less_equal diff --git a/docs/source/api/backend/logic_functions/probnum.backend.logical_and.rst b/docs/source/api/backend/logic_functions/probnum.backend.logical_and.rst new file mode 100644 index 000000000..45e0666db --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.logical_and.rst @@ -0,0 +1,6 @@ +logical_and +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: logical_and diff --git a/docs/source/api/backend/logic_functions/probnum.backend.logical_not.rst b/docs/source/api/backend/logic_functions/probnum.backend.logical_not.rst new file mode 100644 index 000000000..1ca1c9f7f --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.logical_not.rst @@ -0,0 +1,6 @@ +logical_not +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: logical_not diff --git a/docs/source/api/backend/logic_functions/probnum.backend.logical_or.rst b/docs/source/api/backend/logic_functions/probnum.backend.logical_or.rst new file mode 100644 index 000000000..5e945df29 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.logical_or.rst @@ -0,0 +1,6 @@ +logical_or +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: logical_or diff --git a/docs/source/api/backend/logic_functions/probnum.backend.logical_xor.rst b/docs/source/api/backend/logic_functions/probnum.backend.logical_xor.rst new file mode 100644 index 000000000..54148b2b1 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.logical_xor.rst @@ -0,0 +1,6 @@ +logical_xor +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: logical_xor diff --git a/docs/source/api/backend/logic_functions/probnum.backend.not_equal.rst b/docs/source/api/backend/logic_functions/probnum.backend.not_equal.rst new file mode 100644 index 000000000..4027efd39 --- /dev/null +++ b/docs/source/api/backend/logic_functions/probnum.backend.not_equal.rst @@ -0,0 +1,6 @@ +not_equal +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: not_equal diff --git a/docs/source/api/backend/manipulation_functions.rst b/docs/source/api/backend/manipulation_functions.rst new file mode 100644 index 000000000..81eb700f9 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions.rst @@ -0,0 +1,53 @@ +Manipulation Functions +====================== + +Functions manipulating arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.atleast_1d + ~probnum.backend.atleast_2d + ~probnum.backend.broadcast_arrays + ~probnum.backend.broadcast_shapes + ~probnum.backend.broadcast_to + ~probnum.backend.concat + ~probnum.backend.expand_axes + ~probnum.backend.flip + ~probnum.backend.hstack + ~probnum.backend.move_axes + ~probnum.backend.permute_axes + ~probnum.backend.reshape + ~probnum.backend.roll + ~probnum.backend.squeeze + ~probnum.backend.stack + ~probnum.backend.swap_axes + ~probnum.backend.tile + ~probnum.backend.vstack + + +.. toctree:: + :hidden: + + manipulation_functions/probnum.backend.atleast_1d + manipulation_functions/probnum.backend.atleast_2d + manipulation_functions/probnum.backend.broadcast_arrays + manipulation_functions/probnum.backend.broadcast_shapes + manipulation_functions/probnum.backend.broadcast_to + manipulation_functions/probnum.backend.concat + manipulation_functions/probnum.backend.expand_axes + manipulation_functions/probnum.backend.flip + manipulation_functions/probnum.backend.hstack + manipulation_functions/probnum.backend.move_axes + manipulation_functions/probnum.backend.permute_axes + manipulation_functions/probnum.backend.reshape + manipulation_functions/probnum.backend.roll + manipulation_functions/probnum.backend.squeeze + manipulation_functions/probnum.backend.stack + manipulation_functions/probnum.backend.swap_axes + manipulation_functions/probnum.backend.tile + manipulation_functions/probnum.backend.vstack diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_1d.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_1d.rst new file mode 100644 index 000000000..e60d4bcc8 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_1d.rst @@ -0,0 +1,6 @@ +atleast_1d +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: atleast_1d diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_2d.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_2d.rst new file mode 100644 index 000000000..84b09fa84 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.atleast_2d.rst @@ -0,0 +1,6 @@ +atleast_2d +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: atleast_2d diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_arrays.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_arrays.rst new file mode 100644 index 000000000..fb7e8fb4d --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_arrays.rst @@ -0,0 +1,6 @@ +broadcast_arrays +================ + +.. currentmodule:: probnum.backend + +.. autofunction:: broadcast_arrays diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_shapes.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_shapes.rst new file mode 100644 index 000000000..80d9c0923 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_shapes.rst @@ -0,0 +1,6 @@ +broadcast_shapes +================ + +.. currentmodule:: probnum.backend + +.. autofunction:: broadcast_shapes diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_to.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_to.rst new file mode 100644 index 000000000..88fd34830 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.broadcast_to.rst @@ -0,0 +1,6 @@ +broadcast_to +============ + +.. currentmodule:: probnum.backend + +.. autofunction:: broadcast_to diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.concat.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.concat.rst new file mode 100644 index 000000000..d6b1db2f8 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.concat.rst @@ -0,0 +1,6 @@ +concat +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: concat diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.expand_axes.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.expand_axes.rst new file mode 100644 index 000000000..07e165a76 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.expand_axes.rst @@ -0,0 +1,6 @@ +expand_axes +=========== + +.. currentmodule:: probnum.backend + +.. autofunction:: expand_axes diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.flip.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.flip.rst new file mode 100644 index 000000000..b17b199be --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.flip.rst @@ -0,0 +1,6 @@ +flip +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: flip diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.hstack.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.hstack.rst new file mode 100644 index 000000000..a6cf00572 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.hstack.rst @@ -0,0 +1,6 @@ +hstack +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: hstack diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.move_axes.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.move_axes.rst new file mode 100644 index 000000000..7c20283fe --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.move_axes.rst @@ -0,0 +1,6 @@ +move_axes +========== + +.. currentmodule:: probnum.backend + +.. autofunction:: move_axes diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.permute_axes.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.permute_axes.rst new file mode 100644 index 000000000..1e7f2de78 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.permute_axes.rst @@ -0,0 +1,6 @@ +permute_axes +============ + +.. currentmodule:: probnum.backend + +.. autofunction:: permute_axes diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.reshape.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.reshape.rst new file mode 100644 index 000000000..23964fd1b --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.reshape.rst @@ -0,0 +1,6 @@ +reshape +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: reshape diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.roll.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.roll.rst new file mode 100644 index 000000000..c864b0699 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.roll.rst @@ -0,0 +1,6 @@ +roll +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: roll diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.squeeze.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.squeeze.rst new file mode 100644 index 000000000..5b4ffb914 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.squeeze.rst @@ -0,0 +1,6 @@ +squeeze +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: squeeze diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.stack.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.stack.rst new file mode 100644 index 000000000..b453a8b03 --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.stack.rst @@ -0,0 +1,6 @@ +stack +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: stack diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.swap_axes.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.swap_axes.rst new file mode 100644 index 000000000..422bbd3cb --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.swap_axes.rst @@ -0,0 +1,6 @@ +swap_axes +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: swap_axes diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.tile.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.tile.rst new file mode 100644 index 000000000..7a6dfb84a --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.tile.rst @@ -0,0 +1,6 @@ +tile +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: tile diff --git a/docs/source/api/backend/manipulation_functions/probnum.backend.vstack.rst b/docs/source/api/backend/manipulation_functions/probnum.backend.vstack.rst new file mode 100644 index 000000000..10d72a75a --- /dev/null +++ b/docs/source/api/backend/manipulation_functions/probnum.backend.vstack.rst @@ -0,0 +1,6 @@ +vstack +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: vstack diff --git a/docs/source/api/backend/probnum.backend.Dispatcher.rst b/docs/source/api/backend/probnum.backend.Dispatcher.rst new file mode 100644 index 000000000..908774c5f --- /dev/null +++ b/docs/source/api/backend/probnum.backend.Dispatcher.rst @@ -0,0 +1,6 @@ +Dispatcher +========== + +.. currentmodule:: probnum.backend + +.. autoclass:: Dispatcher diff --git a/docs/source/api/backend/random.rst b/docs/source/api/backend/random.rst new file mode 100644 index 000000000..4d35e5a16 --- /dev/null +++ b/docs/source/api/backend/random.rst @@ -0,0 +1,18 @@ +probnum.backend.random +---------------------- +.. automodapi:: probnum.backend.random + :no-heading: + :headings: "*" + + +Classes +******* + ++-------------------------------------------+---------------------------------------+ +| :class:`~probnum.backend.random.RNGState` | State of the random number generator. | ++-------------------------------------------+---------------------------------------+ + +.. toctree:: + :hidden: + + random/probnum.backend.random.RNGState diff --git a/docs/source/api/backend/random/probnum.backend.random.RNGState.rst b/docs/source/api/backend/random/probnum.backend.random.RNGState.rst new file mode 100644 index 000000000..5585926ae --- /dev/null +++ b/docs/source/api/backend/random/probnum.backend.random.RNGState.rst @@ -0,0 +1,6 @@ +RNGState +======== + +.. currentmodule:: probnum.backend.random + +.. autoclass:: RNGState diff --git a/docs/source/api/backend/searching_functions.rst b/docs/source/api/backend/searching_functions.rst new file mode 100644 index 000000000..d5f6360e7 --- /dev/null +++ b/docs/source/api/backend/searching_functions.rst @@ -0,0 +1,25 @@ +Searching Functions +=================== + +Functions for searching in arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.argmax + ~probnum.backend.argmin + ~probnum.backend.nonzero + ~probnum.backend.where + + +.. toctree:: + :hidden: + + searching_functions/probnum.backend.argmax + searching_functions/probnum.backend.argmin + searching_functions/probnum.backend.nonzero + searching_functions/probnum.backend.where diff --git a/docs/source/api/backend/searching_functions/probnum.backend.argmax.rst b/docs/source/api/backend/searching_functions/probnum.backend.argmax.rst new file mode 100644 index 000000000..cf9e25d0c --- /dev/null +++ b/docs/source/api/backend/searching_functions/probnum.backend.argmax.rst @@ -0,0 +1,6 @@ +argmax +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: argmax diff --git a/docs/source/api/backend/searching_functions/probnum.backend.argmin.rst b/docs/source/api/backend/searching_functions/probnum.backend.argmin.rst new file mode 100644 index 000000000..7c8645a2d --- /dev/null +++ b/docs/source/api/backend/searching_functions/probnum.backend.argmin.rst @@ -0,0 +1,6 @@ +argmin +====== + +.. currentmodule:: probnum.backend + +.. autofunction:: argmin diff --git a/docs/source/api/backend/searching_functions/probnum.backend.nonzero.rst b/docs/source/api/backend/searching_functions/probnum.backend.nonzero.rst new file mode 100644 index 000000000..44ea5df28 --- /dev/null +++ b/docs/source/api/backend/searching_functions/probnum.backend.nonzero.rst @@ -0,0 +1,6 @@ +nonzero +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: nonzero diff --git a/docs/source/api/backend/searching_functions/probnum.backend.where.rst b/docs/source/api/backend/searching_functions/probnum.backend.where.rst new file mode 100644 index 000000000..2baacb5c2 --- /dev/null +++ b/docs/source/api/backend/searching_functions/probnum.backend.where.rst @@ -0,0 +1,6 @@ +where +===== + +.. currentmodule:: probnum.backend + +.. autofunction:: where diff --git a/docs/source/api/backend/sorting_functions.rst b/docs/source/api/backend/sorting_functions.rst new file mode 100644 index 000000000..4339bbfb5 --- /dev/null +++ b/docs/source/api/backend/sorting_functions.rst @@ -0,0 +1,21 @@ +Sorting Functions +================= + +Functions for sorting arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.argsort + ~probnum.backend.sort + + +.. toctree:: + :hidden: + + sorting_functions/probnum.backend.argsort + sorting_functions/probnum.backend.sort diff --git a/docs/source/api/backend/sorting_functions/probnum.backend.argsort.rst b/docs/source/api/backend/sorting_functions/probnum.backend.argsort.rst new file mode 100644 index 000000000..a52c6fe46 --- /dev/null +++ b/docs/source/api/backend/sorting_functions/probnum.backend.argsort.rst @@ -0,0 +1,6 @@ +argsort +======= + +.. currentmodule:: probnum.backend + +.. autofunction:: argsort diff --git a/docs/source/api/backend/sorting_functions/probnum.backend.sort.rst b/docs/source/api/backend/sorting_functions/probnum.backend.sort.rst new file mode 100644 index 000000000..8c846293c --- /dev/null +++ b/docs/source/api/backend/sorting_functions/probnum.backend.sort.rst @@ -0,0 +1,6 @@ +sort +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: sort diff --git a/docs/source/api/backend/special.rst b/docs/source/api/backend/special.rst new file mode 100644 index 000000000..77515a609 --- /dev/null +++ b/docs/source/api/backend/special.rst @@ -0,0 +1,5 @@ +probnum.backend.special +----------------------- +.. automodapi:: probnum.backend.special + :no-heading: + :headings: "*" diff --git a/docs/source/api/backend/statistical_functions.rst b/docs/source/api/backend/statistical_functions.rst new file mode 100644 index 000000000..9ee0ff429 --- /dev/null +++ b/docs/source/api/backend/statistical_functions.rst @@ -0,0 +1,31 @@ +Statistical functions +===================== + +Statistical functions on arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.max + ~probnum.backend.mean + ~probnum.backend.min + ~probnum.backend.prod + ~probnum.backend.std + ~probnum.backend.sum + ~probnum.backend.var + + +.. toctree:: + :hidden: + + statistical_functions/probnum.backend.max + statistical_functions/probnum.backend.mean + statistical_functions/probnum.backend.min + statistical_functions/probnum.backend.prod + statistical_functions/probnum.backend.std + statistical_functions/probnum.backend.sum + statistical_functions/probnum.backend.var diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.max.rst b/docs/source/api/backend/statistical_functions/probnum.backend.max.rst new file mode 100644 index 000000000..a3e6cf8d6 --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.max.rst @@ -0,0 +1,6 @@ +max +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: max diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.mean.rst b/docs/source/api/backend/statistical_functions/probnum.backend.mean.rst new file mode 100644 index 000000000..c4a2d445f --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.mean.rst @@ -0,0 +1,6 @@ +mean +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: mean diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.min.rst b/docs/source/api/backend/statistical_functions/probnum.backend.min.rst new file mode 100644 index 000000000..b955df94f --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.min.rst @@ -0,0 +1,6 @@ +min +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: min diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.prod.rst b/docs/source/api/backend/statistical_functions/probnum.backend.prod.rst new file mode 100644 index 000000000..87de74a83 --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.prod.rst @@ -0,0 +1,6 @@ +prod +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: prod diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.std.rst b/docs/source/api/backend/statistical_functions/probnum.backend.std.rst new file mode 100644 index 000000000..38f405742 --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.std.rst @@ -0,0 +1,6 @@ +std +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: std diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.sum.rst b/docs/source/api/backend/statistical_functions/probnum.backend.sum.rst new file mode 100644 index 000000000..9b7f7fcbd --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.sum.rst @@ -0,0 +1,6 @@ +sum +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: sum diff --git a/docs/source/api/backend/statistical_functions/probnum.backend.var.rst b/docs/source/api/backend/statistical_functions/probnum.backend.var.rst new file mode 100644 index 000000000..f8389b132 --- /dev/null +++ b/docs/source/api/backend/statistical_functions/probnum.backend.var.rst @@ -0,0 +1,6 @@ +var +=== + +.. currentmodule:: probnum.backend + +.. autofunction:: var diff --git a/docs/source/api/backend/typing.rst b/docs/source/api/backend/typing.rst new file mode 100644 index 000000000..65eb6665c --- /dev/null +++ b/docs/source/api/backend/typing.rst @@ -0,0 +1,6 @@ +probnum.backend.typing +---------------------- +.. automodapi:: probnum.backend.typing + :no-heading: + :headings: "*" + :include-all-objects: diff --git a/docs/source/api/backend/vectorization.rst b/docs/source/api/backend/vectorization.rst new file mode 100644 index 000000000..aa7604ae1 --- /dev/null +++ b/docs/source/api/backend/vectorization.rst @@ -0,0 +1,21 @@ +Vectorization +============= + +Vectorization of functions over arrays. + +.. currentmodule:: probnum.backend + +Functions +--------- + +.. autosummary:: + + ~probnum.backend.vectorize + ~probnum.backend.vmap + + +.. toctree:: + :hidden: + + vectorization/probnum.backend.vectorize + vectorization/probnum.backend.vmap diff --git a/docs/source/api/backend/vectorization/probnum.backend.vectorize.rst b/docs/source/api/backend/vectorization/probnum.backend.vectorize.rst new file mode 100644 index 000000000..e05cc6ff5 --- /dev/null +++ b/docs/source/api/backend/vectorization/probnum.backend.vectorize.rst @@ -0,0 +1,6 @@ +vectorize +========= + +.. currentmodule:: probnum.backend + +.. autofunction:: vectorize diff --git a/docs/source/api/backend/vectorization/probnum.backend.vmap.rst b/docs/source/api/backend/vectorization/probnum.backend.vmap.rst new file mode 100644 index 000000000..150da5dee --- /dev/null +++ b/docs/source/api/backend/vectorization/probnum.backend.vmap.rst @@ -0,0 +1,6 @@ +vmap +==== + +.. currentmodule:: probnum.backend + +.. autofunction:: vmap diff --git a/docs/source/api/compat.rst b/docs/source/api/compat.rst new file mode 100644 index 000000000..ba57cb3f7 --- /dev/null +++ b/docs/source/api/compat.rst @@ -0,0 +1,12 @@ +************** +probnum.compat +************** + +.. automodapi:: probnum.compat + :no-heading: + :headings: "=" + +.. toctree:: + :hidden: + + compat/testing diff --git a/docs/source/api/compat/testing.rst b/docs/source/api/compat/testing.rst new file mode 100644 index 000000000..fcbff4c5c --- /dev/null +++ b/docs/source/api/compat/testing.rst @@ -0,0 +1,5 @@ +probnum.compat.testing +---------------------- +.. automodapi:: probnum.compat.testing + :no-heading: + :headings: "*" diff --git a/docs/source/api/utils.rst b/docs/source/api/utils.rst deleted file mode 100644 index 50a0e80fa..000000000 --- a/docs/source/api/utils.rst +++ /dev/null @@ -1,13 +0,0 @@ -************* -probnum.utils -************* - -.. automodapi:: probnum.utils - :no-inheritance-diagram: - :no-heading: - :headings: "=" - -.. toctree:: - :hidden: - - utils/linalg diff --git a/docs/source/api/utils/linalg.rst b/docs/source/api/utils/linalg.rst deleted file mode 100644 index 98bf5abf6..000000000 --- a/docs/source/api/utils/linalg.rst +++ /dev/null @@ -1,6 +0,0 @@ -probnum.utils.linalg -==================== - -.. automodapi:: probnum.utils.linalg - :no-heading: - :headings: "-" diff --git a/docs/source/conf.py b/docs/source/conf.py index 6086e44ea..bceff1975 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,7 +14,6 @@ # serve to show the default. from datetime import datetime import os -from pathlib import Path import sys from pkg_resources import DistributionNotFound, get_distribution @@ -55,8 +54,13 @@ autodoc_typehints = "description" autodoc_typehints_description_target = "all" autodoc_typehints_format = "short" +# Ensure type aliases are correctly displayed and linked in the documentation autodoc_type_aliases = { **{type_alias: f"typing.{type_alias}" for type_alias in probnum.typing.__all__}, + **{ + type_alias: f"typing.{type_alias}" + for type_alias in probnum.backend.typing.__all__ + }, **{ type_alias: f"typing.{type_alias}" for type_alias in probnum.quad.typing.__all__ }, @@ -65,15 +69,11 @@ # Settings for napoleon napoleon_use_param = True -# Remove possible duplicate methods when using 'automodapi' -# autodoc_default_flags = ['no-members'] -numpydoc_show_class_members = True - - # Settings for automodapi automodapi_toctreedirnm = "api/automod" automodapi_writereprocessed = False automodsumm_inherited_members = True +numpydoc_show_class_members = False # The suffix(es) of source filenames. # You can specify multiple suffixes as a list of strings: @@ -158,6 +158,8 @@ "numpy": ("https://numpy.org/doc/stable/", None), "scipy": ("https://docs.scipy.org/doc/scipy", None), "matplotlib": ("https://matplotlib.org/stable/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "jax": ("https://jax.readthedocs.io/en/latest/", None), } # -- Options for HTML output ---------------------------------------------- @@ -238,11 +240,3 @@ # MyST configuration myst_update_mathjax = False # needed for mathjax compatibility with nbsphinx myst_enable_extensions = ["dollarmath", "amsmath"] - -# Sphinx Bibtex configuration -bibtex_bibfiles = [] -for f in Path("research/bibliography").glob("*.bib"): - bibtex_bibfiles.append(str(f)) -bibtex_default_style = "unsrtalpha" -bibtex_reference_style = "label" -bibtex_encoding = "utf-8-sig" diff --git a/docs/source/development/implementing_a_probnum_method.ipynb b/docs/source/development/implementing_a_probnum_method.ipynb index 8f27e9549..ed404a5ff 100644 --- a/docs/source/development/implementing_a_probnum_method.ipynb +++ b/docs/source/development/implementing_a_probnum_method.ipynb @@ -67,7 +67,7 @@ "\n", "import probnum as pn\n", "from probnum import randvars, linops\n", - "from probnum.typing import FloatLike, IntLike\n", + "from probnum.backend.typing import FloatLike, IntLike\n", "\n", "rng = np.random.default_rng(seed=123)" ] @@ -590,7 +590,7 @@ "ShapeLike = Union[IntLike, Iterable[IntLike]]\n", "\"\"\"Type of a public API argument for supplying a shape. Values of this type should\n", "always be converted into :class:`ShapeType` using the function\n", - ":func:`probnum.utils.as_shape` before further internal processing.\"\"\"\n", + ":func:`probnum.backend.asshape` before further internal processing.\"\"\"\n", "```\n", "\n", "As a small example we write a function which takes a shape and extends that shape with an integer. The type hinted implementation of this function would look like this." @@ -602,12 +602,12 @@ "metadata": {}, "outputs": [], "source": [ - "from probnum.typing import ShapeType, IntLike, ShapeLike\n", - "from probnum.utils import as_shape\n", + "from probnum.backend.typing import ShapeType, IntLike, ShapeLike\n", + "from probnum.backend import asshape\n", "\n", "\n", "def extend_shape(shape: ShapeLike, extension: IntLike) -> ShapeType:\n", - " return as_shape(shape) + as_shape(extension)" + " return asshape(shape) + asshape(extension)" ] }, { @@ -740,7 +740,7 @@ " \"\"\"\n", " observation = fun(action)\n", " try:\n", - " return utils.as_numpy_scalar(observation, dtype=np.floating)\n", + " return backend.asscalar(observation, dtype=np.floating)\n", " except TypeError as exc:\n", " raise TypeError(\n", " \"The given argument `p` can not be cast to a `np.floating` object.\"\n", diff --git a/docs/source/development/quadopt_example/_probsolve_qp.py b/docs/source/development/quadopt_example/_probsolve_qp.py index 064f13ed9..9fefd1ba1 100644 --- a/docs/source/development/quadopt_example/_probsolve_qp.py +++ b/docs/source/development/quadopt_example/_probsolve_qp.py @@ -1,12 +1,11 @@ from functools import partial -from typing import Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import numpy as np import probnum as pn from probnum import linops, randvars -from probnum.typing import FloatLike, IntLike -import probnum.utils as _utils +from probnum.backend.typing import FloatLike, IntLike from .belief_updates import gaussian_belief_update from .observation_operators import function_evaluation diff --git a/docs/source/development/quadopt_example/belief_updates.py b/docs/source/development/quadopt_example/belief_updates.py index 95173477a..096622800 100644 --- a/docs/source/development/quadopt_example/belief_updates.py +++ b/docs/source/development/quadopt_example/belief_updates.py @@ -7,7 +7,7 @@ import probnum as pn from probnum import linops, randvars -from probnum.typing import FloatLike +from probnum.backend.typing import FloatLike def gaussian_belief_update( diff --git a/docs/source/development/quadopt_example/observation_operators.py b/docs/source/development/quadopt_example/observation_operators.py index a08e25cf4..ac6018ec8 100644 --- a/docs/source/development/quadopt_example/observation_operators.py +++ b/docs/source/development/quadopt_example/observation_operators.py @@ -4,8 +4,8 @@ import numpy as np -from probnum import utils -from probnum.typing import FloatLike +from probnum import backend +from probnum.backend.typing import FloatLike def function_evaluation( @@ -22,7 +22,7 @@ def function_evaluation( """ observation = fun(action) try: - return utils.as_numpy_scalar(observation, dtype=np.floating) + return backend.asscalar(observation, dtype=np.floating) except TypeError as exc: raise TypeError( "The given argument `p` can not be cast to a `np.floating` object." diff --git a/docs/source/development/quadopt_example/policies.py b/docs/source/development/quadopt_example/policies.py index 45e95adbe..f917d09ba 100644 --- a/docs/source/development/quadopt_example/policies.py +++ b/docs/source/development/quadopt_example/policies.py @@ -5,7 +5,7 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatLike +from probnum.backend.typing import FloatLike def explore_exploit_policy( diff --git a/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py b/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py index 11f2fdce2..574509636 100644 --- a/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py +++ b/docs/source/development/quadopt_example/probabilistic_quadratic_optimizer.py @@ -1,18 +1,10 @@ import collections.abc -from functools import partial from typing import Callable, Dict, Iterable, Optional, Tuple, Union import numpy as np -import probnum as pn -from probnum import linops, randvars -from probnum.typing import FloatLike, IntLike -import probnum.utils as _utils - -from .belief_updates import gaussian_belief_update -from .observation_operators import function_evaluation -from .policies import explore_exploit_policy, stochastic_policy -from .stopping_criteria import maximum_iterations, parameter_uncertainty +from probnum import randvars +from probnum.backend.typing import FloatLike, IntLike # Type aliases for quadratic optimization QuadOptPolicyType = Callable[ diff --git a/docs/source/development/quadopt_example/stopping_criteria.py b/docs/source/development/quadopt_example/stopping_criteria.py index dad3bfc04..3eae5a7a9 100644 --- a/docs/source/development/quadopt_example/stopping_criteria.py +++ b/docs/source/development/quadopt_example/stopping_criteria.py @@ -5,7 +5,7 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatLike, IntLike +from probnum.backend.typing import FloatLike, IntLike def parameter_uncertainty( diff --git a/docs/source/development/styleguide.md b/docs/source/development/styleguide.md index 9ec31cecd..f7b829396 100644 --- a/docs/source/development/styleguide.md +++ b/docs/source/development/styleguide.md @@ -41,7 +41,7 @@ An exception from these rules are type-related modules, which include `typing` a Types are always imported directly. - `from typing import Optional, Callable` -- `from probnum.typing import FloatLike` +- `from probnum.backend.typing import FloatLike` Please do not abbreviate import paths unnecessarily. We do **not** use the following imports: - `import probnum.random_variables as pnrv` or `import probnum.filtsmooth as pnfs` (correct would be `from probnum import randvars, filtsmooth`) @@ -64,8 +64,7 @@ Many types representing numeric values, shapes, dtypes, random states, etc. have possible representations. For example a shape could be specified in the following ways: `n, (n,), (n, 1), [n], [n, 1]`. For this reason most types should be standardized internally to a core set of types defined -in `probnum.typing`, e.g. for numeric types `np.generic`, `np.ndarray`. Methods for input -argument standardization can be found in `probnum.utils.argutils`. +in `probnum.typing`, e.g. for numeric types `np.generic`, `np.ndarray`. ### Naming diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 16f0712a2..38e6a2db5 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -93,6 +93,19 @@ distribution. A probabilistic numerical method takes random variables as inputs tutorials/prob/random_variables_quickstart +Generic Computation Backend +--------------------------- + +.. nbgallery:: + :caption: Computation Backend + + tutorials/backend/using_the_backend + + +Automatic Differentiation +------------------------- + + .. |Tutorials| image:: https://img.shields.io/badge/Tutorials-Jupyter-579ACA.svg?style=flat-square&logo=Jupyter&logoColor=white :target: https://mybinder.org/v2/gh/probabilistic-numerics/probnum/main?filepath=docs%2Fsource%2Ftutorials :alt: ProbNum's Tutorials diff --git a/docs/source/tutorials/backend/using_the_backend.ipynb b/docs/source/tutorials/backend/using_the_backend.ipynb new file mode 100644 index 000000000..253a7b787 --- /dev/null +++ b/docs/source/tutorials/backend/using_the_backend.ipynb @@ -0,0 +1,87 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Computation Backend" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PROBNUM_BACKEND=jax\n" + ] + } + ], + "source": [ + "%env PROBNUM_BACKEND=jax" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"PROBNUM_BACKEND\"] = \"torch\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + } + ], + "source": [ + "import probnum" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "150452625984079cb361af52b4d37e7980612cc53056bcdcdd507a0bffcc8cf2" + }, + "kernelspec": { + "display_name": "Python 3.8.10 ('probnum')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/linops/linear_operators_quickstart.ipynb b/docs/source/tutorials/linops/linear_operators_quickstart.ipynb index 5f533847f..f8a9effb3 100644 --- a/docs/source/tutorials/linops/linear_operators_quickstart.ipynb +++ b/docs/source/tutorials/linops/linear_operators_quickstart.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -111,7 +111,7 @@ "" ] }, - "execution_count": 3, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -125,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -134,7 +134,7 @@ "array([4., 0., 1., 2., 3.])" ] }, - "execution_count": 4, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -161,7 +161,7 @@ "array([1., 2., 3., 4., 0.])" ] }, - "execution_count": 5, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -173,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -182,7 +182,7 @@ "array([5., 2., 4., 6., 3.])" ] }, - "execution_count": 6, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -194,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -203,7 +203,7 @@ "array([8., 0., 2., 4., 6.])" ] }, - "execution_count": 7, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -215,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -224,7 +224,7 @@ "array([3., 4., 0., 1., 2.])" ] }, - "execution_count": 8, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -243,27 +243,32 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "SumLinearOperator [\n", + "\t, \n", + "\t, \n", + "]" ] }, - "execution_count": 9, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "P_op + P_op" + "from probnum.linops import Identity\n", + "\n", + "P_op + Identity(shape=P_op.shape, dtype=P_op.dtype)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -275,7 +280,7 @@ "]" ] }, - "execution_count": 10, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -308,19 +313,19 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([ 1.49421769, -1.35451937, 1.05551543, -0.41823967, 0.42934955,\n", - " -0.82155968, -1.93141743, -4.31860989, -1.70475714, 4.36385187,\n", - " 2.36850628, -2.94034717, 0.39821307, -1.08656905, 0.36490375,\n", - " -0.86441656, -0.44778464, -0.44155178, 0.55687361, 0.17178464])" + "array([ 0.58568065, 0.0498713 , 0.66504443, -0.32614311, -1.68186058,\n", + " -1.37859679, -0.31337502, -1.75752552, 1.98846081, 3.38309163,\n", + " 0.7854872 , 1.71980838, 2.51819122, 0.01695391, 1.2422392 ,\n", + " 2.03598922, -1.10850474, -2.73340378, -1.02823131, -2.61539212])" ] }, - "execution_count": 11, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -336,7 +341,6 @@ "A_op = Matrix(A=A_scipy)\n", "\n", "# Some linear operator arithmetic\n", - "from probnum.linops import Identity\n", "x = np.random.randn(n)\n", "Id = Identity(shape=n)\n", "(A_op + 1.5 * Id) @ x" @@ -358,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -367,7 +371,7 @@ "" ] }, - "execution_count": 12, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -388,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -397,7 +401,7 @@ "array([4., 0., 1., 2., 3.])" ] }, - "execution_count": 13, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -415,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -428,7 +432,7 @@ " [0., 0., 0., 1., 0.]])" ] }, - "execution_count": 14, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -448,7 +452,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -461,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -470,7 +474,7 @@ "array([3., 2., 1.])" ] }, - "execution_count": 16, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -481,7 +485,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -492,7 +496,7 @@ " [0., 0., 3.]])" ] }, - "execution_count": 17, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -503,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -520,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -541,7 +545,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -561,7 +565,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -582,7 +586,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -605,7 +609,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.6 (conda)", + "display_name": "Python 3.8.10 ('probnum')", "language": "python", "name": "python3" }, @@ -619,11 +623,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.10" }, "vscode": { "interpreter": { - "hash": "0457b12441837086dec1b475e0008c28e5fc37f4ffe0e5ee9f2b481cc28bc3c9" + "hash": "4101eb159dd6763244816f62d81b6a99465123f056480fac67c56f3a615e0198" } } }, diff --git a/pyproject.toml b/pyproject.toml index 3b05cbd90..630078d4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,8 @@ norecursedirs = [ "*.egg*", "dist", "build", - ".tox" + ".tox", + "src/probnum/backend" ] testpaths = [ "src", @@ -178,6 +179,8 @@ disable = [ # Temporary ignore, see https://github.com/probabilistic-numerics/probnum/discussions/470#discussioncomment-1998097 for an explanation "missing-return-doc", "missing-yield-doc", + # Import order is enforced via isort and customized in its configuration + # (see also https://github.com/PyCQA/pylint/issues/3817#issuecomment-687892090) ] [tool.pylint.format] @@ -219,3 +222,12 @@ extend-exclude = ''' profile = "black" combine_as_imports = true force_sort_within_sections = true +known_testing = ["pytest", "pytest_cases", "tests"] +sections = [ + "FUTURE", + "STDLIB", + "THIRDPARTY", + "FIRSTPARTY", + "LOCALFOLDER", + "TESTING", +] diff --git a/setup.py b/setup.py index 0718d30e4..512828908 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,9 @@ extras_require["jax"] = [ "jax[cpu]<0.3.26; platform_system!='Windows'", ] +extras_require["torch"] = [ + "torch>=1.13", +] extras_require["zoo"] = [ "tqdm>=4.0", "requests>=2.0", diff --git a/src/probnum/__init__.py b/src/probnum/__init__.py index 64bd49d24..50af340ec 100644 --- a/src/probnum/__init__.py +++ b/src/probnum/__init__.py @@ -8,12 +8,18 @@ # isort: off +# Determine backend to use +from ._select_backend import BACKEND, Backend + # Global Configuration # The global configuration registry. Can be used as a context manager to create local # contexts in which configuration is temporarily overwritten. This object contains # unguarded global state and is hence not thread-safe! from ._config import _GLOBAL_CONFIG_SINGLETON as config +# Compute backend functionality +from . import backend + # Abstract interfaces for (components of) probabilistic numerical methods. from ._pnmethod import ( ProbabilisticNumericalMethod, @@ -21,26 +27,27 @@ LambdaStoppingCriterion, ) -# isort: on - +# Supporting packages need to be imported before compat from . import ( - diffeq, - filtsmooth, functions, - linalg, linops, - problems, - quad, randprocs, randvars, - utils, ) + +# Compatibility functionality between backend, linops and randvars +from . import compat + +# isort: on + +from . import diffeq, filtsmooth, linalg, problems, quad from ._version import version as __version__ from .randvars import asrandvar # Public classes and functions. Order is reflected in documentation. __all__ = [ "asrandvar", + "BACKEND", "ProbabilisticNumericalMethod", "StoppingCriterion", "LambdaStoppingCriterion", diff --git a/src/probnum/_config.py b/src/probnum/_config.py index 8cc6f98e9..80ce7edfd 100644 --- a/src/probnum/_config.py +++ b/src/probnum/_config.py @@ -2,6 +2,28 @@ import dataclasses from typing import Any +from . import BACKEND, Backend + +# Select default dtype. +default_floating_dtype = None +default_device = None +if BACKEND is Backend.NUMPY: + from numpy import float64 as default_floating_dtype +elif BACKEND is Backend.JAX: + import jax + from jax.numpy import float64 as default_floating_dtype + + default_device = jax.devices()[0] + jax.config.update("jax_enable_x64", True) +elif BACKEND is Backend.TORCH: + import torch + from torch import float64 as default_floating_dtype + + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +__all__ = ["Configuration", "_GLOBAL_CONFIG_SINGLETON"] + class Configuration: r"""Configuration by which some mechanics of ProbNum can be controlled dynamically. @@ -18,13 +40,13 @@ class Configuration: ======== >>> import probnum - >>> probnum.config.covariance_inversion_damping - 1e-12 + >>> probnum.config.matrix_free + False >>> with probnum.config( - ... covariance_inversion_damping=1e-2, + ... matrix_free=True, ... ): - ... probnum.config.covariance_inversion_damping - 0.01 + ... probnum.config.matrix_free + True """ _NON_REGISTERED_KEY_ERR_MSG = ( @@ -115,15 +137,29 @@ def register(self, key: str, default_value: Any, description: str) -> None: _GLOBAL_CONFIG_SINGLETON = Configuration() # ... define some configuration options, and the respective default values -# (which have to be documented in the Configuration-class docstring!!), ... _DEFAULT_CONFIG_OPTIONS = [ # list of tuples (config_key, default_value) ( - "covariance_inversion_damping", - 1e-12, + "default_floating_dtype", + default_floating_dtype, + ( + r"The default floating point data type to use when creating numeric " + r"objects, such as " + r":class:`~probnum.backend.Array`\ s. One of " + r"``None``, :class:`~probnum.backend.float32`, " + r":class:`~probnum.backend.float64`. If ``None``, the default " + r"``dtype`` of the selected computation backend is used." + ), + ), + ( + "default_device", + default_device, ( - "A (typically small) value that is per default added to the diagonal " - "of covariance matrices in order to make inversion numerically stable." + r"The default device to use for numeric objects, such as " + r":class:`~probnum.backend.Array`\ s. By default uses the (first) GPU," + r" if available; if not, the CPU is used. If ``None``, " + r"the placement is controlled by the behavior of the selected " + r"computation backend." ), ), ( diff --git a/src/probnum/_select_backend.py b/src/probnum/_select_backend.py new file mode 100644 index 000000000..a9f5a8197 --- /dev/null +++ b/src/probnum/_select_backend.py @@ -0,0 +1,61 @@ +import enum +import json +import os +import pathlib + + +@enum.unique +class Backend(enum.Enum): + JAX = enum.auto() + TORCH = enum.auto() + NUMPY = enum.auto() + + +BACKEND_FILE = pathlib.Path.home() / ".probnum.json" +BACKEND_FILE_KEY = "backend" + +BACKEND_ENV_VAR = "PROBNUM_BACKEND" + + +def select_backend() -> Backend: + """Select the computation backend.""" + backend_str = None + + if BACKEND_ENV_VAR in os.environ: + backend_str = os.environ[BACKEND_ENV_VAR].upper() + elif BACKEND_FILE.exists() and BACKEND_FILE.is_file(): + with BACKEND_FILE.open("r") as f: + config = json.load(f) + + if BACKEND_FILE_KEY in config: + backend_str = config[BACKEND_FILE_KEY].upper() + + if backend_str is not None: + try: + return Backend[backend_str] + except KeyError as e: + # TODO + raise e from e + + return Backend.NUMPY + + +def _select_via_import() -> Backend: + try: + import jax # pylint: disable=unused-import,import-outside-toplevel + + return Backend.JAX + except ImportError: + pass + + try: + import torch # pylint: disable=unused-import,import-outside-toplevel + + return Backend.TORCH + except ImportError: + pass + + return Backend.NUMPY + + +BACKEND = select_backend() diff --git a/src/probnum/backend/__init__.py b/src/probnum/backend/__init__.py new file mode 100644 index 000000000..bbf630320 --- /dev/null +++ b/src/probnum/backend/__init__.py @@ -0,0 +1,92 @@ +"""Generic computation backend. + +ProbNum's backend implements a unified API for computations with arrays / tensors, that +allows writing generic code and the use of a custom backend library (currently NumPy, +JAX and PyTorch). + +.. note :: + + The interface provided by this module follows the `Python array API standard + `_ closely, which defines a + common API for array and tensor Python libraries. +""" + +from __future__ import annotations + +import builtins +import inspect +import sys + +# isort: off + +from ._dispatcher import Dispatcher + +from ._array_object import * +from ._data_types import * +from ._constants import * +from ._control_flow import * +from ._creation_functions import * +from ._elementwise_functions import * +from ._logic_functions import * +from ._manipulation_functions import * +from ._searching_functions import * +from ._sorting_functions import * +from ._statistical_functions import * +from ._jit_compilation import * +from ._vectorization import * + +from . import ( + _array_object, + _data_types, + _constants, + _control_flow, + _creation_functions, + _elementwise_functions, + _logic_functions, + _manipulation_functions, + _searching_functions, + _sorting_functions, + _statistical_functions, + _jit_compilation, + _vectorization, + autodiff, + linalg, + random, + special, +) + +# isort: on + +# Import some often used functions into probnum.backend +from .linalg import diagonal, einsum, matmul, outer, tensordot, vecdot + +# Define probnum.backend API +__all__imported_modules = ( + _array_object.__all__ + + _data_types.__all__ + + _constants.__all__ + + _control_flow.__all__ + + _creation_functions.__all__ + + _elementwise_functions.__all__ + + _logic_functions.__all__ + + _manipulation_functions.__all__ + + _searching_functions.__all__ + + _sorting_functions.__all__ + + _statistical_functions.__all__ + + _jit_compilation.__all__ + + _vectorization.__all__ +) +__all__ = [ + "Dispatcher", +] + __all__imported_modules + +# Set correct module paths. Corrects links and module paths in documentation. +member_dict = dict(inspect.getmembers(sys.modules[__name__])) +for member_name in __all__imported_modules: + if builtins.any([member_name == mn for mn in ["Array", "Scalar", "Device"]]): + continue # Avoids overriding the __module__ of aliases, which can cause bugs. + + try: + member_dict[member_name].__module__ = "probnum.backend" + except (AttributeError, TypeError): + pass diff --git a/src/probnum/backend/_array_object/__init__.py b/src/probnum/backend/_array_object/__init__.py new file mode 100644 index 000000000..799b5b412 --- /dev/null +++ b/src/probnum/backend/_array_object/__init__.py @@ -0,0 +1,101 @@ +"""Array object.""" + +from __future__ import annotations + +from typing import Any, Optional, Tuple, Union + +import numpy as np + +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + + +__all__ = ["asshape", "isarray", "ndim", "to_numpy", "Array", "Device", "Scalar"] + +Scalar = _impl.Scalar +Array = _impl.Array +Device = _impl.Device + + +def asshape( + x: "probnum.backend.typing.ShapeLike", + ndim: Optional["probnum.backend.typing.IntLike"] = None, +) -> "probnum.backend.typing.ShapeType": + """Convert a shape representation into a shape defined as a tuple of ints. + + Parameters + ---------- + x + Shape representation. + ndim + Number of axes / dimensions of the object with shape ``x``. + + Returns + ------- + shape + The input ``x`` converted to a :class:`~probnum.backend.typing.ShapeType`. + + Raises + ------ + TypeError + If the given ``x`` cannot be converted to a shape with ``ndim`` axes. + """ + + try: + # x is an `IntLike` + shape = (int(x),) + except (TypeError, ValueError): + # x is an iterable + try: + shape = tuple(int(item) for item in x) + except (TypeError, ValueError) as err: + raise TypeError( + f"The given shape {x} must be an integer or an iterable of integers." + ) from err + + if ndim is not None: + ndim = int(ndim) + + if len(shape) != ndim: + raise TypeError(f"The given shape {shape} must have {ndim} dimensions.") + + return shape + + +def isarray(x: Any) -> bool: + """Check whether an object is an :class:`~probnum.backend.Array`. + + Parameters + ---------- + x + Object to check. + """ + return isinstance(x, (Array, Scalar)) + + +def ndim(x: Array) -> int: + """Number of dimensions (axes) of an array. + + Parameters + ---------- + x + Array to get dimensions of. + """ + return _impl.ndim(x) + + +def to_numpy(*arrays: Array) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: + """Convert an :class:`~probnum.backend.Array` to a NumPy :class:`~numpy.ndarray`. + + Parameters + ---------- + arrays + Arrays to convert. + """ + return _impl.to_numpy(*arrays) diff --git a/src/probnum/backend/_array_object/_jax.py b/src/probnum/backend/_array_object/_jax.py new file mode 100644 index 000000000..6de75a778 --- /dev/null +++ b/src/probnum/backend/_array_object/_jax.py @@ -0,0 +1,19 @@ +"""Array object in JAX.""" +from typing import Tuple, Union + +try: + # pylint: disable=redefined-builtin, unused-import + import jax + from jax import Array, Array as Scalar + import jax.numpy as jnp + from jax.numpy import ndim + from jaxlib.xla_extension import Device +except ModuleNotFoundError: + pass + + +def to_numpy(*arrays: "jax.Array") -> Union["jax.Array", Tuple["jax.Array", ...]]: + if len(arrays) == 1: + return jnp.array(arrays[0]) + + return tuple(jnp.array(arr) for arr in arrays) diff --git a/src/probnum/backend/_array_object/_numpy.py b/src/probnum/backend/_array_object/_numpy.py new file mode 100644 index 000000000..dfc141b2b --- /dev/null +++ b/src/probnum/backend/_array_object/_numpy.py @@ -0,0 +1,18 @@ +"""Array object in NumPy.""" +from typing import Literal, Tuple, Union + +import numpy as np +from numpy import ( # pylint: disable=redefined-builtin, unused-import + generic as Scalar, + ndarray as Array, + ndim, +) + +Device = Literal["cpu"] + + +def to_numpy(*arrays: np.ndarray) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: + if len(arrays) == 1: + return arrays[0] + + return tuple(arrays) diff --git a/src/probnum/backend/_array_object/_torch.py b/src/probnum/backend/_array_object/_torch.py new file mode 100644 index 000000000..231e02832 --- /dev/null +++ b/src/probnum/backend/_array_object/_torch.py @@ -0,0 +1,29 @@ +"""Array object in PyTorch.""" + +from typing import Tuple, Union + +import numpy as np + +try: + import torch + from torch import ( # pylint: disable=redefined-builtin, unused-import, reimported + Tensor as Array, + Tensor as Scalar, + device as Device, + ) +except ModuleNotFoundError: + pass + + +def ndim(a: "torch.Tensor"): + try: + return a.ndim + except AttributeError: + return torch.as_tensor(a).ndim + + +def to_numpy(*arrays: "torch.Tensor") -> Union[np.ndarray, Tuple[np.ndarray, ...]]: + if len(arrays) == 1: + return arrays[0].cpu().detach().numpy() + + return tuple(arr.cpu().detach().numpy() for arr in arrays) diff --git a/src/probnum/backend/_constants/__init__.py b/src/probnum/backend/_constants/__init__.py new file mode 100644 index 000000000..738fdd740 --- /dev/null +++ b/src/probnum/backend/_constants/__init__.py @@ -0,0 +1,25 @@ +"""Numerical constants.""" + +import numpy as np + +from .._creation_functions import asarray + +__all__ = ["inf", "nan", "e", "pi"] + +nan = asarray(np.nan) +"""IEEE 754 floating-point representation of Not a Number (``NaN``).""" + +inf = asarray(np.inf) +"""IEEE 754 floating-point representation of (positive) infinity.""" + +e = asarray(np.e) +"""IEEE 754 floating-point representation of Euler's constant. + +``e = 2.71828182845904523536028747135266249775724709369995...`` +""" + +pi = asarray(np.pi) +"""IEEE 754 floating-point representation of the mathematical constant ``π``. + +``pi = 3.1415926535897932384626433...`` +""" diff --git a/src/probnum/backend/_control_flow/__init__.py b/src/probnum/backend/_control_flow/__init__.py new file mode 100644 index 000000000..a99ce337e --- /dev/null +++ b/src/probnum/backend/_control_flow/__init__.py @@ -0,0 +1,17 @@ +from typing import Callable + +from ..._select_backend import BACKEND, Backend +from ..typing import Scalar + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = ["cond"] + + +def cond(pred: Scalar, true_fn: Callable, false_fn: Callable, *operands): + return _impl.cond(pred, true_fn, false_fn, *operands) diff --git a/src/probnum/backend/_control_flow/_jax.py b/src/probnum/backend/_control_flow/_jax.py new file mode 100644 index 000000000..2133c29de --- /dev/null +++ b/src/probnum/backend/_control_flow/_jax.py @@ -0,0 +1,4 @@ +try: + from jax.lax import cond +except ModuleNotFoundError: + pass diff --git a/src/probnum/backend/_control_flow/_numpy.py b/src/probnum/backend/_control_flow/_numpy.py new file mode 100644 index 000000000..d7fa84dce --- /dev/null +++ b/src/probnum/backend/_control_flow/_numpy.py @@ -0,0 +1,18 @@ +from typing import Callable, Union + +import numpy as np + + +def cond( + pred: Union[np.ndarray, np.generic], + true_fn: Callable, + false_fn: Callable, + *operands +): + if np.ndim(pred) != 0: + raise ValueError("`pred` must be a scalar") + + if pred: + return true_fn(*operands) + + return false_fn(*operands) diff --git a/src/probnum/backend/_control_flow/_torch.py b/src/probnum/backend/_control_flow/_torch.py new file mode 100644 index 000000000..34697bdd5 --- /dev/null +++ b/src/probnum/backend/_control_flow/_torch.py @@ -0,0 +1,18 @@ +from typing import Callable + +try: + import torch +except ModuleNotFoundError: + pass + + +def cond(pred: " torch.Tensor", true_fn: Callable, false_fn: Callable, *operands): + pred = torch.as_tensor(pred) + + if pred.ndim != 0: + raise ValueError("`pred` must be a scalar") + + if pred: + return true_fn(*operands) + + return false_fn(*operands) diff --git a/src/probnum/backend/_creation_functions/__init__.py b/src/probnum/backend/_creation_functions/__init__.py new file mode 100644 index 000000000..3f025f2b9 --- /dev/null +++ b/src/probnum/backend/_creation_functions/__init__.py @@ -0,0 +1,637 @@ +"""Array creation functions.""" + +from __future__ import annotations + +from typing import List, Optional, Union + +from .. import Array, Device, DType, Scalar, asshape, ndim +from ... import config +from ..._select_backend import BACKEND, Backend +from ..typing import ArrayLike, DTypeLike, ScalarLike, ShapeLike, ShapeType + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = [ + "arange", + "asarray", + "asscalar", + "diag", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", +] +__all__.sort() + + +def asarray( + obj: Union[Array, bool, int, float, "NestedSequence", "SupportsBufferProtocol"], + /, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> Array: + """Convert the input to an array. + + Parameters + ---------- + obj + Object to be converted to an array. May be a Python scalar, a (possibly nested) + sequence of Python scalars, or an object supporting the Python buffer protocol. + dtype + Output array data type. + device + Device on which to place the created array. If ``device`` is ``None`` and ``x`` + is an array, the output array device must be inferred from ``x``. + copy + Boolean indicating whether or not to copy the input. If ``True``, the function + must always copy. If ``False``, the function must never copy for input which + supports the buffer protocol and must raise a ``ValueError`` in case a copy + would be necessary. If ``None``, the function must reuse existing memory buffer + if possible and copy otherwise. + + Returns + ------- + out + An array containing the data from ``obj``. + """ + return _impl.asarray(obj, dtype=dtype, device=device, copy=copy) + + +def asscalar( + x: ScalarLike, + dtype: Optional[DType] = None, +) -> Scalar: + """Convert a scalar into a NumPy scalar. + + Parameters + ---------- + x + Scalar value. + dtype + Data type of the scalar. + """ + if ndim(x) != 0: + raise ValueError("The given input is not a scalar.") + + return asarray(x, dtype=dtype)[()] + + +def diag(x: ArrayLike, /, *, offset: int = 0) -> Array: + """Construct a diagonal array. + + Parameters + ---------- + x + Diagonal of the to-be-constructed array. + offset + Offset specifying the off-diagonal relative to the main diagonal. + - ``offset = 0``: the main diagonal. + - ``offset > 0``: off-diagonal above the main diagonal. + - ``offset < 0``: off-diagonal below the main diagonal. + + Returns + ------- + out + The constructed diagonal array. + """ + return _impl.diag(x, k=offset) + + +def tril(x: Array, /, *, offset: int = 0) -> Array: + """Returns the lower triangular part of a matrix (or a stack of matrices) ``x``. + + .. note:: + + The lower triangular part of the matrix is defined as the elements on and below + the specified (off-)diagonal given by ``offset``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. + offset + Offset defining the (off-)diagonal above which to zero elements. + - ``offset = 0``: the main diagonal. + - ``offset > 0``: off-diagonal above the main diagonal. + - ``offset < 0``: off-diagonal below the main diagonal. + + Returns + ------- + out : + An array containing the lower triangular part(s). + """ + return _impl.tril(x, k=offset) + + +def triu(x: Array, /, *, offset: int = 0) -> Array: + """Returns the upper triangular part of a matrix (or a stack of matrices) ``x``. + + .. note:: + + The upper triangular part of the matrix is defined as the elements on and above + the specified (off-)diagonal given by ``offset``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. + offset + Offset defining the (off-)diagonal below which to zero elements. + - ``offset = 0``: the main diagonal. + - ``offset > 0``: off-diagonal above the main diagonal. + - ``offset < 0``: off-diagonal below the main diagonal. + + Returns + ------- + out: + An array containing the upper triangular part(s). + """ + return _impl.triu(x, k=offset) + + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns evenly spaced values within the half-open interval ``[start, stop)`` as a + one-dimensional array. + + Parameters + ---------- + start + If ``stop`` is specified, the start of interval (inclusive); otherwise, the end + of the interval (exclusive). If ``stop`` is not specified, the default starting + value is ``0``. + stop + The end of the interval. + step + The distance between two adjacent elements (``out[i+1] - out[i]``). Must not be + ``0``; may be negative, this results in an empty array if ``stop >= start``. + Default: ``1``. + dtype + Output array data type. Should be a floating-point data type. If ``dtype`` is + ``None``, the output array data type must be the default floating-point data + type. + device + Device on which to place the created array. + + .. note:: + + This function cannot guarantee that the interval does not include the ``stop`` + value in those cases where ``step`` is not an integer and floating-point rounding + errors affect the length of the output array. + + Returns + ------- + out + A one-dimensional array containing evenly spaced values. The length of the + output array must be ``ceil((stop-start)/step)`` if ``stop - start`` and + ``step`` have the same sign, and length ``0`` otherwise. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.arange(start, stop, step, dtype=dtype, device=device) + + +def empty( + shape: ShapeLike, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns an uninitialized array having a specified ``shape``. + + Parameters + ---------- + shape + Output array shape. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be the default floating-point data type. + device + Device on which to place the created array. + + Returns + ------- + out + An array containing uninitialized data. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.empty(asshape(shape), dtype=dtype, device=device) + + +def empty_like( + x: Array, + /, + *, + shape: Optional[ShapeLike] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns an uninitialized array with the same ``shape`` as an input array ``x``. + + Parameters + ---------- + x + Input array from which to derive the output array shape. + shape + Overrides the shape of the result. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``x``. + device + Device on which to place the created array. If ``device`` is ``None``, the + output array device must be inferred from ``x``. + + Returns + ------- + out + an array having the same shape as ``x`` and containing uninitialized data. + """ + if dtype is None: + dtype = x.dtype + if shape is not None: + shape = asshape(shape) + + return _impl.empty_like(x, shape=shape, dtype=dtype, device=device) + + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a two-dimensional array with ones on the ``k``\\ th diagonal and zeros + elsewhere. + + Parameters + ---------- + n_rows + Number of rows in the output array. + n_cols + Number of columns in the output array. If ``None``, the default number of + columns in the output array is equal to ``n_rows``. + k + Index of the diagonal. A positive value refers to an upper diagonal, a negative + value to a lower diagonal, and ``0`` to the main diagonal. Default: ``0``. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be the default floating-point data type. + device + Device on which to place the created array. + + Returns + ------- + out + an array where all elements are equal to zero, except for the ``k``\\th + diagonal, whose values are equal to one. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.eye(n_rows, n_cols, k=k, dtype=dtype, device=device) + + +def full( + shape: ShapeType, + fill_value: Union[int, float], + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array having a specified ``shape`` and filled with ``fill_value``. + + Parameters + ---------- + shape + Output array shape. + fill_value + Fill value. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``fill_value``. If the fill value is an ``int``, the + output array data type must be the default integer data type. If the fill value + is a ``float``, the output array data type must be the default floating-point + data type. If the fill value is a ``bool``, the output array must have boolean + data type. + + .. note:: + + If the ``fill_value`` exceeds the precision of the resolved default output + array data type, behavior is left unspecified and, thus, + implementation-defined. + + device + Device on which to place the created array. + + Returns + ------- + out + an array where every element is equal to ``fill_value``. + """ + if device is None: + device = config.default_device + return _impl.full(shape, fill_value, dtype=dtype, device=device) + + +def full_like( + x: Array, + /, + fill_value: Union[int, float], + *, + shape: Optional[ShapeLike] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array filled with ``fill_value`` and having the same ``shape`` as + an input array ``x``. + + Parameters + ---------- + x + Input array from which to derive the output array shape. + fill_value + fill value. + shape + Overrides the shape of the result. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``x``. + device + Device on which to place the created array. If ``device`` is ``None``, the + output array device must be inferred from ``x``. + + Returns + ------- + out + an array having the same shape as ``x`` and where every element is equal to + ``fill_value``. + """ + if shape is not None: + shape = asshape(shape) + if dtype is None: + dtype = x.dtype + + return _impl.full_like( + x, fill_value=fill_value, shape=shape, dtype=dtype, device=device + ) + + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> Array: + """Returns evenly spaced numbers over a specified interval. + + Parameters + ---------- + start + The start of the interval. + stop + The end of the interval. If ``endpoint`` is ``False``, the function must + generate a sequence of ``num+1`` evenly spaced numbers starting with ``start`` + and ending with ``stop`` and exclude the ``stop`` from the returned array such + that the returned array consists of evenly spaced numbers over the half-open + interval ``[start, stop)``. If ``endpoint`` is ``True``, the output array must + consist of evenly spaced numbers over the closed interval ``[start, stop]``. + Default: ``True``. + + .. note:: + + The step size changes when `endpoint` is `False`. + + num + Number of samples. Must be a non-negative integer value; otherwise, the function + must raise an exception. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be the default floating-point data type. + device + Device on which to place the created array. + endpoint + Boolean indicating whether to include ``stop`` in the interval. Default: + ``True``. + + Returns + ------- + out + a one-dimensional array containing evenly spaced values. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.linspace( + start, stop, num=num, dtype=dtype, device=device, endpoint=endpoint + ) + + +def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: + """Returns coordinate matrices from coordinate vectors. + + Parameters + ---------- + arrays + an arbitrary number of one-dimensional arrays representing grid coordinates. + Each array should have the same numeric data type. + indexing + Cartesian ``'xy'`` or matrix ``'ij'`` indexing of output. If provided zero or + one one-dimensional vector(s) (i.e., the zero- and one-dimensional cases, + respectively), the ``indexing`` keyword has no effect and should be ignored. + Default: ``'xy'``. + + Returns + ------- + out + list of N arrays, where ``N`` is the number of provided one-dimensional input + arrays. Each returned array must have rank ``N``. For ``N`` one-dimensional + arrays having lengths ``Ni = len(xi)``, + + - if matrix indexing ``ij``, then each returned array must have the shape + ``(N1, N2, N3, ..., Nn)``. + - if Cartesian indexing ``xy``, then each returned array must have shape + ``(N2, N1, N3, ..., Nn)``. + + Accordingly, for the two-dimensional case with input one-dimensional arrays of + length ``M`` and ``N``, if matrix indexing ``ij``, then each returned array must + have shape ``(M, N)``, and, if Cartesian indexing ``xy``, then each returned + array must have shape ``(N, M)``. + Similarly, for the three-dimensional case with input one-dimensional arrays of + length ``M``, ``N``, and ``P``, if matrix indexing ``ij``, then each returned + array must have shape ``(M, N, P)``, and, if Cartesian indexing ``xy``, then + each returned array must have shape ``(N, M, P)``. + Each returned array should have the same data type as the input arrays. + """ + return _impl.ones_like(*arrays, indexing=indexing) + + +def ones( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array having a specified ``shape`` and filled with ones. + + Parameters + ---------- + shape + Output array shape. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be the default floating-point data type. + device + Device on which to place the created array. + Returns + ------- + out + an array containing ones. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.ones(shape, dtype=dtype, device=device) + + +def ones_like( + x: Array, + /, + *, + shape: Optional[ShapeLike] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array filled with ones and having the same ``shape`` as an input + array ``x``. + + Parameters + ---------- + x + Input array from which to derive the output array shape. + shape + Overrides the shape of the result. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``x``. + device + Device on which to place the created array. If ``device`` is ``None``, the + output array device must be inferred from ``x``. + + Returns + ------- + out + an array having the same shape as ``x`` and filled with ones. + """ + if shape is not None: + shape = asshape(shape) + if dtype is None: + dtype = x.dtype + + return _impl.ones_like(x, shape=shape, dtype=dtype, device=device) + + +def zeros( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array having a specified ``shape`` and filled with zeros. + + Parameters + ---------- + shape + Output array shape. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be the default floating-point data type. + device + Device on which to place the created array. + + Returns + ------- + out + an array containing zeros. + """ + if dtype is None: + dtype = config.default_floating_dtype + if device is None: + device = config.default_device + return _impl.zeros(shape, dtype=dtype, device=device) + + +def zeros_like( + x: Array, + /, + *, + shape: Optional[ShapeLike] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> Array: + """Returns a new array filled with zeros and having the same ``shape`` as an input + array ``x``. + + Parameters + ---------- + x + Input array from which to derive the output array shape. + shape + Overrides the shape of the result. + dtype + Output array data type. If ``dtype`` is ``None``, the output array data type + must be inferred from ``x``. + device + Device on which to place the created array. If ``device`` is ``None``, the + output array device must be inferred from ``x``. + + Returns + ------- + out + an array having the same shape as ``x`` and filled with zeros. + """ + if dtype is None: + dtype = x.dtype + if shape is not None: + shape = asshape(shape) + + return _impl.zeros_like(x, shape=shape, dtype=dtype, device=device) diff --git a/src/probnum/backend/_creation_functions/_jax.py b/src/probnum/backend/_creation_functions/_jax.py new file mode 100644 index 000000000..184b68314 --- /dev/null +++ b/src/probnum/backend/_creation_functions/_jax.py @@ -0,0 +1,178 @@ +"""JAX array creation functions.""" +from typing import List, Optional, Union + +try: + import jax + import jax.numpy as jnp + from jax.numpy import diag, tril, triu # pylint: unused-import +except ModuleNotFoundError: + pass + +from .. import Device, DType +from ... import config +from .._data_types import is_floating_dtype +from ..typing import ShapeType + +# pylint: disable=redefined-builtin + + +def asarray( + obj: Union[ + "jax.Array", bool, int, float, "NestedSequence", "SupportsBufferProtocol" + ], + /, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> "jax.Array": + if copy is None: + copy = True + + if isinstance(obj, jax.Array): + device = obj.device() + else: + device = config.default_device + + out = jax.device_put(jnp.array(obj, dtype=dtype, copy=copy)) + + if is_floating_dtype(out.dtype): + out = out.astype(config.default_floating_dtype) + + return out + + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) + + +def empty( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.empty(shape, dtype=dtype), device=device) + + +def empty_like( + x: "jax.Array", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + if device is None: + device = x.device() + return jax.device_put(jnp.empty_like(x, shape=shape, dtype=dtype), device=device) + + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) + + +def full( + shape: ShapeType, + fill_value: Union[int, float], + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) + + +def full_like( + x: "jax.Array", + /, + fill_value: Union[int, float], + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + if device is None: + device = x.device() + return jax.device_put( + jnp.full_like(x, fill_value=fill_value, shape=shape, dtype=dtype), device=device + ) + + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> "jax.Array": + return jax.device_put( + jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), + device=device, + ) + + +def meshgrid(*arrays: "jax.Array", indexing: str = "xy") -> List["jax.Array"]: + return jnp.meshgrid(*arrays, indexing=indexing) + + +def ones( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.ones(shape, dtype=dtype), device=device) + + +def ones_like( + x: "jax.Array", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + if device is None: + device = x.device() + return jax.device_put(jnp.ones_like(x, shape=shape, dtype=dtype), device=device) + + +def zeros( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device) + + +def zeros_like( + x: "jax.Array", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "jax.Array": + if device is None: + device = x.device() + return jax.device_put(jnp.zeros_like(x, shape=shape, dtype=dtype), device=device) diff --git a/src/probnum/backend/_creation_functions/_numpy.py b/src/probnum/backend/_creation_functions/_numpy.py new file mode 100644 index 000000000..f1aa6cc51 --- /dev/null +++ b/src/probnum/backend/_creation_functions/_numpy.py @@ -0,0 +1,154 @@ +"""NumPy array creation functions.""" +from typing import List, Optional, Union + +import numpy as np +from numpy import diag, tril, triu # pylint: disable= unused-import + +from .. import Device, DType +from ... import config +from .._data_types import is_floating_dtype +from ..typing import ShapeType + +# pylint: disable=redefined-builtin + + +def asarray( + obj: Union[ + np.ndarray, bool, int, float, "NestedSequence", "SupportsBufferProtocol" + ], + /, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> np.ndarray: + if copy is None: + copy = False + out = np.array(obj, dtype=dtype, copy=copy) + if is_floating_dtype(out.dtype): + return out.astype(config.default_floating_dtype, copy=False) + + return np.array(obj, dtype=dtype, copy=copy) + + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.arange(start, stop, step, dtype=dtype) + + +def empty( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.empty(shape, dtype=dtype) + + +def empty_like( + x: np.ndarray, + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.empty_like(x, shape=shape, dtype=dtype) + + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.eye(n_rows, n_cols, k=k, dtype=dtype) + + +def full( + shape: ShapeType, + fill_value: Union[int, float], + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.full(shape, fill_value, dtype=dtype) + + +def full_like( + x: np.ndarray, + /, + fill_value: Union[int, float], + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.full_like(x, fill_value=fill_value, shape=shape, dtype=dtype) + + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> np.ndarray: + return np.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint) + + +def meshgrid(*arrays: np.ndarray, indexing: str = "xy") -> List[np.ndarray]: + return np.ones_like(*arrays, indexing=indexing) + + +def ones( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.ones(shape, dtype=dtype) + + +def ones_like( + x: np.ndarray, + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.ones_like(x, shape=shape, dtype=dtype) + + +def zeros( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.zeros(shape, dtype=dtype) + + +def zeros_like( + x: np.ndarray, + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> np.ndarray: + return np.zeros_like(x, shape=shape, dtype=dtype) diff --git a/src/probnum/backend/_creation_functions/_torch.py b/src/probnum/backend/_creation_functions/_torch.py new file mode 100644 index 000000000..5a38d5647 --- /dev/null +++ b/src/probnum/backend/_creation_functions/_torch.py @@ -0,0 +1,189 @@ +"""Torch tensor creation functions.""" +from typing import List, Optional, Union + +try: + import torch + from torch import tril, triu # pylint: unused-import +except ModuleNotFoundError: + pass +from .. import Device, DType +from ... import config +from .._data_types import is_floating_dtype +from ..typing import ShapeType + +# pylint: disable=redefined-builtin + + +def asarray( + obj: Union[ + "torch.Tensor", bool, int, float, "NestedSequence", "SupportsBufferProtocol" + ], + /, + *, + dtype: Optional["torch.dtype"] = None, + device: Optional["torch.device"] = None, + copy: Optional[bool] = None, +) -> "torch.Tensor": + out = torch.as_tensor(obj, dtype=dtype, device=device) + + if is_floating_dtype(out.dtype): + out = out.to(dtype=config.default_floating_dtype, copy=False) + + if copy is None: + copy = False + if copy: + return out.clone() + + return out + + +def diag(x: "torch.Tensor", /, *, k: int = 0) -> "torch.Tensor": + return torch.diag(x, diagonal=k) + + +def tril(x: "torch.Tensor", /, k: int = 0) -> "torch.Tensor": + return tril(x, diagonal=k) + + +def triu(x: "torch.Tensor", /, k: int = 0) -> "torch.Tensor": + return triu(x, diagonal=k) + + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + return torch.arange(start=start, stop=stop, step=step, dtype=dtype, device=device) + + +def empty( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + return torch.empty(shape, dtype=dtype, device=device) + + +def empty_like( + x: "torch.Tensor", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + if device is None: + device = x.device + return torch.empty_like(x, layout=shape, dtype=dtype, device=device) + + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + if k != 0: + raise NotImplementedError + if n_cols is None: + return torch.eye(n_rows, dtype=dtype, device=device) + return torch.eye(n_rows, n_cols, dtype=dtype, device=device) + + +def full( + shape: ShapeType, + fill_value: Union[int, float], + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + return torch.full(shape, fill_value, dtype=dtype, device=device) + + +def full_like( + x: "torch.Tensor", + /, + fill_value: Union[int, float], + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + if device is None: + device = x.device + return torch.full_like( + x, fill_value=fill_value, layout=shape, dtype=dtype, device=device + ) + + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> "torch.Tensor": + if not endpoint: + raise NotImplementedError + + return torch.linspace(start=start, end=stop, steps=num, dtype=dtype, device=device) + + +def meshgrid(*arrays: "torch.Tensor", indexing: str = "xy") -> List["torch.Tensor"]: + return torch.meshgrid(*arrays, indexing=indexing) + + +def ones( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + return torch.ones(shape, dtype=dtype, device=device) + + +def ones_like( + x: "torch.Tensor", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + if device is None: + device = x.device + return torch.ones_like(x, layout=shape, dtype=dtype, device=device) + + +def zeros( + shape: ShapeType, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + return torch.zeros(shape, dtype=dtype, device=device) + + +def zeros_like( + x: "torch.Tensor", + /, + *, + shape: Optional[ShapeType] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, +) -> "torch.Tensor": + if device is None: + device = x.device + return torch.zeros_like(x, layout=shape, dtype=dtype, device=device) diff --git a/src/probnum/backend/_data_types/__init__.py b/src/probnum/backend/_data_types/__init__.py new file mode 100644 index 000000000..513e5731d --- /dev/null +++ b/src/probnum/backend/_data_types/__init__.py @@ -0,0 +1,239 @@ +"""Data types.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Union + +from .. import Array +from ..._select_backend import BACKEND, Backend +from ..typing import DTypeLike + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = [ + "DType", + "bool", + "int32", + "int64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "MachineLimitsFloatingPoint", + "MachineLimitsInteger", + "asdtype", + "can_cast", + "cast", + "finfo", + "iinfo", + "is_floating_dtype", + "promote_types", + "result_type", +] + +DType = _impl.DType +bool = _impl.bool +int32 = _impl.int32 +int64 = _impl.int64 +float16 = _impl.float16 +float32 = _impl.float32 +float64 = _impl.float64 +complex64 = _impl.complex64 +complex128 = _impl.complex128 + + +@dataclass +class MachineLimitsFloatingPoint: + """Machine limits for a floating point type. + + Parameters + ---------- + bits + The number of bits occupied by the type. + max + The largest representable number. + min + The smallest representable number, typically ``-max``. + eps + The difference between 1.0 and the next smallest representable float larger than + 1.0. For example, for 64-bit binary floats in the IEEE-754 standard, + ``eps = 2**-52``, approximately 2.22e-16. + """ + + bits: int + eps: float + max: float + min: float + + +@dataclass +class MachineLimitsInteger: + """Machine limits for an integer type. + + Parameters + ---------- + bits + The number of bits occupied by the type. + max + The largest representable number. + min + The smallest representable number, typically ``-max``. + """ + + bits: int + max: int + min: int + + +def asdtype(x: DTypeLike, /) -> DType: + """Convert the input to a :class:`~probnum.backend.DType`. + + Parameters + ---------- + x + Object which can be converted to a :class:`~probnum.backend.DType`. + """ + return _impl.asdtype(x) + + +def cast( + x: Array, dtype: DType, /, *, casting: str = "unsafe", copy: bool = True +) -> Array: + """Copies an array to a specified data type irrespective of type-promotion rules. + + Parameters + ---------- + x + Array to cast. + dtype + Desired data type. + casting + Controls what kind of data casting may occur. + copy + Specifies whether to copy an array when the specified ``dtype`` matches the data + type of the input array ``x``. If ``True``, a newly allocated array will always + be returned. If ``False`` and the specified ``dtype`` matches the data type of + the input array, the input array will be returned; otherwise, a newly allocated + will be returned. + + Returns + ------- + out + An array having the specified data type and the same shape as ``x``. + """ + return _impl.cast(x, dtype, casting=casting, copy=copy) + + +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: + """Determines if one data type can be cast to another data type according the type + promotion rules. + + Parameters + ---------- + from_ + Input data type or array from which to cast. + to + Desired data type. + + Returns + ------- + out + ``True`` if the cast can occur according to the type promotion rules; otherwise, + ``False``. + """ + return _impl.can_cast(from_, to) + + +def finfo(type: Union[DType, Array], /) -> MachineLimitsFloatingPoint: + """Machine limits for floating-point data types. + + Parameters + ---------- + type + The kind of floating-point data-type about which to get information. If complex, + the information is about its component data type. + + Returns + ------- + out + :class:`~probnum.backend.MachineLimitsFloatingPoint` object containing + information on machine limits for floating-point data types. + """ + return MachineLimitsFloatingPoint(**_impl.finfo(type)) + + +def iinfo(type: Union[DType, Array], /) -> MachineLimitsInteger: + """Machine limits for integer data types. + + Parameters + ---------- + type + The kind of integer data-type about which to get information. + + Returns + ------- + out + :class:`~probnum.backend.MachineLimitsInteger` object containing information on + machine limits for integer data types. + """ + return MachineLimitsInteger(**_impl.iinfo(type)) + + +def is_floating_dtype(dtype: DType, /) -> bool: + """Check whether ``dtype`` is a floating point data type. + + Parameters + ---------- + dtype + DType object to check. + """ + return _impl.is_floating_dtype(dtype) + + +def promote_types(type1: DType, type2: DType, /) -> DType: + """Returns the data type with the smallest size and smallest scalar kind to which + both ``type1`` and ``type2`` may be safely cast. + + This function is symmetric, but rarely associative. + + Parameters + ---------- + dtype1 + First data type. + dtype2 + Second data type. + + Returns + ------- + out + The promoted data type. + """ + return _impl.promote_types(type1, type2) + + +def result_type(*arrays_and_dtypes: Union[Array, DType]) -> DType: + """Returns the dtype that results from applying the type promotion rules to the + arguments. + + .. note:: + If provided mixed dtypes (e.g., integer and floating-point), the returned dtype + will be implementation-specific. + + Parameters + ---------- + arrays_and_dtypes + An arbitrary number of input arrays and/or dtypes. + + Returns + ------- + out + The dtype resulting from an operation involving the input arrays and dtypes. + """ + return _impl.result_type(*arrays_and_dtypes) diff --git a/src/probnum/backend/_data_types/_jax.py b/src/probnum/backend/_data_types/_jax.py new file mode 100644 index 000000000..8aff31d67 --- /dev/null +++ b/src/probnum/backend/_data_types/_jax.py @@ -0,0 +1,67 @@ +"""Data types in JAX.""" + +from typing import Dict, Union + +try: + import jax + import jax.numpy as jnp + from jax.numpy import ( # pylint: disable=redefined-builtin, unused-import + bool_ as bool, + complex64, + complex128, + dtype as DType, + float16, + float32, + float64, + int32, + int64, + ) +except ModuleNotFoundError: + pass + +from ..typing import DTypeLike + + +def asdtype(x: DTypeLike, /) -> DType: + return jnp.dtype(x) + + +def cast( + x: "jax.Array", dtype: DType, /, *, casting: str = "unsafe", copy: bool = True +) -> "jax.Array": + return x.astype(dtype=dtype) + + +def can_cast(from_: Union[DType, "jax.Array"], to: DType, /) -> bool: + return jnp.can_cast(from_, to) + + +def finfo(type: Union[DType, "jax.Array"], /) -> Dict: + floating_info = jnp.finfo(type) + return { + "bits": floating_info.bits, + "eps": floating_info.eps, + "max": floating_info.max, + "min": floating_info.min, + } + + +def iinfo(type: Union[DType, "jax.Array"], /) -> Dict: + integer_info = jnp.iinfo(type) + return { + "bits": integer_info.bits, + "max": integer_info.max, + "min": integer_info.min, + } + + +def is_floating_dtype(dtype: DType, /) -> bool: + return jnp.issubdtype(dtype, jnp.floating) + + +def promote_types(type1: DType, type2: DType, /) -> DType: + return jnp.promote_types(type1, type2) + + +def result_type(*arrays_and_dtypes: Union["jax.Array", DType]) -> DType: + return jnp.result_type(*arrays_and_dtypes) diff --git a/src/probnum/backend/_data_types/_numpy.py b/src/probnum/backend/_data_types/_numpy.py new file mode 100644 index 000000000..63bf35cdd --- /dev/null +++ b/src/probnum/backend/_data_types/_numpy.py @@ -0,0 +1,63 @@ +"""Data types in NumPy.""" + +from typing import Dict, Union + +import numpy as np +from numpy import ( # pylint: disable=redefined-builtin, unused-import + bool_ as bool, + complex64, + complex128, + dtype as DType, + float16, + float32, + float64, + int32, + int64, +) + +from ..typing import DTypeLike + + +def asdtype(x: DTypeLike, /) -> DType: + return np.dtype(x) + + +def cast( + x: np.ndarray, dtype: DType, /, *, casting: str = "unsafe", copy: bool = True +) -> np.ndarray: + return x.astype(dtype=dtype, casting=casting, copy=copy) + + +def can_cast(from_: Union[DType, np.ndarray], to: DType, /) -> bool: + return np.can_cast(from_, to) + + +def finfo(type: Union[DType, np.ndarray], /) -> Dict: + floating_info = np.finfo(type) + return { + "bits": floating_info.bits, + "eps": floating_info.eps, + "max": floating_info.max, + "min": floating_info.min, + } + + +def iinfo(type: Union[DType, np.ndarray], /) -> Dict: + integer_info = np.iinfo(type) + return { + "bits": integer_info.bits, + "max": integer_info.max, + "min": integer_info.min, + } + + +def is_floating_dtype(dtype: DType, /) -> bool: + return np.issubdtype(dtype, np.floating) + + +def promote_types(type1: DType, type2: DType, /) -> DType: + return np.promote_types(type1, type2) + + +def result_type(*arrays_and_dtypes: Union[np.ndarray, DType]) -> DType: + return np.result_type(*arrays_and_dtypes) diff --git a/src/probnum/backend/_data_types/_torch.py b/src/probnum/backend/_data_types/_torch.py new file mode 100644 index 000000000..998196915 --- /dev/null +++ b/src/probnum/backend/_data_types/_torch.py @@ -0,0 +1,76 @@ +"""Data types in PyTorch.""" +from typing import Dict, Union + +import numpy as np + +try: + import torch + from torch import ( # pylint: disable=redefined-builtin, unused-import + bool, + complex64, + complex128, + dtype as DType, + float16, + float32, + float64, + int32, + int64, + ) +except ModuleNotFoundError: + pass + +# from . import MachineLimitsFloatingPoint, MachineLimitsInteger +from ..typing import DTypeLike + + +def asdtype(x: DTypeLike, /) -> "DType": + if isinstance(x, torch.dtype): + return x + + return torch.as_tensor( + np.empty( + (), + dtype=np.dtype(x), + ), + ).dtype + + +def cast( + x: "torch.Tensor", dtype: "DType", /, *, casting: str = "unsafe", copy: bool = True +) -> "torch.Tensor": + return x.to(dtype=dtype, copy=copy) + + +def can_cast(from_: Union["DType", "torch.Tensor"], to: "DType", /) -> bool: + return torch.can_cast(from_, to) + + +def finfo(type: Union["DType", "torch.Tensor"], /) -> Dict: + floating_info = torch.finfo(type) + return { + "bits": floating_info.bits, + "eps": floating_info.eps, + "max": floating_info.max, + "min": floating_info.min, + } + + +def iinfo(type: Union["DType", "torch.Tensor"], /) -> Dict: + integer_info = torch.iinfo(type) + return { + "bits": integer_info.bits, + "max": integer_info.max, + "min": integer_info.min, + } + + +def is_floating_dtype(dtype: "DType", /) -> bool: + return torch.is_floating_point(torch.empty((), dtype=dtype)) + + +def promote_types(type1: "DType", type2: "DType", /) -> "DType": + return torch.promote_types(type1, type2) + + +def result_type(*arrays_and_dtypes: Union["torch.Tensor", "DType"]) -> "DType": + return torch.result_type(*arrays_and_dtypes) diff --git a/src/probnum/backend/_dispatcher.py b/src/probnum/backend/_dispatcher.py new file mode 100644 index 000000000..f10a5493c --- /dev/null +++ b/src/probnum/backend/_dispatcher.py @@ -0,0 +1,104 @@ +from types import MethodType +from typing import Callable, Optional + +from .._select_backend import BACKEND, Backend + + +class Dispatcher: + """Dispatcher for backend-specific implementations of a function. + + Defines a decorator which can be used to define a function in multiple ways + depending on the backend. This is useful, if besides the generic backend + implementation, a more efficient implementation can be defined using + functionality from a computation backend directly. + + Parameters + ---------- + generic_impl + Generic implementation. + numpy_impl + NumPy implementation. + jax_impl + JAX implementation. + torch_impl + PyTorch implementation. + + Example + ------- + >>> @backend.Dispatcher + ... def f(x): + ... raise NotImplementedError() + ... + ... @f.jax_impl + ... def _(x: jnp.ndarray) -> jnp.ndarray: + ... pass + """ + + def __init__( + self, + generic_impl: Optional[Callable] = None, + /, + *, + numpy_impl: Optional[Callable] = None, + jax_impl: Optional[Callable] = None, + torch_impl: Optional[Callable] = None, + ): + if generic_impl is None: + generic_impl = Dispatcher._raise_not_implemented_error + + self._impl = { + Backend.NUMPY: generic_impl if numpy_impl is None else numpy_impl, + Backend.JAX: generic_impl if jax_impl is None else jax_impl, + Backend.TORCH: generic_impl if torch_impl is None else torch_impl, + } + + def numpy_impl(self, impl: Callable) -> Callable: + self._impl[Backend.NUMPY] = impl + + return impl + + def jax_impl(self, impl: Callable) -> Callable: + self._impl[Backend.JAX] = impl + + return impl + + def torch_impl(self, impl: Callable) -> Callable: + self._impl[Backend.TORCH] = impl + + return impl + + def __call__(self, *args, **kwargs): + return self._impl[BACKEND](*args, **kwargs) + + @staticmethod + def _raise_not_implemented_error() -> None: + raise NotImplementedError( + f"This function is not implemented for the backend `{BACKEND.name}`" + ) + + def __get__(self, obj, objtype=None): + """This is necessary in order to use the :class:`Dispatcher` as a class + attribute which is then translated into a method of class instances, i.e. to + allow for. + + .. code:: + + class Foo: + @Dispatcher + def baz(self, x): + raise NotImplementedError() + + @baz.jax + def _(self, x): + return x + + bar = Foo() + bar.baz("Test") # Output: "Test" + + See https://docs.python.org/3/howto/descriptor.html?highlight=methodtype#functions-and-methods + for details. + """ + if obj is None: + return self + + return MethodType(self, obj) diff --git a/src/probnum/backend/_elementwise_functions/__init__.py b/src/probnum/backend/_elementwise_functions/__init__.py new file mode 100644 index 000000000..ef3c01ccc --- /dev/null +++ b/src/probnum/backend/_elementwise_functions/__init__.py @@ -0,0 +1,1111 @@ +"""Elementwise functions.""" + +from .. import Array +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "exp", + "expm1", + "floor", + "floor_divide", + "imag", + "isfinite", + "isinf", + "isnan", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "maximum", + "minimum", + "multiply", + "negative", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", +] +__all__.sort() + + +def abs(x: Array, /) -> Array: + """Calculates the absolute value for each element ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the absolute value of each element in ``x``. + """ + return _impl.abs(x) + + +def acos(x: Array, /) -> Array: + """Calculates an approximation of the principal value of the inverse cosine, having + domain ``[-1, +1]`` and codomain ``[+0, +π]``, for each element ``x_i`` of the input + array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the inverse cosine of each element in ``x``. + """ + return _impl.acos(x) + + +def acosh(x: Array, /) -> Array: + """Calculates an approximation to the inverse hyperbolic cosine, having domain + ``[+1, infinity]`` and codomain ``[+0, infinity]``, for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array whose elements each represent the area of a hyperbolic sector. + Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the inverse hyperbolic cosine of each element in ``x``. + """ + return _impl.acosh(x) + + +def add(x1: Array, x2: Array, /) -> Array: + """Calculates the sum for each element ``x1_i`` of the input array ``x1`` with the + respective element ``x2_i`` of the input array ``x2``. + + .. note:: + + Floating-point addition is a commutative operation, but not always associative. + + + Parameters + ---------- + x1 + first input array. + x2 + second input array. Must be compatible with ``x1`` (according to the rules of + broadcasting). + + Returns + ------- + out + an array containing the element-wise sums. + """ + return _impl.add(x1, x2) + + +def asin(x: Array, /) -> Array: + """Calculates an approximation of the principal value of the inverse sine, having + domain ``[-1, +1]`` and codomain ``[-π/2, +π/2]`` for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the inverse sine of each element in ``x``. + """ + return _impl.asin(x) + + +def asinh(x: Array, /) -> Array: + """Calculates an approximation to the inverse hyperbolic sine, having domain + ``[-infinity, infinity]`` and codomain ``[-infinity, infinity]``, for each element + ``x_i`` in the input array ``x``. + + Parameters + ---------- + x + input array whose elements each represent the area of a hyperbolic sector. + + Returns + ------- + out + an array containing the inverse hyperbolic sine of each element in ``x``. + """ + return _impl.asinh(x) + + +def atan(x: Array, /) -> Array: + """Calculates an approximation of the principal value of the inverse tangent, having + domain ``[-infinity, infinity]`` and codomain ``[-π/2, +π/2]``, for each element + ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the inverse tangent of each element in ``x``. + """ + return _impl.atan(x) + + +def atan2(x1: Array, x2: Array, /) -> Array: + """Calculates an approximation of the inverse tangent of the quotient ``x1/x2``, + having domain ``[-infinity, infinity] x [-infinity, infinity]`` and codomain ``[-π, + π]``, for each pair of elements ``(x1_i, x2_i)`` of the input arrays ``x1`` and + ``x2``, respectively. + + The mathematical signs of ``x1_i`` and ``x2_i`` determine the quadrant of each + element-wise result. The quadrant (i.e., branch) is chosen such that each + element-wise result is the signed angle in radians between the ray ending at the + origin and passing through the point ``(1,0)`` and the ray ending at the origin and + passing through the point ``(x2_i, x1_i)``. + + + Parameters + ---------- + x1 + input array corresponding to the y-coordinates. + x2 + input array corresponding to the x-coordinates. + + Returns + ------- + out + an array containing the inverse tangent of the quotient ``x1/x2``. + """ + return _impl.atan2(x1, x2) + + +def atanh(x: Array, /) -> Array: + """Calculates an approximation to the inverse hyperbolic tangent, having domain + ``[-1, +1]`` and codomain ``[-infinity, infinity]``, for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array whose elements each represent the area of a hyperbolic sector. + Returns + ------- + out + an array containing the inverse hyperbolic tangent of each element in ``x``. + """ + return _impl.atanh(x) + + +def bitwise_and(x1: Array, x2: Array, /) -> Array: + """Computes the bitwise AND of the underlying binary representation of each element + ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input + array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have an integer or boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have an integer or + boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_and(x1, x2) + + +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + """Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the left by + appending ``x2_i`` (i.e., the respective element in the input array ``x2``) zeros to + the right of ``x1_i``. + + Parameters + ---------- + x1 + first input array. Should have an integer data type. + x2 + second input array. Must be compatible with ``x1``. Should have an integer data + type. Each element must be greater than or equal to ``0``. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_left_shift(x1, x2) + + +def bitwise_invert(x: Array, /) -> Array: + """Inverts (flips) each bit for each element ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have an integer or boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_invert(x) + + +def bitwise_or(x1: Array, x2: Array, /) -> Array: + """Computes the bitwise OR of the underlying binary representation of each element + ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input + array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have an integer or boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have an integer or + boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_or(x1, x2) + + +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + """Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the right + according to the respective element ``x2_i`` of the input array ``x2``. + + .. note:: + This operation must be an arithmetic shift (i.e., sign-propagating) and thus + equivalent to floor division by a power of two. + + Parameters + ---------- + x1 + first input array. Should have an integer data type. + x2 + second input array. Must be compatible with ``x1``. Should have an integer data + type. Each element must be greater than or equal to ``0``. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_right_shift(x1, x2) + + +def bitwise_xor(x1: Array, x2: Array, /) -> Array: + """Computes the bitwise XOR of the underlying binary representation of each element + ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input + array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have an integer or boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have an integer or + boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.bitwise_xor(x1, x2) + + +def ceil(x: Array, /) -> Array: + """Rounds each element ``x_i`` of the input array ``x`` to the smallest (i.e., + closest to ``-infinity``) integer-valued number that is not less than ``x_i``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the rounded result for each element in ``x``. + """ + return _impl.ceil(x) + + +def conj(x: Array, /) -> Array: + """Returns the complex conjugate for each element ``x_i`` of the input array ``x``. + + For complex numbers of the form + + .. math:: + a + bj + + the complex conjugate is defined as + + .. math:: + a - bj + + Hence, the returned complex conjugates must be computed by negating the imaginary + component of each element ``x_i``. + + Parameters + ---------- + x + input array. Should have a complex-floating point data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.conj(x) + + +def cos(x: Array, /) -> Array: + r""" + Calculates an approximation to the cosine for each element ``x_i`` of the input + array ``x``. + + Each element ``x_i`` is assumed to be expressed in radians. + + + For complex floating-point operands, special cases must be handled as if the + operation is implemented as ``cosh(x*1j)``. + + .. note:: + The cosine is an entire function on the complex plane and has no branch cuts. + + .. note:: + For complex arguments, the mathematical definition of cosine is + + .. math:: + \begin{align} \operatorname{cos}(x) &= \sum_{n=0}^\infty \frac{(-1)^n}{(2n)!} x^{2n} \\ &= \frac{e^{jx} + e^{-jx}}{2} \\ &= \operatorname{cosh}(jx) \end{align} + + where :math:`\operatorname{cosh}` is the hyperbolic cosine. + + Parameters + ---------- + x + input array whose elements are each expressed in radians. Should have a + floating-point data type. + + Returns + ------- + out + an array containing the cosine of each element in ``x``. + """ + return _impl.cos(x) + + +def cosh(x: Array, /) -> Array: + r""" + Calculates an approximation to the hyperbolic cosine for each element ``x_i`` in the + input array ``x``. + + The mathematical definition of the hyperbolic cosine is + + .. math:: + \operatorname{cosh}(x) = \frac{e^x + e^{-x}}{2} + + Parameters + ---------- + x + input array whose elements each represent a hyperbolic angle. Should have a + floating-point data type. + + Returns + ------- + out + an array containing the hyperbolic cosine of each element in ``x``. + """ + return _impl.cosh(x) + + +def divide(x1: Array, x2: Array, /) -> Array: + """Calculates the division for each element ``x1_i`` of the input array ``x1`` with + the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + dividend input array. Should have a real-valued data type. + x2 + divisor input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.divide(x1, x2) + + +def exp(x: Array, /) -> Array: + """Calculates an approximation to the exponential function for each element ``x_i`` + of the input array ``x`` (``e`` raised to the power of ``x_i``, where ``e`` is the + base of the natural logarithm). + + Parameters + ---------- + x + input array. Should have a floating-point data type. + + Returns + ------- + out + an array containing the evaluated exponential function result for each element + in ``x``. + """ + return _impl.exp(x) + + +def expm1(x: Array, /) -> Array: + """Calculates an approximation to ``exp(x)-1``, having domain ``[-infinity, + infinity]`` and codomain ``[-1, infinity]``, for each element ``x_i`` of the input + array ``x``. + + .. note:: + + The purpose of this function is to calculate ``exp(x) - 1.0`` more accurately + when `x` is close to zero. + + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.expm1(x) + + +def floor(x: Array, /) -> Array: + """Rounds each element ``x_i`` of the input array ``x`` to the greatest (i.e., + closest to ``infinity``) integer-valued number that is not greater than ``x_i``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the rounded result for each element in ``x``. + """ + return _impl.floor(x) + + +def floor_divide(x1: Array, x2: Array, /) -> Array: + r""" + Rounds the result of dividing each element ``x1_i`` of the input array ``x1`` by the + respective element ``x2_i`` of the input array ``x2`` to the greatest (i.e., + closest to `infinity`) integer-value number that is not greater than the division + result. + + Parameters + ---------- + x1 + dividend input array. Should have a real-valued data type. + x2 + divisor input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.floor_divide(x) + + +def imag(x: Array, /) -> Array: + """Returns the imaginary component of a complex number for each element ``x_i`` of + the input array ``x``. + + Parameters + ---------- + x + input array. Should have a complex floating-point data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.imag(x) + + +def isfinite(x: Array, /) -> Array: + """Tests each element ``x_i`` of the input array ``x`` to determine if finite (i.e., + not ``NaN`` and not equal to positive or negative infinity). + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing test results. An element ``out_i`` is ``True`` if ``x_i`` is + finite and ``False`` otherwise. + """ + return _impl.isfinite(x) + + +def isinf(x: Array, /) -> Array: + """Tests each element ``x_i`` of the input array ``x`` to determine if equal to + positive or negative infinity. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing test results. An element ``out_i`` is ``True`` if ``x_i`` + is either positive or negative infinity and ``False`` otherwise. + """ + return _impl.isinf(x) + + +def isnan(x: Array, /) -> Array: + """Tests each element ``x_i`` of the input array ``x`` to determine whether the + element is ``NaN``. + + Parameters + ---------- + x + Input array. Should have a numeric data type. + + Returns + ------- + out + An array containing test results. An element ``out_i`` is ``True`` if ``x_i`` is + ``NaN`` and ``False`` otherwise. The returned array should have a data type of + ``bool``. + """ + return _impl.isnan(x) + + +def log(x: Array, /) -> Array: + """Calculates an approximation to the natural (base ``e``) logarithm, having domain + ``[0, infinity]`` and codomain ``[-infinity, infinity]``, for each element ``x_i`` + of the input array ``x``. + + **Special cases** + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the result is ``NaN``. + - If ``x_i`` is less than ``0``, the result is ``NaN``. + - If ``x_i`` is either ``+0`` or ``-0``, the result is ``-infinity``. + - If ``x_i`` is ``1``, the result is ``+0``. + - If ``x_i`` is ``infinity``, the result is ``infinity``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the evaluated natural logarithm for each element in ``x``. + """ + return _impl.log(x) + + +def log1p(x: Array, /) -> Array: + """Calculates an approximation to ``log(1+x)``, where ``log`` refers to the natural + (base ``e``) logarithm, having domain ``[-1, infinity]`` and codomain ``[-infinity, + infinity]``, for each element ``x_i`` of the input array ``x``. + + .. note:: + The purpose of this function is to calculate ``log(1+x)`` more accurately + when `x` is close to zero. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.log1p(x) + + +def log2(x: Array, /) -> Array: + """Calculates an approximation to the base ``2`` logarithm, having domain ``[0, + infinity]`` and codomain ``[-infinity, infinity]``, for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the evaluated base ``2`` logarithm for each element in + ``x``. + """ + return _impl.log2(x) + + +def log10(x: Array, /) -> Array: + """Calculates an approximation to the base ``10`` logarithm, having domain ``[0, + infinity]`` and codomain ``[-infinity, infinity]``, for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the evaluated base ``10`` logarithm for each element in + ``x``. + """ + return _impl.log10(x) + + +def logaddexp(x1: Array, x2: Array, /) -> Array: + """Calculates the logarithm of the sum of exponentiations ``log(exp(x1) + exp(x2))`` + for each element ``x1_i`` of the input array ``x1`` with the respective element + ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued floating-point data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + floating-point data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.logaddexp(x1, x2) + + +def maximum(x1: Array, x2: Array, /) -> Array: + """Element-wise maximum of two arrays. + + Compare two arrays and returns a new array containing the element-wise maxima. If + one of the elements being compared is a NaN, then that element is returned. If both + elements are NaNs then the first is returned. The latter distinction is important + for complex NaNs, which are defined as at least one of the real or imaginary parts + being a NaN. The net effect is that NaNs are propagated. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. Must be compatible with ``x1``. + + Returns + ------- + out + An array containing the element-wise maxima. + """ + return _impl.maximum(x1, x2) + + +def minimum(x1: Array, x2: Array, /) -> Array: + """Element-wise minimum of two arrays. + + Compare two arrays and returns a new array containing the element-wise minima. If + one of the elements being compared is a NaN, then that element is returned. If both + elements are NaNs then the first is returned. The latter distinction is important + for complex NaNs, which are defined as at least one of the real or imaginary parts + being a NaN. The net effect is that NaNs are propagated. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. Must be compatible with ``x1``. + + Returns + ------- + out + An array containing the element-wise minima. + """ + return _impl.minimum(x1, x2) + + +def multiply(x1: Array, x2: Array, /) -> Array: + """Calculates the product for each element ``x1_i`` of the input array ``x1`` with + the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise products. + """ + return _impl.multiply(x1, x2) + + +def negative(x: Array, /) -> Array: + """ + Computes the numerical negative of each element ``x_i`` (i.e., ``y_i = -x_i``) of + the input array ``x``. + + Parameters + ---------- + x + input array. Should have a numeric data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.negative(x) + + +def positive(x: Array, /) -> Array: + """ + Computes the numerical positive of each element ``x_i`` (i.e., ``y_i = +x_i``) of + the input array ``x``. + + Parameters + ---------- + x + input array. Should have a numeric data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.positive(x) + + +def pow(x1: Array, x2: Array, /) -> Array: + """Calculates an approximation of exponentiation by raising each element ``x1_i`` + (the base) of the input array ``x1`` to the power of ``x2_i`` (the exponent), where + ``x2_i`` is the corresponding element of the input array ``x2``. + + Parameters + ---------- + x1 + first input array whose elements correspond to the exponentiation base. Should + have a real-valued data type. + x2 + second input array whose elements correspond to the exponentiation exponent. + Should have a real-valued data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.pow(x1, x2) + + +def real(x: Array, /) -> Array: + """Returns the real component of a complex number for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array. Should have a complex floating-point data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.real(x) + + +def remainder(x1: Array, x2: Array, /) -> Array: + """Returns the remainder of division for each element ``x1_i`` of the input array + ``x1`` and the respective element ``x2_i`` of the input array ``x2``. + + .. note:: + This function is equivalent to the Python modulus operator ``x1_i % x2_i``. + + Parameters + ---------- + x1 + dividend input array. Should have a real-valued data type. + x2 + divisor input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.remainder(x) + + +def round(x: Array, /) -> Array: + """Rounds each element ``x_i`` of the input array ``x`` to the nearest integer- + valued number. + + Parameters + ---------- + x + input array. Should have a numeric data type. + + Returns + ------- + out + an array containing the rounded result for each element in ``x``. + """ + return _impl.round(x) + + +def sign(x: Array, /) -> Array: + """Returns an indication of the sign of a number for each element ``x_i`` of the + input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.sign(x) + + +def sin(x: Array, /) -> Array: + r""" + Calculates an approximation to the sine for each element ``x_i`` of the input array + ``x``. + + Each element ``x_i`` is assumed to be expressed in radians. + + For complex floating-point operands, special cases must be handled as if the + operation is implemented as ``-1j * sinh(x*1j)``. + + Parameters + ---------- + x + input array whose elements are each expressed in radians. Should have a floating-point data type. + + Returns + ------- + out + an array containing the sine of each element in ``x``. + """ + return _impl.sin(x) + + +def sinh(x: Array, /) -> Array: + r""" + Calculates an approximation to the hyperbolic sine for each element ``x_i`` of the + input array ``x``. + + The mathematical definition of the hyperbolic sine is + + .. math:: + \operatorname{sinh}(x) = \frac{e^x - e^{-x}}{2} + + Parameters + ---------- + x + input array whose elements each represent a hyperbolic angle. Should have a floating-point data type. + + Returns + ------- + out + an array containing the hyperbolic sine of each element in ``x``. + """ + return _impl.sinh(x) + + +def square(x: Array, /) -> Array: + """ + Squares (``x_i * x_i``) each element ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the evaluated result for each element in ``x``. + """ + return _impl.square(x) + + +def sqrt(x: Array, /) -> Array: + """Calculates the square root, having domain ``[0, infinity]`` and codomain ``[0, + infinity]``, for each element ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have a real-valued floating-point data type. + + Returns + ------- + out + an array containing the square root of each element in ``x``. + """ + return _impl.sqrt(x) + + +def subtract(x1: Array, x2: Array, /) -> Array: + """Calculates the difference for each element ``x1_i`` of the input array ``x1`` + with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued data type. + + Returns + ------- + out + an array containing the element-wise differences. + """ + return _impl.subtract(x1, x2) + + +def tan(x: Array, /) -> Array: + r""" + Calculates an approximation to the tangent for each element ``x_i`` of the input + array ``x``. + + Each element ``x_i`` is assumed to be expressed in radians. + + Parameters + ---------- + x + input array whose elements are expressed in radians. Should have a floating-point data type. + + Returns + ------- + out + an array containing the tangent of each element in ``x``. + """ + return _impl.tan(x) + + +def tanh(x: Array, /) -> Array: + r""" + Calculates an approximation to the hyperbolic tangent for each element ``x_i`` of + the input array ``x``. + + Parameters + ---------- + x + input array whose elements each represent a hyperbolic angle. Should have a floating-point data type. + + Returns + ------- + out + an array containing the hyperbolic tangent of each element in ``x``. + """ + return _impl.tanh(x) + + +def trunc(x: Array, /) -> Array: + """Rounds each element ``x_i`` of the input array ``x`` to the integer-valued number + that is closest to but no greater than ``x_i``. + + Parameters + ---------- + x + input array. Should have a real-valued data type. + + Returns + ------- + out + an array containing the rounded result for each element in ``x``. + """ + return _impl.trunc(x) diff --git a/src/probnum/backend/_elementwise_functions/_jax.py b/src/probnum/backend/_elementwise_functions/_jax.py new file mode 100644 index 000000000..240f29329 --- /dev/null +++ b/src/probnum/backend/_elementwise_functions/_jax.py @@ -0,0 +1,57 @@ +"""Element-wise functions on JAX arrays.""" +try: + from jax.numpy import ( # pylint: disable=unused-import + abs, + add, + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + bitwise_and, + bitwise_or, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + exp, + expm1, + floor, + floor_divide, + imag, + invert as bitwise_invert, + isfinite, + isinf, + isnan, + left_shift as bitwise_left_shift, + log, + log1p, + log2, + log10, + logaddexp, + maximum, + minimum, + multiply, + negative, + positive, + power as pow, + real, + remainder, + right_shift as bitwise_right_shift, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, + ) +except ModuleNotFoundError: + pass diff --git a/src/probnum/backend/_elementwise_functions/_numpy.py b/src/probnum/backend/_elementwise_functions/_numpy.py new file mode 100644 index 000000000..cc481c52d --- /dev/null +++ b/src/probnum/backend/_elementwise_functions/_numpy.py @@ -0,0 +1,55 @@ +"""Element-wise functions on NumPy arrays.""" + +from numpy import ( # pylint: disable=unused-import + abs, + add, + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + bitwise_and, + bitwise_or, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + exp, + expm1, + floor, + floor_divide, + imag, + invert as bitwise_invert, + isfinite, + isinf, + isnan, + left_shift as bitwise_left_shift, + log, + log1p, + log2, + log10, + logaddexp, + maximum, + minimum, + multiply, + negative, + positive, + power as pow, + real, + remainder, + right_shift as bitwise_right_shift, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) diff --git a/src/probnum/backend/_elementwise_functions/_torch.py b/src/probnum/backend/_elementwise_functions/_torch.py new file mode 100644 index 000000000..feedc4fa5 --- /dev/null +++ b/src/probnum/backend/_elementwise_functions/_torch.py @@ -0,0 +1,57 @@ +"""Element-wise functions on torch tensors.""" +try: + from torch import ( # pylint: disable=unused-import + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_left_shift, + bitwise_not as bitwise_invert, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + exp, + expm1, + floor, + floor_divide, + imag, + isfinite, + isinf, + isnan, + log, + log1p, + log2, + log10, + logaddexp, + maximum, + minimum, + multiply, + negative, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, + ) +except ModuleNotFoundError: + pass diff --git a/src/probnum/backend/_jit_compilation/__init__.py b/src/probnum/backend/_jit_compilation/__init__.py new file mode 100644 index 000000000..80bec6453 --- /dev/null +++ b/src/probnum/backend/_jit_compilation/__init__.py @@ -0,0 +1,84 @@ +"""Just-In-Time Compilation.""" +from typing import Callable, Iterable, Union + +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = ["jit", "jit_method"] + + +def jit( + fun: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + """Set up ``fun`` for just-in-time compilation. + + Parameters + ---------- + fun + Function to be jitted. ``fun`` should be a pure function, as side-effects may + only be executed once. The arguments and return value of ``fun`` should be + arrays, scalars, or (nested) standard Python containers (tuple/list/dict) + thereof. + static_argnums + An optional int or collection of ints that specify which positional arguments to + treat as static (compile-time constant). Operations that only depend on static + arguments will be constant-folded in Python (during tracing), and so the + corresponding argument values can be any Python object. + static_argnames + An optional string or collection of strings specifying which named arguments to + treat as static (compile-time constant). + + Returns + ------- + wrapped + A wrapped version of ``fun``, set up for just-in-time compilation. + """ + return _impl.jit( + fun, static_argnums=static_argnums, static_argnames=static_argnames + ) + + +def jit_method( + method: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + """Set up a ``method`` of an object for just-in-time compilation. + + Convencience wrapper for jitting the method(s) of an object. Typically used as a + decorator. + + Parameters + ---------- + method + Method to be jitted. ``method`` should be a pure function, as side-effects may + only be executed once. The arguments and return value of ``method`` should be + arrays, scalars, or (nested) standard Python containers (tuple/list/dict) + thereof. + static_argnums + An optional int or collection of ints that specify which positional arguments to + treat as static (compile-time constant). Operations that only depend on static + arguments will be constant-folded in Python (during tracing), and so the + corresponding argument values can be any Python object. + static_argnames + An optional string or collection of strings specifying which named arguments to + treat as static (compile-time constant). + + Returns + ------- + wrapped + A wrapped version of ``method``, set up for just-in-time compilation. + """ + return _impl.jit_method( + method, static_argnums=static_argnums, static_argnames=static_argnames + ) diff --git a/src/probnum/backend/_jit_compilation/_jax.py b/src/probnum/backend/_jit_compilation/_jax.py new file mode 100644 index 000000000..c223d095a --- /dev/null +++ b/src/probnum/backend/_jit_compilation/_jax.py @@ -0,0 +1,23 @@ +"""Just-In-Time Compilation in JAX.""" +from typing import Callable, Iterable, Union + +try: + import jax + from jax import jit # pylint: disable=unused-import +except ModuleNotFoundError: + pass + + +def jit_method( + method: Callable, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + _static_argnums = (0,) + + if static_argnums is not None: + _static_argnums += tuple(argnum + 1 for argnum in static_argnums) + + return jax.jit( + method, static_argnums=_static_argnums, static_argnames=static_argnames + ) diff --git a/src/probnum/backend/_jit_compilation/_numpy.py b/src/probnum/backend/_jit_compilation/_numpy.py new file mode 100644 index 000000000..3f2b8dc53 --- /dev/null +++ b/src/probnum/backend/_jit_compilation/_numpy.py @@ -0,0 +1,21 @@ +"""Just-In-Time Compilation in NumPy.""" + +from typing import Callable, Iterable, Union + + +def jit( + fun: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + return fun + + +def jit_method( + method: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + return method diff --git a/src/probnum/backend/_jit_compilation/_torch.py b/src/probnum/backend/_jit_compilation/_torch.py new file mode 100644 index 000000000..571b7595f --- /dev/null +++ b/src/probnum/backend/_jit_compilation/_torch.py @@ -0,0 +1,21 @@ +"""Just-In-Time Compilation in PyTorch.""" + +from typing import Callable, Iterable, Union + + +def jit( + fun: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + return fun + + +def jit_method( + method: Callable, + *, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, +): + return method diff --git a/src/probnum/backend/_logic_functions/__init__.py b/src/probnum/backend/_logic_functions/__init__.py new file mode 100644 index 000000000..73592a0f1 --- /dev/null +++ b/src/probnum/backend/_logic_functions/__init__.py @@ -0,0 +1,280 @@ +"""Logic functions.""" + +from .. import Array +from ..._select_backend import BACKEND, Backend +from ..typing import ShapeType + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +from typing import Optional, Union + +__all__ = [ + "all", + "any", + "equal", + "greater", + "greater_equal", + "less", + "less_equal", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "not_equal", +] +__all__.sort() + + +def all( + x: Array, /, *, axis: Optional[Union[int, ShapeType]] = None, keepdims: bool = False +) -> Array: + """Tests whether all input array elements evaluate to ``True`` along a specified + axis. + + Parameters + ---------- + x + Input array. + axis + Axis or axes along which to perform a logical ``AND`` reduction. By default, the + logical ``AND`` reduction will be performed over the entire array. + keepdims + If ``True``, the reduced axes (dimensions) will be included in the result as + singleton dimensions. Otherwise, if ``False``, the reduced axes (dimensions) + will not be included in the result. + + Returns + ------- + out + If a logical ``AND`` reduction was performed over the entire array, the returned + array will be a zero-dimensional array containing the test result; otherwise, + the returned array will be a non-zero-dimensional array containing the test + results. + """ + return _impl.all(x, axis=axis, keepdims=keepdims) + + +def any( + x: Array, /, *, axis: Optional[Union[int, ShapeType]] = None, keepdims: bool = False +) -> Array: + """Tests whether any input array element evaluates to ``True`` along a specified + axis. + + Parameters + ---------- + x + Input array. + axis + Axis or axes along which to perform a logical ``OR`` reduction. By default, the + logical ``OR`` reduction will be performed over the entire array. + keepdims + If ``True``, the reduced axes (dimensions) will be included in the result as + singleton dimensions. Otherwise, if ``False``, the reduced axes (dimensions) + will not be included in the result. + + Returns + ------- + out + If a logical ``OR`` reduction was performed over the entire array, the returned + array will be a zero-dimensional array containing the test result; otherwise, + the returned array will be a non-zero-dimensional array containing the test + results. + """ + return _impl.any(x, axis=axis, keepdims=keepdims) + + +def equal(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i == x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. May have any data type. + x2 + second input array. Must be compatible with ``x1``. May have any data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.equal(x1, x2) + + +def greater(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i > x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.greater(x1, x2) + + +def greater_equal(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i >= x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.greater_equal(x1, x2) + + +def less(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i < x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.less(x1, x2) + + +def less_equal(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i <= x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a real-valued data type. + x2 + second input array. Must be compatible with ``x1``. Should have a real-valued + data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.less_equal(x1, x2) + + +def logical_and(x1: Array, x2: Array, /) -> Array: + """Computes the logical AND for each element ``x1_i`` of the input array ``x1`` with + the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have a boolean data + type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.logical_and(x1, x2) + + +def logical_not(x: Array, /) -> Array: + """Computes the logical NOT for each element ``x_i`` of the input array ``x``. + + Parameters + ---------- + x + input array. Should have a boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.logical_not(x) + + +def logical_or(x1: Array, x2: Array, /) -> Array: + """Computes the logical OR for each element ``x1_i`` of the input array ``x1`` with + the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have a boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.logical_or(x1, x2) + + +def logical_xor(x1: Array, x2: Array, /) -> Array: + """Computes the logical XOR for each element ``x1_i`` of the input array ``x1`` with + the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. Should have a boolean data type. + x2 + second input array. Must be compatible with ``x1``. Should have a boolean data type. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.logical_xor(x1, x2) + + +def not_equal(x1: Array, x2: Array, /) -> Array: + """Computes the truth value of ``x1_i != x2_i`` for each element ``x1_i`` of the + input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1 + first input array. May have any data type. + x2 + second input array. Must be compatible with ``x1``. + + Returns + ------- + out + an array containing the element-wise results. + """ + return _impl.not_equal(x1, x2) diff --git a/src/probnum/backend/_logic_functions/_jax.py b/src/probnum/backend/_logic_functions/_jax.py new file mode 100644 index 000000000..7b3ea08cc --- /dev/null +++ b/src/probnum/backend/_logic_functions/_jax.py @@ -0,0 +1,18 @@ +"""Logic functions on JAX arrays.""" +try: + from jax.numpy import ( # pylint: disable=unused-import + all, + any, + equal, + greater, + greater_equal, + less, + less_equal, + logical_and, + logical_not, + logical_or, + logical_xor, + not_equal, + ) +except ModuleNotFoundError: + pass diff --git a/src/probnum/backend/_logic_functions/_numpy.py b/src/probnum/backend/_logic_functions/_numpy.py new file mode 100644 index 000000000..933e71ff3 --- /dev/null +++ b/src/probnum/backend/_logic_functions/_numpy.py @@ -0,0 +1,16 @@ +"""Logic functions on NumPy arrays.""" + +from numpy import ( # pylint: disable=unused-import + all, + any, + equal, + greater, + greater_equal, + less, + less_equal, + logical_and, + logical_not, + logical_or, + logical_xor, + not_equal, +) diff --git a/src/probnum/backend/_logic_functions/_torch.py b/src/probnum/backend/_logic_functions/_torch.py new file mode 100644 index 000000000..d13b4434d --- /dev/null +++ b/src/probnum/backend/_logic_functions/_torch.py @@ -0,0 +1,71 @@ +"""Logic functions on torch tensors.""" +try: + from torch import ( # pylint: disable=unused-import + equal, + greater, + greater_equal, + less, + less_equal, + logical_and, + logical_not, + logical_or, + logical_xor, + not_equal, + ) +except ModuleNotFoundError: + pass + +from typing import Optional, Union + +from probnum.backend.typing import ShapeType + + +def all( + a: "torch.Tensor", + *, + axis: Optional[Union[int, ShapeType]] = None, + keepdims: bool = False +) -> "torch.Tensor": + if isinstance(axis, int): + return torch.all( + a, + dim=axis, + keepdim=keepdims, + ) + + axes = sorted(axis) + + res = a + + # If `keepdims is True`, this only works because axes is sorted! + for axis in reversed(axes): + res = torch.all(res, dim=axis, keepdims=keepdims) + + return res + + +def any( + a: "torch.Tensor", + *, + axis: Optional[Union[int, ShapeType]] = None, + keepdims: bool = False +) -> "torch.Tensor": + if axis is None: + return torch.any(a) + + if isinstance(axis, int): + return torch.any( + a, + dim=axis, + keepdim=keepdims, + ) + + axes = sorted(axis) + + res = a + + # If `keepdims is True`, this only works because axes is sorted! + for axis in reversed(axes): + res = torch.any(res, dim=axis, keepdims=keepdims) + + return res diff --git a/src/probnum/backend/_manipulation_functions/__init__.py b/src/probnum/backend/_manipulation_functions/__init__.py new file mode 100644 index 000000000..deb518883 --- /dev/null +++ b/src/probnum/backend/_manipulation_functions/__init__.py @@ -0,0 +1,441 @@ +"""Array manipulation functions.""" + +from typing import List, Optional, Sequence, Tuple, Union + +from .. import Array +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +from .. import asshape +from ..typing import ShapeLike, ShapeType + +__all__ = [ + "atleast_1d", + "atleast_2d", + "broadcast_arrays", + "broadcast_shapes", + "broadcast_to", + "concat", + "expand_axes", + "flip", + "hstack", + "move_axes", + "permute_axes", + "reshape", + "roll", + "squeeze", + "stack", + "swap_axes", + "tile", + "vstack", +] +__all__.sort() + + +def atleast_1d(*arrays: Array): + """Convert inputs to arrays with at least one dimension. + + Scalar inputs are converted to 1-dimensional arrays, whilst + higher-dimensional inputs are preserved. + + Parameters + ---------- + arrays + One or more input arrays. + + Returns + ------- + out + An array, or list of arrays, each with ``a.ndim >= 1``. + + See Also + -------- + atleast_2d : Convert inputs to arrays with at least two dimensions. + """ + return _impl.atleast_1d(*arrays) + + +def atleast_2d(*arrays: Array): + """Convert inputs to arrays with at least two dimensions. + + Parameters + ---------- + arrays + One or more input arrays. + + Returns + ------- + out + An array, or list of arrays, each with ``a.ndim >= 2``. + + See Also + -------- + atleast_1d : Convert inputs to arrays with at least one dimension. + """ + return _impl.atleast_2d(*arrays) + + +def broadcast_arrays(*arrays: Array) -> List[Array]: + """Broadcasts one or more arrays against one another. + + Parameters + ---------- + arrays + An arbitrary number of to-be broadcasted arrays. + + Returns + ------- + out + A list of broadcasted arrays. + """ + return _impl.broadcast_arrays(*arrays) + + +def broadcast_shapes(*shapes: ShapeType) -> ShapeType: + """Broadcast the input shapes into a single shape. + + Returns the resulting shape of `broadcasting + `_ + arrays of the given ``shapes``. + + Parameters + ---------- + shapes + The shapes to be broadcast against each other. + + Returns + ------- + outshape + Broadcasted shape. + """ + return _impl.broadcast_shapes(*shapes) + + +def broadcast_to(x: Array, /, shape: ShapeLike) -> Array: + """Broadcasts an array to a specified shape. + + Parameters + ---------- + x + Array to broadcast. + shape + Array shape. Must be compatible with ``x``. + + Returns + ------- + out + An array having a specified shape. + """ + return _impl.broadcast_to(x, shape=asshape(shape)) + + +def concat( + arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0 +) -> Array: + """Joins a sequence of arrays along an existing axis. + + Parameters + ---------- + arrays + Input arrays to join. The arrays must have the same shape, except in the + dimension specified by ``axis``. + axis + Axis along which the arrays will be joined. If ``axis`` is ``None``, arrays are + flattened before concatenation. + + Returns + ------- + out + An output array containing the concatenated values. + """ + return _impl.concat(arrays, axis=axis) + + +def expand_axes(x: Array, /, *, axis: int = 0) -> Array: + """Expands the shape of an array by inserting a new axis of size one at the position + specified by ``axis``. + + Parameters + ---------- + x + Input array. + axis + Axis position. + + Returns + ------- + out + An expanded output array having the same data type as ``x``. + """ + return _impl.expand_axes(x, axis=axis) + + +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """Reverses the order of elements in an array along the given axis. + + Parameters + ---------- + x + Input array. + axis + Axis (or axes) along which to flip. If ``axis`` is ``None``, the function will + flip all input array axes. + + Returns + ------- + out + An output array having the same data type and shape as ``x`` and whose elements, + relative to ``x``, are reordered. + """ + return _impl.flip(x, axis=axis) + + +def permute_axes(x: Array, /, axes: Tuple[int, ...]) -> Array: + """Permutes the axes of an array ``x``. + + Parameters + ---------- + x + input array. + axes + Tuple containing a permutation of ``(0, 1, ..., N-1)`` where ``N`` is the number + of axes of ``x``. + + Returns + ------- + out + An array containing the axes permutation. + + See Also + -------- + swap_axes : Permute the axes of an array. + """ + return _impl.permute_axes(x, axes=axes) + + +def move_axes( + x: Array, + /, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> Array: + """Move axes of an array to new positions. + + Other axes remain in the original order + + Parameters + ---------- + x + Array whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. These must also be unique. + + Returns + ------- + out + Array with moved axes. + """ + return _impl.move_axes(x, source=source, destination=destination) + + +def swap_axes(x: Array, /, axis1: int, axis2: int) -> Array: + """Swaps the axes of an array ``x``. + + Parameters + ---------- + x + Input array. + axis1 + First axis to be swapped. + axis2 + Second axis to be swapped. + + Returns + ------- + out + An array containing the swapped axes. + + See Also + -------- + permute_axes : Permute the axes of an array. + """ + return _impl.swap_axes(x, axis1=axis1, axis2=axis2) + + +def reshape(x: Array, /, shape: ShapeLike, *, copy: Optional[bool] = None) -> Array: + """Reshapes an array without changing its data. + + Parameters + ---------- + x + Input array to reshape. + shape + A new shape compatible with the original shape. One shape dimension is allowed + to be ``-1``. When a shape dimension is ``-1``, the corresponding output array + shape dimension will be inferred from the length of the array and the remaining + dimensions. + copy + Boolean indicating whether or not to copy the input array. If ``None``, reuses + existing memory buffer if possible and copy otherwise. + + Returns + ------- + out + An output array having the same data type and elements as ``x``. + """ + return _impl.reshape(x, shape=asshape(shape), copy=copy) + + +def roll( + x: Array, + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Array: + """Rolls array elements along a specified axis. + + Array elements that roll beyond the last position are re-introduced at the first + position. Array elements that roll beyond the first position are re-introduced at + the last position. + + Parameters + ---------- + x + Input array. + shift + Number of places by which the elements are shifted. If ``shift`` is a tuple, + then ``axis`` must be a tuple of the same size, and each of the given axes will + be shifted by the corresponding element in ``shift``. If ``shift`` is an ``int`` + and ``axis`` a tuple, then the same ``shift`` will be used for all specified + axes. + axis + Axis (or axes) along which elements to shift. If ``axis`` is ``None``, the array + will be flattened, shifted, and then restored to its original shape. + + Returns + ------- + out + An output array having the same data type as ``x`` and whose elements, relative + to ``x``, are shifted. + """ + return _impl.roll(x, shift=shift, axis=axis) + + +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: + """Removes singleton axes from ``x``. + + Parameters + ---------- + x + Input array. + axis + Axis (or axes) to squeeze. + + Returns + ------- + out + An output array having the same data type and elements as ``x``. + """ + return _impl.squeeze(x, axis=axis) + + +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: + """Joins a sequence of arrays along a new axis. + + Parameters + ---------- + arrays + Input arrays to join. Each array must have the same shape. + axis + Axis along which the arrays will be joined. Providing an ``axis`` specifies the + index of the new axis in the dimensions of the result. For example, if ``axis`` + is ``0``, the new axis will be the first dimension and the output array will + have shape ``(N, A, B, C)``; if ``axis`` is ``1``, the new axis will be the + second dimension and the output array will have shape ``(A, N, B, C)``. + + Returns + -------- + out + An output array having rank ``N+1``, where ``N`` is the rank (number of + dimensions) of ``x``. + """ + return _impl.stack(arrays, axis=axis) + + +def hstack(arrays: Union[Tuple[Array, ...], List[Array]], /) -> Array: + """Joins a sequence of arrays horizontally (column-wise). + + Parameters + ---------- + arrays + Input arrays to join. Each array must have the same shape along all but the + second axis. + + Returns + -------- + out + An output array formed by stacking the given arrays. + """ + return _impl.hstack(arrays) + + +def vstack(arrays: Union[Tuple[Array, ...], List[Array]], /) -> Array: + """Joins a sequence of arrays vertically (column-wise). + + Parameters + ---------- + arrays + Input arrays to join. Each array must have the same shape along all but the + first axis. + + Returns + -------- + out + An output array formed by stacking the given arrays. + """ + return _impl.vstack(arrays) + + +def tile(A: Array, /, reps: ShapeLike) -> Array: + """Construct an array by repeating ``A`` the number of times given by ``reps``. + + If ``reps`` has length ``d``, the result will have dimension of + ``max(d, A.ndim)``. + + If ``A.ndim < d``, ``A`` is promoted to be d-dimensional by prepending new + axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, + or shape (1, 1, 3) for 3-D replication. If this is not the desired + behavior, promote ``A`` to d-dimensions manually before calling this + function. + + If ``A.ndim > d``, ``reps`` is promoted to ``A``.ndim by pre-pending 1's to it. + Thus for an ``A`` of shape (2, 3, 4, 5), a ``reps`` of (2, 2) is treated as + (1, 1, 2, 2). + + .. note:: + + Although tile may be used for broadcasting, it is strongly recommended to use + broadcasting operations and functionality instead. + + Parameters + ---------- + A + The input array. + reps + The number of repetitions of ``A`` along each axis. + + Returns + ------- + out + The tiled output array. + """ + return _impl.tile(A, asshape(reps)) diff --git a/src/probnum/backend/_manipulation_functions/_jax.py b/src/probnum/backend/_manipulation_functions/_jax.py new file mode 100644 index 000000000..9580b10c0 --- /dev/null +++ b/src/probnum/backend/_manipulation_functions/_jax.py @@ -0,0 +1,38 @@ +"""JAX array manipulation functions.""" +from typing import Optional + +try: + import jax + import jax.numpy as jnp + from jax.numpy import ( # pylint: disable=unused-import + atleast_1d, + atleast_2d, + broadcast_arrays, + broadcast_shapes, + broadcast_to, + concatenate as concat, + expand_dims as expand_axes, + flip, + hstack, + moveaxis as move_axes, + roll, + squeeze, + stack, + swapaxes as swap_axes, + tile, + transpose as permute_axes, + vstack, + ) +except ModuleNotFoundError: + pass + +from ..typing import ShapeType + + +def reshape( + x: "jax.Array", /, shape: ShapeType, *, copy: Optional[bool] = None +) -> "jax.Array": + if copy is not None: + if copy: + out = jnp.copy(x) + return jnp.reshape(out, newshape=shape) diff --git a/src/probnum/backend/_manipulation_functions/_numpy.py b/src/probnum/backend/_manipulation_functions/_numpy.py new file mode 100644 index 000000000..0e0bd7c51 --- /dev/null +++ b/src/probnum/backend/_manipulation_functions/_numpy.py @@ -0,0 +1,35 @@ +"""NumPy array manipulation functions.""" + +from typing import Optional + +import numpy as np +from numpy import ( # pylint: disable=unused-import + atleast_1d, + atleast_2d, + broadcast_arrays, + broadcast_shapes, + broadcast_to, + concatenate as concat, + expand_dims as expand_axes, + flip, + hstack, + moveaxis as move_axes, + roll, + squeeze, + stack, + swapaxes as swap_axes, + tile, + transpose as permute_axes, + vstack, +) + +from ..typing import ShapeType + + +def reshape( + x: np.ndarray, /, shape: ShapeType, *, copy: Optional[bool] = None +) -> np.ndarray: + if copy is not None: + if copy: + out = np.copy(x) + return np.reshape(out, newshape=shape) diff --git a/src/probnum/backend/_manipulation_functions/_torch.py b/src/probnum/backend/_manipulation_functions/_torch.py new file mode 100644 index 000000000..239e47ad0 --- /dev/null +++ b/src/probnum/backend/_manipulation_functions/_torch.py @@ -0,0 +1,83 @@ +"""Torch tensor manipulation functions.""" + +from typing import List, Optional, Tuple, Union + +try: + import torch + from torch import ( # pylint: disable=unused-import + atleast_1d, + atleast_2d, + broadcast_shapes, + broadcast_tensors as broadcast_arrays, + hstack, + movedim as move_axes, + vstack, + ) +except ModuleNotFoundError: + pass + +from ..typing import ShapeType + + +def broadcast_to(x: "torch.Tensor", /, shape: ShapeType) -> "torch.Tensor": + return torch.broadcast_to(x, size=shape) + + +def concat( + arrays: Union[Tuple["torch.Tensor", ...], List["torch.Tensor"]], + /, + *, + axis: Optional[int] = 0, +) -> "torch.Tensor": + return torch.concat(tensors=arrays, dim=axis) + + +def expand_axes(x: "torch.Tensor", /, *, axis: int = 0) -> "torch.Tensor": + return torch.unsqueeze(input=x, dim=axis) + + +def flip( + x: "torch.Tensor", /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> "torch.Tensor": + return torch.flip(x, dims=axis) + + +def permute_axes(x: "torch.Tensor", /, axes: Tuple[int, ...]) -> "torch.Tensor": + return torch.permute(x, dims=axes) + + +def swap_axes(x: "torch.Tensor", /, axis1: int, axis2: int) -> "torch.Tensor": + return torch.swapdims(x, dim0=axis1, dim1=axis2) + + +def reshape( + x: "torch.Tensor", /, shape: ShapeType, *, copy: Optional[bool] = None +) -> "torch.Tensor": + if copy is not None: + if copy: + out = torch.clone(x) + return torch.reshape(out, shape=shape) + + +def roll( + x: "torch.Tensor", + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, +) -> "torch.Tensor": + return torch.roll(x, shifts=shift, dims=axis) + + +def squeeze(x: "torch.Tensor", /, axis: Union[int, Tuple[int, ...]]) -> "torch.Tensor": + return torch.squeeze(x, dim=axis) + + +def stack( + arrays: Union[Tuple["torch.Tensor", ...], List["torch.Tensor"]], /, *, axis: int = 0 +) -> "torch.Tensor": + return torch.stack(arrays, dim=axis) + + +def tile(A: "torch.Tensor", reps: "torch.Tensor") -> "torch.Tensor": + return torch.tile(input=A, dims=reps) diff --git a/src/probnum/backend/_searching_functions/__init__.py b/src/probnum/backend/_searching_functions/__init__.py new file mode 100644 index 000000000..b04f6a425 --- /dev/null +++ b/src/probnum/backend/_searching_functions/__init__.py @@ -0,0 +1,120 @@ +"""Searching functions.""" + +from typing import Optional, Tuple + +from .. import Array +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = ["argmin", "argmax", "nonzero", "where"] +__all__.sort() + + +def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """Returns the indices of the maximum values along a specified axis. + + When the maximum value occurs multiple times, only the indices corresponding to the + first occurrence are returned. + + Parameters + ---------- + x + Input array. Should have a real-valued data type. + axis + Axis along which to search. If ``None``, the function must return the index of + the maximum value of the flattened array. + keepdims + If ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array. Otherwise, if ``False``, the reduced axes (dimensions) must not be + included in the result. + + Returns + ------- + out + If ``axis`` is ``None``, a zero-dimensional array containing the index of the + first occurrence of the maximum value; otherwise, a non-zero-dimensional array + containing the indices of the maximum values. The returned array must have be + the default array index data type. + """ + return _impl.argmax(x, axis=axis, keepdims=keepdims) + + +def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """Returns the indices of the minimum values along a specified axis. + + When the minimum value occurs multiple times, only the indices corresponding to the + first occurrence are returned. + + Parameters + ---------- + x + Input array. Should have a real-valued data type. + axis + Axis along which to search. If ``None``, the function must return the index of + the minimum value of the flattened array. Default: ``None``. + keepdims + If ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array. Otherwise, if ``False``, the reduced axes (dimensions) must not be + included in the result. Default: ``False``. + + Returns + ------- + out + If ``axis`` is ``None``, a zero-dimensional array containing the index of the + first occurrence of the minimum value; otherwise, a non-zero-dimensional array + containing the indices of the minimum values. The returned array must have the + default array index data type. + """ + return _impl.argmin(x, axis=axis, keepdims=keepdims) + + +def nonzero(x: Array, /) -> Tuple[Array, ...]: + """Returns the indices of the array elements which are non-zero. + + Parameters + ---------- + x + Input array. Must have a positive rank. If ``x`` is zero-dimensional, the + function will raise an exception. + + Returns + ------- + out + A tuple of ``k`` arrays, one for each dimension of ``x`` and each of size ``n`` + (where ``n`` is the total number of non-zero elements), containing the indices + of the non-zero elements in that dimension. The indices must be returned in + row-major, C-style order. The returned array must have the default array index + data type. + """ + return _impl.nonzero(x) + + +def where(condition: Array, x1: Array, x2: Array, /) -> Array: + """Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. + + Parameters + ---------- + condition + When ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible + with ``x1`` and ``x2``. + x1 + First input array. Must be compatible with ``condition`` and ``x2``. + x2 + Second input array. Must be compatible with ``condition`` and ``x1``. + + Returns + ------- + out + An array with elements from ``x1`` where ``condition`` is ``True``, and elements + from ``x2`` elsewhere. The returned array must have a data type determined by + type promotion rules with the arrays ``x1`` and ``x2``. + """ + return _impl.where(condition, x1, x2) diff --git a/src/probnum/backend/_searching_functions/_jax.py b/src/probnum/backend/_searching_functions/_jax.py new file mode 100644 index 000000000..c8f020f7d --- /dev/null +++ b/src/probnum/backend/_searching_functions/_jax.py @@ -0,0 +1,24 @@ +"""Searching functions on JAX arrays.""" +from typing import Optional + +try: + import jax + import jax.numpy as jnp + from jax.numpy import ( # pylint: disable=redefined-builtin, unused-import + nonzero, + where, + ) +except ModuleNotFoundError: + pass + + +def argmax( + x: "jax.Array", /, *, axis: Optional[int] = None, keepdims: bool = False +) -> "jax.Array": + return jnp.argmax(a=x, axis=axis, keepdims=keepdims) + + +def argmin( + x: "jax.Array", /, *, axis: Optional[int] = None, keepdims: bool = False +) -> "jax.Array": + return jnp.argmin(a=x, axis=axis, keepdims=keepdims) diff --git a/src/probnum/backend/_searching_functions/_numpy.py b/src/probnum/backend/_searching_functions/_numpy.py new file mode 100644 index 000000000..edeff8a57 --- /dev/null +++ b/src/probnum/backend/_searching_functions/_numpy.py @@ -0,0 +1,17 @@ +"""Searching functions on NumPy arrays.""" +from typing import Optional + +import numpy as np +from numpy import nonzero, where # pylint: disable=redefined-builtin, unused-import + + +def argmax( + x: np.ndarray, /, *, axis: Optional[int] = None, keepdims: bool = False +) -> np.ndarray: + return np.argmax(a=x, axis=axis, keepdims=keepdims) + + +def argmin( + x: np.ndarray, /, *, axis: Optional[int] = None, keepdims: bool = False +) -> np.ndarray: + return np.argmin(a=x, axis=axis, keepdims=keepdims) diff --git a/src/probnum/backend/_searching_functions/_torch.py b/src/probnum/backend/_searching_functions/_torch.py new file mode 100644 index 000000000..37bf4bb8e --- /dev/null +++ b/src/probnum/backend/_searching_functions/_torch.py @@ -0,0 +1,26 @@ +"""Searching functions on torch tensors.""" +from typing import Optional, Tuple + +try: + import torch + from torch import ( # pylint: disable=redefined-builtin, unused-import, no-name-in-module + where, + ) +except ModuleNotFoundError: + pass + + +def argmax( + x: "torch.Tensor", /, *, axis: Optional[int] = None, keepdims: bool = False +) -> "torch.Tensor": + return torch.argmax(input=x, dim=axis, keepdim=keepdims) + + +def argmin( + x: "torch.Tensor", /, *, axis: Optional[int] = None, keepdims: bool = False +) -> "torch.Tensor": + return torch.argmin(input=x, dim=axis, keepdim=keepdims) + + +def nonzero(x: "torch.Tensor", /) -> Tuple["torch.Tensor", ...]: + return torch.nonzero(input=x, as_tuple=True) diff --git a/src/probnum/backend/_sorting_functions/__init__.py b/src/probnum/backend/_sorting_functions/__init__.py new file mode 100644 index 000000000..6fc6fc36b --- /dev/null +++ b/src/probnum/backend/_sorting_functions/__init__.py @@ -0,0 +1,77 @@ +"""Sorting functions.""" + +from .. import Array +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = ["argsort", "sort"] +__all__.sort() + + +def argsort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: + """Returns the indices that sort an array ``x`` along a specified axis. + + Parameters + ---------- + x + input array. + axis + axis along which to sort. If set to ``-1``, the function must sort along the + last axis. Default: ``-1``. + descending + sort order. If ``True``, the returned indices sort ``x`` in descending order + (by value). If ``False``, the returned indices sort ``x`` in ascending order + (by value). Default: ``False``. + stable + sort stability. If ``True``, the returned indices must maintain the relative + order of ``x`` values which compare as equal. If ``False``, the returned indices + may or may not maintain the relative order of ``x`` values which compare as + equal (i.e., the relative order of ``x`` values which compare as equal is + implementation-dependent). Default: ``True``. + + Returns + ------- + out : + an array of indices. The returned array must have the same shape as ``x``. The + returned array must have the default array index data type. + """ + return _impl.argsort(x, axis=axis, descending=descending, stable=stable) + + +def sort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: + """Returns a sorted copy of an input array ``x``. + + Parameters + ---------- + x + input array. + axis + axis along which to sort. If set to ``-1``, the function must sort along the + last axis. Default: ``-1``. + descending + sort order. If ``True``, the array must be sorted in descending order (by + value). If ``False``, the array must be sorted in ascending order (by value). + Default: ``False``. + stable + sort stability. If ``True``, the returned array must maintain the relative order + of ``x`` values which compare as equal. If ``False``, the returned array may or + may not maintain the relative order of ``x`` values which compare as equal + (i.e., the relative order of ``x`` values which compare as equal is + implementation-dependent). Default: ``True``. + Returns + ------- + out : + a sorted array. The returned array must have the same data type and shape as + ``x``. + """ + return _impl.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/src/probnum/backend/_sorting_functions/_jax.py b/src/probnum/backend/_sorting_functions/_jax.py new file mode 100644 index 000000000..2a467fffa --- /dev/null +++ b/src/probnum/backend/_sorting_functions/_jax.py @@ -0,0 +1,48 @@ +"""Sorting functions for JAX arrays.""" + +try: + import jax + import jax.numpy as jnp + from jax.numpy import isnan # pylint: disable=redefined-builtin, unused-import +except ModuleNotFoundError: + pass + + +def sort( + x: "jax.Array", + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> "jax.Array": + kind = "quicksort" + if stable: + kind = "stable" + + sorted_array = jnp.sort(x, axis=axis, kind=kind) + + if descending: + return jnp.flip(sorted_array, axis=axis) + + return sorted_array + + +def argsort( + x: "jax.Array", + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> "jax.Array": + kind = "quicksort" + if stable: + kind = "stable" + + sort_idx = jnp.argsort(x, axis=axis, kind=kind) + + if descending: + return jnp.flip(sort_idx, axis=axis) + + return sort_idx diff --git a/src/probnum/backend/_sorting_functions/_numpy.py b/src/probnum/backend/_sorting_functions/_numpy.py new file mode 100644 index 000000000..9aba38ba3 --- /dev/null +++ b/src/probnum/backend/_sorting_functions/_numpy.py @@ -0,0 +1,43 @@ +"""/sorting functions for NumPy arrays.""" +import numpy as np +from numpy import isnan # pylint: disable=redefined-builtin, unused-import + + +def sort( + x: np.ndarray, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> np.ndarray: + kind = "quicksort" + if stable: + kind = "stable" + + sorted_array = np.sort(x, axis=axis, kind=kind) + + if descending: + return np.flip(sorted_array, axis=axis) + + return sorted_array + + +def argsort( + x: np.ndarray, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> np.ndarray: + kind = "quicksort" + if stable: + kind = "stable" + + sort_idx = np.argsort(x, axis=axis, kind=kind) + + if descending: + return np.flip(sort_idx, axis=axis) + + return sort_idx diff --git a/src/probnum/backend/_sorting_functions/_torch.py b/src/probnum/backend/_sorting_functions/_torch.py new file mode 100644 index 000000000..0a7f862b3 --- /dev/null +++ b/src/probnum/backend/_sorting_functions/_torch.py @@ -0,0 +1,31 @@ +"""Sorting functions for torch tensors.""" + +try: + import torch + from torch import ( # pylint: disable=redefined-builtin, unused-import, no-name-in-module + isnan, + ) +except ModuleNotFoundError: + pass + + +def sort( + x: "torch.Tensor", + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> "torch.Tensor": + return torch.sort(x, dim=axis, descending=descending, stable=stable)[0] + + +def argsort( + x: "torch.Tensor", + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> "torch.Tensor": + return torch.sort(x, dim=axis, descending=descending, stable=stable)[1] diff --git a/src/probnum/backend/_statistical_functions/__init__.py b/src/probnum/backend/_statistical_functions/__init__.py new file mode 100644 index 000000000..f87ee575f --- /dev/null +++ b/src/probnum/backend/_statistical_functions/__init__.py @@ -0,0 +1,344 @@ +"""Statistical functions.""" + +from __future__ import annotations + +from typing import Optional, Tuple, Union + +from .. import Array, DType +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"] +__all__.sort() + + +def max( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """Calculates the maximum value of the input array ``x``. + + **Special Cases** + For floating-point operands, + + - If ``x_i`` is ``NaN``, the maximum value is ``NaN`` (i.e., ``NaN`` values + propagate). + + Parameters + ---------- + x + input array. Should have a numeric data type. + axis + axis or axes along which maximum values must be computed. By default, the + maximum value must be computed over the entire array. If a tuple of integers, + maximum values must be computed over multiple axes. Default: ``None``. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the maximum value was computed over the entire array, a zero-dimensional + array containing the maximum value; otherwise, a non-zero-dimensional array + containing the maximum values. The returned array must have the same data type + as ``x``. + """ + return _impl.max(x, axis=axis, keepdims=keepdims) + + +def mean( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """Calculates the arithmetic mean of the input array ``x``. + + **Special Cases** + Let ``N`` equal the number of elements over which to compute the arithmetic mean. + + - If ``N`` is ``0``, the arithmetic mean is ``NaN``. + - If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` (i.e., ``NaN`` values + propagate). + + Parameters + ---------- + x + input array. Should have a floating-point data type. + axis + axis or axes along which arithmetic means must be computed. By default, the mean + must be computed over the entire array. If a tuple of integers, arithmetic means + must be computed over multiple axes. Default: ``None``. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the arithmetic mean was computed over the entire array, a zero-dimensional + array containing the arithmetic mean; otherwise, a non-zero-dimensional array + containing the arithmetic means. The returned array must have the same data type + as ``x``. + """ + return _impl.mean(x, axis=axis, keepdims=keepdims) + + +def min( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """Calculates the minimum value of the input array ``x``. + + **Special Cases** + For floating-point operands, + + - If ``x_i`` is ``NaN``, the minimum value is ``NaN`` (i.e., ``NaN`` values + propagate). + + Parameters + ---------- + x + input array. Should have a numeric data type. + axis + axis or axes along which minimum values must be computed. By default, the + minimum value must be computed over the entire array. If a tuple of integers, + minimum values must be computed over multiple axes. Default: ``None``. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the minimum value was computed over the entire array, a zero-dimensional + array containing the minimum value; otherwise, a non-zero-dimensional array + containing the minimum values. The returned array must have the same data type + as ``x``. + """ + return _impl.min(x, axis=axis, keepdims=keepdims) + + +def prod( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[DType] = None, + keepdims: bool = False, +) -> Array: + """Calculates the product of input array ``x`` elements. + + **Special Cases** + Let ``N`` equal the number of elements over which to compute the product. + + - If ``N`` is ``0``, the product is `1` (i.e., the empty product). + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the product is ``NaN`` (i.e., ``NaN`` values propagate). + + Parameters + ---------- + x + input array. Should have a numeric data type. + axis + axis or axes along which products must be computed. By default, the product must + be computed over the entire array. If a tuple of integers, products must be + computed over multiple axes. Default: ``None``. + dtype + data type of the returned array. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the product was computed over the entire array, a zero-dimensional array + containing the product; otherwise, a non-zero-dimensional array containing the + products. The returned array must have a data type as described by the ``dtype`` + parameter above. + """ + return _impl.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +def std( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + """Calculates the standard deviation of the input array ``x``. + + **Special Cases** + Let ``N`` equal the number of elements over which to compute the standard deviation. + + - If ``N - correction`` is less than or equal to ``0``, the standard deviation is + ``NaN``. + - If ``x_i`` is ``NaN``, the standard deviation is ``NaN`` (i.e., ``NaN`` values + propagate). + + Parameters + ---------- + x + input array. Should have a floating-point data type. + axis + axis or axes along which standard deviations must be computed. By default, the + standard deviation must be computed over the entire array. If a tuple of + integers, standard deviations must be computed over multiple axes. + Default: ``None``. + correction + degrees of freedom adjustment. Setting this parameter to a value other than + ``0`` has the effect of adjusting the divisor during the calculation of the + standard deviation according to ``N-c`` where ``N`` corresponds to the total + number of elements over which the standard deviation is computed and ``c`` + corresponds to the provided degrees of freedom adjustment. When computing the + standard deviation of a population, setting this parameter to ``0`` is the + standard choice (i.e., the provided array contains data constituting an entire + population). When computing the corrected sample standard deviation, setting + this parameter to ``1`` is the standard choice (i.e., the provided array + contains data sampled from a larger population; this is commonly referred to as + Bessel's correction). Default: ``0``. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the standard deviation was computed over the entire array, a zero-dimensional + array containing the standard deviation; otherwise, a non-zero-dimensional array + containing the standard deviations. The returned array must have the same data + type as ``x``. + """ + return _impl.std(x, axis=axis, correction=correction, keepdims=keepdims) + + +def sum( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[DType] = None, + keepdims: bool = False, +) -> Array: + """Calculates the sum of the input array ``x``. + + **Special Cases** + Let ``N`` equal the number of elements over which to compute the sum. + + - If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum). + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate). + + Parameters + ---------- + x + input array. Should have a numeric data type. + axis + axis or axes along which sums must be computed. By default, the sum must be + computed over the entire array. If a tuple of integers, sums must be computed + over multiple axes. Default: ``None``. + dtype + data type of the returned array. + keepdims: bool + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the sum was computed over the entire array, a zero-dimensional array + containing the sum; otherwise, an array containing the sums. The returned + array must have a data type as described by the ``dtype`` parameter above. + """ + return _impl.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +def var( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + """Calculates the variance of the input array ``x``. + + **Special Cases** + Let ``N`` equal the number of elements over which to compute the variance. + + - If ``N - correction`` is less than or equal to ``0``, the variance is ``NaN``. + - If ``x_i`` is ``NaN``, the variance is ``NaN`` (i.e., ``NaN`` values propagate). + + Parameters + ---------- + x + input array. Should have a floating-point data type. + axis + axis or axes along which variances must be computed. By default, the variance + must be computed over the entire array. If a tuple of integers, variances must + be computed over multiple axes. Default: ``None``. + correction + degrees of freedom adjustment. Setting this parameter to a value other than + ``0`` has the effect of adjusting the divisor during the calculation of the + variance according to ``N-c`` where ``N`` corresponds to the total number of + elements over which the variance is computed and ``c`` corresponds to the + provided degrees of freedom adjustment. When computing the variance of a + population, setting this parameter to ``0`` is the standard choice (i.e., the + provided array contains data constituting an entire population). When computing + the unbiased sample variance, setting this parameter to ``1`` is the standard + choice (i.e., the provided array contains data sampled from a larger population; + this is commonly referred to as Bessel's correction). Default: ``0``. + keepdims + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see `broadcasting `). Otherwise, if ``False``, the reduced + axes (dimensions) must not be included in the result. Default: ``False``. + + Returns + ------- + out + if the variance was computed over the entire array, a zero-dimensional array + containing the variance; otherwise, a non-zero-dimensional array containing the + variances. The returned array must have the same data type as ``x``. + """ + return _impl.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/src/probnum/backend/_statistical_functions/_jax.py b/src/probnum/backend/_statistical_functions/_jax.py new file mode 100644 index 000000000..31aa8f616 --- /dev/null +++ b/src/probnum/backend/_statistical_functions/_jax.py @@ -0,0 +1,52 @@ +"""Statistical functions implemented in JAX.""" + +from typing import Optional, Tuple, Union + +try: + import jax + import jax.numpy as jnp + from jax.numpy import mean, prod, sum # pylint: disable=unused-import +except ModuleNotFoundError: + pass + + +def max( + x: "jax.Array", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> "jax.Array": + return jnp.amax(x, axis=axis, keepdims=keepdims) + + +def min( + x: "jax.Array", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> "jax.Array": + return jnp.amin(x, axis=axis, keepdims=keepdims) + + +def std( + x: "jax.Array", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> "jax.Array": + return jnp.std(x, axis=axis, ddof=correction, keepdims=keepdims) + + +def var( + x: "jax.Array", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> "jax.Array": + return jnp.var(x, axis=axis, ddof=correction, keepdims=keepdims) diff --git a/src/probnum/backend/_statistical_functions/_numpy.py b/src/probnum/backend/_statistical_functions/_numpy.py new file mode 100644 index 000000000..51ada858e --- /dev/null +++ b/src/probnum/backend/_statistical_functions/_numpy.py @@ -0,0 +1,78 @@ +"""Statistical functions implemented in NumPy.""" +from typing import Optional, Tuple, Union + +import numpy as np + + +def max( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.amax(x, axis=axis, keepdims=keepdims)) + + +def min( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.amin(x, axis=axis, keepdims=keepdims)) + + +def mean( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.mean(x, axis=axis, keepdims=keepdims)) + + +def prod( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[np.dtype] = None, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)) + + +def sum( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[np.dtype] = None, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)) + + +def std( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.std(x, axis=axis, ddof=correction, keepdims=keepdims)) + + +def var( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> np.ndarray: + return np.asarray(np.var(x, axis=axis, ddof=correction, keepdims=keepdims)) diff --git a/src/probnum/backend/_statistical_functions/_torch.py b/src/probnum/backend/_statistical_functions/_torch.py new file mode 100644 index 000000000..1c43c9845 --- /dev/null +++ b/src/probnum/backend/_statistical_functions/_torch.py @@ -0,0 +1,92 @@ +"""Statistical functions implemented in PyTorch.""" + +from typing import Optional, Tuple, Union + +try: + import torch +except ModuleNotFoundError: + pass + + +def max( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> "torch.Tensor": + return torch.max(x, dim=axis, keepdim=keepdims) + + +def min( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> "torch.Tensor": + return torch.min(x, dim=axis, keepdim=keepdims) + + +def mean( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> "torch.Tensor": + return torch.mean(x, dim=axis, keepdim=keepdims) + + +def prod( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional["torch.dtype"] = None, + keepdims: bool = False, +) -> "torch.Tensor": + return torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims) + + +def sum( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional["torch.dtype"] = None, + keepdims: bool = False, +) -> "torch.Tensor": + return torch.sum(x, dim=axis, dtype=dtype, keepdim=keepdims) + + +def std( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> "torch.Tensor": + if correction == 0.0: + return torch.std(x, dim=axis, unbiased=False, keepdim=keepdims) + elif correction == 1.0: + return torch.std(x, dim=axis, unbiased=True, keepdim=keepdims) + else: + raise NotImplementedError("Only correction=0 or =1 implemented.") + + +def var( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> "torch.Tensor": + if correction == 0.0: + return torch.var(x, dim=axis, unbiased=False, keepdim=keepdims) + elif correction == 1.0: + return torch.var(x, dim=axis, unbiased=True, keepdim=keepdims) + else: + raise NotImplementedError("Only correction=0 or =1 implemented.") diff --git a/src/probnum/backend/_vectorization/__init__.py b/src/probnum/backend/_vectorization/__init__.py new file mode 100644 index 000000000..3b3eaf1ad --- /dev/null +++ b/src/probnum/backend/_vectorization/__init__.py @@ -0,0 +1,98 @@ +"""Vectorization of functions.""" +from typing import AbstractSet, Any, Callable, Optional, Sequence, Union + +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + + +__all__ = [ + "vectorize", + "vmap", +] +__all__.sort() + + +def vectorize( + fun: Callable, + /, + *, + excluded: Optional[AbstractSet[Union[int, str]]] = None, + signature: Optional[str] = None, +) -> Callable: + """Vectorizing map, which creates a function which maps ``fun`` over array elements. + + Define a vectorized function which takes a nested sequence of arrays as inputs + and returns a single array or a tuple of arrays. The vectorized function + evaluates ``fun`` over successive tuples of the input arrays like the python map + function, except it uses broadcasting rules. + + .. note:: + The :func:`~probnum.vectorize` function is primarily provided for convenience, + not for performance. The implementation is essentially a for loop. + + Parameters + ---------- + fun + Function to be mapped + excluded + Set of strings or integers representing the positional or keyword arguments for + which the function will not be vectorized. These will be passed directly to + ``fun`` unmodified. + signature + Generalized universal function signature, e.g., ``(m,n),(n)->(m)`` for + vectorized matrix-vector multiplication. If provided, ``fun`` will be called + with (and expected to return) arrays with shapes given by the size of + corresponding core dimensions. By default, ``fun`` is assumed to take scalars as + input and output. + """ + return _impl.vectorize(fun, excluded=excluded, signature=signature) + + +def vmap( + fun: Callable, + /, + in_axes: Union[int, Sequence[Any]] = 0, + out_axes: Union[int, Sequence[Any]] = 0, +) -> Callable: + """Vectorizing map, which creates a function which maps ``fun`` over argument axes. + + Parameters + ---------- + fun + Function to be mapped over additional axes. + in_axes + Input array axes to map over. + + If each positional argument to ``fun`` is an array, then ``in_axes`` can + be an integer, a None, or a tuple of integers and Nones with length equal + to the number of positional arguments to ``fun``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + axes of the corresponding input array. + out_axes + Where the mapped axis should appear in the output. + + All outputs with a mapped axis must have a non-None + ``out_axes`` specification. Axis integers must be in the range ``[-ndim, + ndim)`` for each output array, where ``ndim`` is the number of dimensions + (axes) of the array returned by the :func:`vmap`-ed function, which is one + more than the number of dimensions (axes) of the corresponding array + returned by ``fun``. + + Returns + ------- + vfun + Batched/vectorized version of ``fun`` with arguments that correspond to + those of ``fun``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``fun``, but + with extra array axes at positions indicated by ``out_axes``. + """ + return _impl.vmap(fun, in_axes, out_axes) diff --git a/src/probnum/backend/_vectorization/_jax.py b/src/probnum/backend/_vectorization/_jax.py new file mode 100644 index 000000000..eb7bd9292 --- /dev/null +++ b/src/probnum/backend/_vectorization/_jax.py @@ -0,0 +1,23 @@ +"""Vectorization in JAX.""" + +from typing import AbstractSet, Callable, Optional, Union + +try: + from jax import vmap # pylint: disable=unused-import + import jax.numpy as jnp +except ModuleNotFoundError: + pass + + +def vectorize( + fun: Callable, + /, + *, + excluded: Optional[AbstractSet[Union[int, str]]] = None, + signature: Optional[str] = None, +) -> Callable: + return jnp.vectorize( + fun, + excluded=excluded if excluded is not None else set(), + signature=signature, + ) diff --git a/src/probnum/backend/_vectorization/_numpy.py b/src/probnum/backend/_vectorization/_numpy.py new file mode 100644 index 000000000..2fd07791d --- /dev/null +++ b/src/probnum/backend/_vectorization/_numpy.py @@ -0,0 +1,13 @@ +"""Vectorization in NumPy.""" + +from typing import Any, Callable, Sequence, Union + +from numpy import vectorize # pylint: disable=redefined-builtin, unused-import + + +def vmap( + fun: Callable, + in_axes: Union[int, Sequence[Any]] = 0, + out_axes: Union[int, Sequence[Any]] = 0, +) -> Callable: + raise NotImplementedError diff --git a/src/probnum/backend/_vectorization/_torch.py b/src/probnum/backend/_vectorization/_torch.py new file mode 100644 index 000000000..2f0e7add2 --- /dev/null +++ b/src/probnum/backend/_vectorization/_torch.py @@ -0,0 +1,25 @@ +"""Vectorization in PyTorch.""" +from typing import AbstractSet, Any, Callable, Optional, Sequence, Union + +try: + import functorch +except ModuleNotFoundError: + pass + + +def vectorize( + fun: Callable, + /, + *, + excluded: Optional[AbstractSet[Union[int, str]]] = None, + signature: Optional[str] = None, +) -> Callable: + raise NotImplementedError() + + +def vmap( + fun: Callable, + in_axes: Union[int, Sequence[Any]] = 0, + out_axes: Union[int, Sequence[Any]] = 0, +) -> Callable: + return functorch.vmap(fun, in_dims=in_axes, out_dims=out_axes) diff --git a/src/probnum/backend/autodiff/__init__.py b/src/probnum/backend/autodiff/__init__.py new file mode 100644 index 000000000..cab1ccde1 --- /dev/null +++ b/src/probnum/backend/autodiff/__init__.py @@ -0,0 +1,197 @@ +"""(Automatic) Differentiation.""" + +from typing import Any, Callable, Sequence, Tuple, Union + +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + + +__all__ = [ + "grad", + "hessian", + "jacfwd", + "jacrev", + "value_and_grad", +] +__all__.sort() + + +def grad( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + """Creates a function that evaluates the gradient of ``fun``. + + Parameters + ---------- + fun + Function to be differentiated. Its arguments at positions specified by + ``argnums`` should be arrays, scalars, or standard Python containers. + Argument arrays in the positions specified by ``argnums`` must be of + inexact (i.e., floating-point or complex) type. It + should return a scalar (which includes arrays with shape ``()`` but not + arrays with shape ``(1,)`` etc.) + argnums + Specifies which positional argument(s) to differentiate with respect to. + has_aux + Indicates whether ``fun`` returns a pair where the first element is considered + the output of the mathematical function to be differentiated and the second + element is auxiliary data. + + Returns + ------- + grad_fun + A function with the same arguments as ``fun``, that evaluates the gradient + of ``fun``. If ``argnums`` is an integer then the gradient has the same + shape and type as the positional argument indicated by that integer. If + argnums is a tuple of integers, the gradient is a tuple of values with the + same shapes and types as the corresponding arguments. + + Examples + -------- + >>> from probnum.backend.autodiff import grad + >>> grad_sin = grad(backend.sin) + >>> grad_sin(backend.pi) + -1.0 + """ + return _impl.grad(fun=fun, argnums=argnums, has_aux=has_aux) + + +def hessian( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + """Hessian of ``fun`` as a dense array. + + Parameters + ---------- + fun + Function whose Hessian is to be computed. Its arguments at positions + specified by ``argnums`` should be arrays, scalars, or standard Python + containers thereof. It should return arrays, scalars, or standard Python + containers thereof. + argnums + Specifies which positional argument(s) to differentiate with respect to. + has_aux + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. + + Returns + ------- + hessian + A function with the same arguments as ``fun``, that evaluates the Hessian of + ``fun``. + + >>> from probnum.backend.autodiff import hessian + >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 + >>> hessian(g)(backend.asarray([1., 2.]))) + [[ 6. -2.] + [ -2. -480.]] + """ + return _impl.hessian(fun=fun, argnums=argnums, has_aux=has_aux) + + +def jacfwd( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. + + Parameters + ---------- + fun + Function whose Jacobian is to be computed. + argnums + Specifies which positional argument(s) to differentiate with respect to. + has_aux + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. + + Returns + ------- + jacfun + A function with the same arguments as ``fun``, that evaluates the Jacobian of + ``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True + then a pair of (jacobian, auxiliary_data) is returned. + """ + return _impl.jacfwd(fun, argnums, has_aux=has_aux) + + +def jacrev( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. + + Parameters + ---------- + fun + Function whose Jacobian is to be computed. + argnums + Specifies which positional argument(s) to differentiate with respect to. + has_aux + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. + + Returns + ------- + jacfun + A function with the same arguments as ``fun``, that evaluates the Jacobian of + ``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True + then a pair of (jacobian, auxiliary_data) is returned. + """ + return _impl.jacrev(fun, argnums, has_aux=has_aux) + + +def value_and_grad( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable[..., Tuple[Any, Any]]: + """Create a function that efficiently evaluates both ``fun`` and the gradient of + ``fun``. + + Parameters + ---------- + fun + Function to be differentiated. Its arguments at positions specified by + ``argnums`` should be arrays, scalars, or standard Python containers. It should + return a scalar (which includes arrays with shape ``()`` but not arrays with + shape ``(1,)`` etc.) + argnums + Specifies which positional argument(s) to differentiate with respect to. + has_aux + Indicates whether ``fun`` returns a pair where the first element is considered + the output of the mathematical function to be differentiated and the second + element is auxiliary data. + + Returns + ------- + value_and_grad + A function with the same arguments as ``fun`` that evaluates both ``fun`` and + the gradient of ``fun`` and returns them as a pair (a two-element tuple). If + ``argnums`` is an integer then the gradient has the same shape and type as the + positional argument indicated by that integer. If ``argnums`` is a sequence of + integers, the gradient is a tuple of values with the same shapes and types as + the corresponding arguments. If ``has_aux`` is ``True`` then a tuple of + ``((value, auxiliary_data), gradient)`` is returned. + """ + return _impl.value_and_grad(fun, argnums, has_aux=has_aux) diff --git a/src/probnum/backend/autodiff/_jax.py b/src/probnum/backend/autodiff/_jax.py new file mode 100644 index 000000000..3de2b5ecd --- /dev/null +++ b/src/probnum/backend/autodiff/_jax.py @@ -0,0 +1,11 @@ +"""(Automatic) Differentiation in JAX.""" +try: + from jax import ( # pylint: disable=unused-import + grad, + hessian, + jacfwd, + jacrev, + value_and_grad, + ) +except ModuleNotFoundError: + pass diff --git a/src/probnum/backend/autodiff/_numpy.py b/src/probnum/backend/autodiff/_numpy.py new file mode 100644 index 000000000..f04eff2b8 --- /dev/null +++ b/src/probnum/backend/autodiff/_numpy.py @@ -0,0 +1,39 @@ +"""(Automatic) Differentiation in NumPy.""" + +from typing import Any, Callable, Sequence, Union + + +def grad( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + raise NotImplementedError() + + +def hessian( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + raise NotImplementedError + + +def jacrev( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + raise NotImplementedError + + +def jacfwd( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + raise NotImplementedError + + +def value_and_grad( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + raise NotImplementedError() diff --git a/src/probnum/backend/autodiff/_torch.py b/src/probnum/backend/autodiff/_torch.py new file mode 100644 index 000000000..b9973f8b2 --- /dev/null +++ b/src/probnum/backend/autodiff/_torch.py @@ -0,0 +1,52 @@ +"""(Automatic) Differentiation in PyTorch.""" + +from typing import Callable, Sequence, Union + +try: + import functorch +except ModuleNotFoundError: + pass + + +def grad( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + return functorch.grad(fun, argnums, has_aux=has_aux) + + +def hessian( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + return functorch.jacfwd( + functorch.jacrev(fun, argnums, has_aux=has_aux), argnums, has_aux=has_aux + ) + + +def jacrev( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + return functorch.jacrev(fun, argnums, has_aux=has_aux) + + +def jacfwd( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + *, + has_aux: bool = False, +) -> Callable: + return functorch.jacfwd(fun, argnums, has_aux=has_aux) + + +def value_and_grad( + fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False +) -> Callable: + gfun_fun = functorch.grad_and_value(fun, argnums, has_aux=has_aux) + + def fun_gradfun(x): + g, f = gfun_fun(x) + return f, g + + return fun_gradfun diff --git a/src/probnum/backend/linalg/__init__.py b/src/probnum/backend/linalg/__init__.py new file mode 100644 index 000000000..45bbf4d85 --- /dev/null +++ b/src/probnum/backend/linalg/__init__.py @@ -0,0 +1,893 @@ +"""Linear algebra.""" +import collections +from typing import Literal, Optional, Tuple, Union + +from probnum.backend.typing import ShapeLike + +from .. import Array, asshape +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +from ._cholesky_updates import cholesky_update, tril_to_positive_tril +from ._inner_product import induced_vector_norm, inner_product +from ._orthogonalize import gram_schmidt, gram_schmidt_double, gram_schmidt_modified + +__all__ = [ + "cholesky", + "cholesky_update", + "det", + "diagonal", + "eigh", + "eigvalsh", + "einsum", + "gram_schmidt", + "gram_schmidt_double", + "gram_schmidt_modified", + "induced_vector_norm", + "inner_product", + "inv", + "kron", + "matrix_norm", + "matrix_rank", + "matrix_power", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "solve_cholesky", + "solve_triangular", + "svd", + "svdvals", + "tensordot", + "trace", + "tril_to_positive_tril", + "vecdot", + "vector_norm", +] +__all__.sort() + + +def cholesky(x: Array, /, *, upper: bool = False) -> Array: + r""" + Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix (or stack of matrices) ``x``. + + If ``x`` is real-valued, let :math:`\mathbb{K}` be the set of real numbers + $\mathbb{R}$, and, if ``x`` is complex-valued, let $\mathbb{K}$ be the set of + complex numbers $\mathbb{C}$. + + The lower Cholesky decomposition of a complex Hermitian or real symmetric + positive-definite matrix :math:`x \in \mathbb{K}^{n \times n}` is defined as + + .. math:: + + x = LL^{H} \qquad \text{L $\in \mathbb{K}^{n \times n}$} + + where :math:`L` is a lower triangular matrix and :math:`L^{H}` is the conjugate + transpose when :math:`L` is complex-valued and the transpose when :math:`L` is + real-valued. + + The upper Cholesky decomposition is defined similarly + + .. math:: + + x = UU^{H} \qquad \text{U $\in\ \mathbb{K}^{n \times n}$} + + where :math:`U` is an upper triangular matrix. + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions + form square complex Hermitian or real symmetric positive-definite matrices. + upper + If ``True``, the result will be the upper-triangular Cholesky factor :math:`U`. + If ``False``, the result will be the lower-triangular Cholesky factor :math:`L`. + + Returns + ------- + out + An array containing the Cholesky factors for each square matrix. + """ + return _impl.cholesky(x, upper=upper) + + +def det(x: Array, /) -> Array: + """Returns the determinant of a square matrix (or a stack of square matrices) ``x``. + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form + square matrices. + + Returns + ------- + out + If ``x`` is a two-dimensional array, a zero-dimensional array containing the + determinant; otherwise, a non-zero dimensional array containing the determinant + for each square matrix. + """ + return _impl.det(x) + + +def inv(x: Array, /) -> Array: + """Returns the multiplicative inverse of a square matrix (or a stack of square + matrices). + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form + square matrices. + + Returns + ------- + out + An array containing the multiplicative inverses. + """ + return _impl.inv(x) + + +def outer(x1: Array, x2: Array, /) -> Array: + """Returns the outer product of two vectors ``x1`` and ``x2``. + + Parameters + ---------- + x1 + First one-dimensional input array of size ``N``. + x2 + Second one-dimensional input array of size ``M``. + + Returns + ------- + out + A two-dimensional array containing the outer product and whose shape is + ``(N, M)``. + """ + return _impl.outer(x1, x2) + + +def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: + """Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices). + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. + rtol + Relative tolerance for small singular values. Singular values approximately less + than or equal to ``rtol * largest_singular_value`` are set to zero. + + Returns + ------- + out + An array containing the pseudo-inverses. + """ + return _impl.pinv(x, rtol=rtol) + + +def matrix_power(x: Array, n: int, /) -> Array: + """Raises a square matrix (or a stack of square matrices) ``x`` to an integer power + ``n``. + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. + n + Integer exponent. + + Returns + ------- + out + If ``n`` is equal to zero, an array containing the identity matrix for each square matrix. If ``n`` is less than zero, an array containing the inverse of each square matrix raised to the absolute value of ``n``, provided that each square matrix is invertible. If ``n`` is greater than zero, an array containing the result of raising each square matrix to the power ``n``. + """ + return _impl.matrix_power(x, n) + + +def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: + """Returns the rank (i.e., number of non-zero singular values) of a matrix (or a + stack of matrices). + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. + rtol + Relative tolerance for small singular values. Singular values approximately less + than or equal to ``rtol * largest_singular_value`` are set to zero. + + Returns + ------- + out + An array containing the ranks. + """ + return _impl.matrix_rank(x, rtol=rtol) + + +def matrix_transpose(x: Array, /) -> Array: + """Transposes a matrix (or a stack of matrices) ``x``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``M x N`` matrices. + + Returns + ------- + out + An array containing the transpose for each matrix and having shape + ``(..., N, M)``. The returned array must have the same data type as ``x``. + """ + return _impl.matrix_transpose(x) + + +Slogdet = collections.namedtuple("Slogdet", ["sign", "logabsdet"]) + + +def slogdet(x: Array, /) -> Tuple[Array, Array]: + """Returns the sign and the natural logarithm of the absolute value of the + determinant of a square matrix (or a stack of square matrices). + + .. note:: + The purpose of this function is to calculate the determinant more accurately when the determinant is either very small or very large, as calling ``det`` may overflow or underflow. + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form + square matrices. + + Returns + ------- + out + A namedtuple (``sign``, ``logabsdet``) whose + + - first element ``sign`` is an array representing the sign of the determinant + for each square matrix. + - second element ``logabsdet`` is an array containing the determinant for each + square matrix. + """ + sign, logabsdet = _impl.slogdet(x) + return Slogdet(sign, logabsdet) + + +def trace(x: Array, /, *, offset: int = 0) -> Array: + """Returns the sum along the specified diagonals of a matrix (or a stack of + matrices). + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. + offset + offset specifying the off-diagonal relative to the main diagonal. + - ``offset = 0``: the main diagonal. + - ``offset > 0``: off-diagonal above the main diagonal. + - ``offset < 0``: off-diagonal below the main diagonal. + + Returns + ------- + out + An array containing the traces and whose shape is determined by removing the + last two dimensions and storing the traces in the last array dimension. + """ + return _impl.trace(x, offset=offset) + + +def einsum( + *arrays: Array, + optimization: Optional[str] = "greedy", +): + """Evaluates the Einstein summation convention on the given ``arrays``. + + Using the Einstein summation convention, many common multi-dimensional, linear + algebraic array operations can be represented in a simple fashion. + + Parameters + ---------- + arrays + Arrays to use for the operation. + optimization + Controls what kind of intermediate optimization of the contraction path should + occur. Options are: + + +---------------+--------------------------------------------------------+ + | ``None`` | No optimization will be done. | + +---------------+--------------------------------------------------------+ + | ``"optimal"`` | Exhaustively search all possible paths. | + +---------------+--------------------------------------------------------+ + | ``"greedy"`` | Find a path one step at a time using a cost heuristic. | + +---------------+--------------------------------------------------------+ + + Returns + ------- + out + The calculation based on the Einstein summation convention. + """ + return _impl.einsum(*arrays, optimize=optimization) + + +def vector_norm( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal["inf", "-inf"]] = 2, +) -> Array: + """Computes the vector norm of a vector (or batch of vectors). + + Parameters + ---------- + x + Input array. Should have a floating-point data type. + axis + If an integer, ``axis`` specifies the axis (dimension) along which to compute + vector norms. If an n-tuple, ``axis`` specifies the axes (dimensions) along + which to compute batched vector norms. If ``None``, the vector norm is + computed over all array values (i.e., equivalent to computing the vector norm of + a flattened array). + keepdims + If ``True``, the axes (dimensions) specified by ``axis`` are included in the + result as singleton dimensions, and, accordingly, the result is compatible with + the input array. Otherwise, if ``False``, the last two + axes (dimensions) are not be included in the result. + ord + Order of the norm. The following mathematical norms are supported: + + +--------------------+----------------------------+ + | ord | description | + +====================+============================+ + | `1` | L1-norm (Manhattan) | + +--------------------+----------------------------+ + | `2` | L2-norm (Euclidean) | + +--------------------+----------------------------+ + | `inf` | infinity norm | + +--------------------+----------------------------+ + | `(int,float >= 1)` | p-norm | + +--------------------+----------------------------+ + + The following non-mathematical "norms" are supported: + + +--------------------+------------------------------------+ + | ord | description | + +====================+====================================+ + | `0` | :code:`sum(a != 0)` | + +--------------------+------------------------------------+ + | `-1` | :code:`1./sum(1./abs(a))` | + +--------------------+------------------------------------+ + | `-2` | :code:`1./sqrt(sum(1./abs(a)**2))` | + +--------------------+------------------------------------+ + | `-inf` | :code:`min(abs(a))` | + +--------------------+------------------------------------+ + | `(int,float < 1)` | :code:`sum(abs(a)**ord)**(1./ord)` | + +--------------------+------------------------------------+ + + Returns + ------- + out + An array containing the vector norms. If ``axis`` is ``None``, the returned + array is a zero-dimensional array containing a vector norm. If ``axis`` is a + scalar value (``int`` or ``float``), the returned array has a rank which + is one less than the rank of ``x``. If ``axis`` is a ``n``-tuple, the returned + array has a rank which is ``n`` less than the rank of ``x``. + """ + return _impl.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) + + +def kron(x: Array, y: Array, /) -> Array: + """Kronecker product of two arrays. + + Computes the Kronecker product, a composite array made of blocks of the second array + scaled by the first. + + Parameters + ---------- + x + First Kronecker factor. + y + Second Kronecker factor. + """ + return _impl.kron(x, y) + + +def matmul(x1: Array, x2: Array, /) -> Array: + """Computes the matrix product. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. + + Returns + ------- + out + Matrix product of ``x1 and ``x2``. + """ + return _impl.matmul(x1, x2) + + +def matrix_norm( + x: Array, + /, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal["inf", "-inf", "fro", "nuc"]]] = "fro", +) -> Array: + """Computes the matrix norm of a matrix (or a stack of matrices) ``x``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices. Should have a floating-point data type. + keepdims + If ``True``, the last two axes (dimensions) are included in the result as + singleton dimensions, and, accordingly, the result is compatible with the + input array (see `broadcasting `_). Otherwise, if ``False``, the last two + axes (dimensions) are not be included in the result. + ord + Order of the norm. The following mathematical norms are supported: + + +------------------+---------------------------------+ + | ord | description | + +==================+=================================+ + | 'fro' | Frobenius norm | + +------------------+---------------------------------+ + | 'nuc' | nuclear norm | + +------------------+---------------------------------+ + | 1 | max(sum(abs(x), axis=0)) | + +------------------+---------------------------------+ + | 2 | largest singular value | + +------------------+---------------------------------+ + | inf | max(sum(abs(x), axis=1)) | + +------------------+---------------------------------+ + + The following non-mathematical "norms" are supported: + + +------------------+---------------------------------+ + | ord | description | + +==================+=================================+ + | -1 | min(sum(abs(x), axis=0)) | + +------------------+---------------------------------+ + | -2 | smallest singular value | + +------------------+---------------------------------+ + | -inf | min(sum(abs(x), axis=1)) | + +------------------+---------------------------------+ + + If ``ord=1``, the norm corresponds to the induced matrix norm where ``p=1`` + (i.e., the maximum absolute value column sum). + If ``ord=2``, the norm corresponds to the induced matrix norm where ``p=inf`` + (i.e., the maximum absolute value row sum). + If ``ord=inf``, the norm corresponds to the induced matrix norm where ``p=2`` + (i.e., the largest singular value). + + Returns + ------- + out + An array containing the norms for each ``MxN`` matrix. If ``keepdims`` is + ``False``, the returned array has a rank which is two less than the + rank of ``x``. The returned array must have a floating-point data type + determined by `type-promotion `_. + """ + return _impl.matrix_norm(x, keepdims=keepdims, ord=ord) + + +def solve(A: Array, B: Array, /) -> Array: + """Returns the solution to the system of linear equations represented by the + well-determined (i.e., full rank) linear matrix equation ``AX = B``. + + .. note:: + + Whether an array library explicitly checks whether an input array is full rank is + implementation-defined. + + Parameters + ---------- + A + Coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two + dimensions form square matrices. Must be of full rank (i.e., all rows or, + equivalently, columns must be linearly independent). + B + Ordinate (or "dependent variable") array ``B``. If ``B`` has shape ``(M,)``, + ``B`` is equivalent to an array having shape ``(..., M, 1)``. If ``B`` has + shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for + which to compute a solution, and ``shape(B)[:-1]`` must be compatible with + ``shape(x1)[:-1]`` (see `broadcasting `_). + + Returns + ------- + out: + An array containing the solution to the system ``AX = B`` for each square + matrix. + """ + return _impl.solve(A, B) + + +def solve_cholesky( + C: Array, + B: Array, + /, + *, + upper: bool = False, + check_finite: bool = True, +) -> Array: + r"""Computes the solution of the system of linear equations ``A X = B`` + given the Cholesky factor ``C`` of ``A``. + + Parameters + ---------- + C + Cholesky factor(s) ``C`` having shape ``(..., M, M)`` and whose innermost two + dimensions form triangular matrices. + B + Ordinate (or "dependent variable") array ``B``. If ``B`` has shape ``(M,)``, + ``B`` is equivalent to an array having shape ``(..., M, 1)``. If ``B`` has + shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for + which to compute a solution, and ``shape(B)[:-1]`` must be compatible with + ``shape(A)[:-1]`` (see `broadcasting `_). + upper + If ``True``, the result will be the upper-triangular Cholesky factor :math:`U`. + If ``False``, the result will be the lower-triangular Cholesky factor :math:`L`. + check_finite + Whether to check that the input matrices contain only finite numbers. Disabling + may give a performance gain, but may result in problems (crashes, + non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + out: + An array containing the solution to the system ``AX = B`` for each Cholesky + factor. + """ + return _impl.solve_cholesky(cholfac, B, upper=upper, check_finite=check_finite) + + +def solve_triangular( + A: Array, + B: Array, + /, + *, + transpose: bool = False, + upper: bool = False, + unit_diagonal: bool = False, +) -> Array: + r"""Computes the solution of a triangular system of linear equations ``AX = B`` + with a unique solution. + + Parameters + ---------- + A + Coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two + dimensions form triangular matrices. Must be of full rank (i.e., all rows or, + equivalently, columns must be linearly independent). + B + Ordinate (or "dependent variable") array ``B``. If ``B`` has shape ``(M,)``, + ``B`` is equivalent to an array having shape ``(..., M, 1)``. If ``B`` has + shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for + which to compute a solution, and ``shape(B)[:-1]`` must be compatible with + ``shape(A)[:-1]`` (see `broadcasting `_). + transpose + Whether to solve the system :math:`AX=B` or the system + :math:`A^\top X=B`. + upper + Use only data contained in the upper triangle of ``A``. + unit_diagonal + Whether the diagonal(s) of the triangular matrices in ``A`` consistent of ones. + + Returns + ------- + out: + An array containing the solution to the system ``AX = B`` for each triangular + matrix. + """ + return _impl.solve_triangular( + A, B, transpose=transpose, upper=upper, unit_diagonal=unit_diagonal + ) + + +def diagonal( + x: Array, /, *, offset: int = 0, axis1: int = -2, axis2: int = -1 +) -> Array: + """Returns the specified diagonals of a matrix (or a stack of matrices) ``x``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions + form ``MxN`` matrices. + offset + Offset specifying the off-diagonal relative to the main diagonal. + - ``offset = 0``: the main diagonal. + - ``offset > 0``: off-diagonal above the main diagonal. + - ``offset < 0``: off-diagonal below the main diagonal. + axis1 + Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals + should be taken. + axis2 + Axis to be used as the second axis of the 2-D sub-arrays from which the + diagonals should be taken. + + Returns + ------- + out + An array containing the diagonals and whose shape is determined by removing the + last two dimensions and appending a dimension equal to the size of the resulting + diagonals. + """ + return _impl.diagonal(x, offset, axis1, axis2) + + +Eigh = collections.namedtuple("Eigh", ["eigenvalues", "eigenvectors"]) + + +def eigh(x: Array, /) -> Tuple[Array]: + """ + Returns an eigendecomposition ``x = QLQᵀ`` of a symmetric matrix (or a stack of + symmetric matrices) ``x``, where ``Q`` is an orthogonal matrix (or a stack of + matrices) and ``L`` is a vector (or a stack of vectors). + + .. note:: + + Whether an array library explicitly checks whether an input array is a symmetric + matrix (or a stack of symmetric matrices) is implementation-defined. + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form + square matrices. Must have a floating-point data type. + + Returns + ------- + out + A namedtuple (``eigenvalues``, ``eigenvectors``) whose + + - first element is an array consisting of computed eigenvalues and has shape + ``(..., M)``. + - second element is an array where the columns of the inner most + matrices contain the computed eigenvectors. These matrices are + orthogonal. The array containing the eigenvectors has shape + ``(..., M, M)``. + + Each returned array has the same floating-point data type as ``x``. + + .. note:: + + Eigenvalue sort order is left unspecified and is thus implementation-dependent. + """ + eigenvalues, eigenvectors = _impl.eigh(x) + return Eigh(eigenvalues, eigenvectors) + + +def eigvalsh(x: Array, /) -> Array: + """Returns the eigenvalues of a symmetric matrix (or a stack of symmetric matrices). + + Parameters + ---------- + x + Input array having shape ``(..., M, M)`` and whose innermost two dimensions form + square matrices. Must have a real-valued floating-point data type. + + Returns + ------- + out + An array containing the computed eigenvalues. The returned array must have shape + ``(..., M)`` and have the same data type as ``x``. + + .. note:: + Eigenvalue sort order is left unspecified and is thus implementation-dependent. + """ + return _impl.eigvalsh(x) + + +SVD = collections.namedtuple("SVD", ["U", "S", "Vh"]) + + +def svd(x: Array, /, *, full_matrices: bool = True) -> Union[Array, Tuple[Array, ...]]: + """ + Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of + matrices) ``x``, where ``U`` is a matrix (or a stack of matrices) with orthonormal + columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` + is a matrix (or a stack of matrices) with orthonormal rows. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + matrices on which to perform singular value decomposition. Must have a + floating-point data type. + full_matrices + If ``True``, compute full-sized ``U`` and ``Vh``, such that ``U`` has shape + ``(..., M, M)`` and ``Vh`` has shape ``(..., N, N)``. If ``False``, compute on + the leading ``K`` singular vectors, such that ``U`` has shape ``(..., M, K)`` + and ``Vh`` has shape ``(..., K, N)`` and where ``K = min(M, N)``. + + Returns + ------- + out + A namedtuple ``(U, S, Vh)`` whose + + - first element is an array whose shape depends on the value of + ``full_matrices`` and contains matrices with orthonormal columns (i.e., the + columns are left singular vectors). If + ``full_matrices`` is ``True``, the array has shape ``(..., M, M)``. If + ``full_matrices`` is ``False``, the array has shape ``(..., M, K)``, + where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions have the + same shape as those of the input ``x``. + - second element is an array with shape ``(..., K)`` that contains the + vector(s) of singular values of length ``K``, where ``K = min(M, N)``. For + each vector, the singular values must be sorted in descending order by + magnitude, such that ``s[..., 0]`` is the + largest value, ``s[..., 1]`` is the second largest value, et cetera. The + first ``x.ndim-2`` dimensions have the same shape as those of the input + ``x``. + - third element is an array whose shape depends on the value of + ``full_matrices`` and contain orthonormal rows (i.e., the rows are the right + singular vectors and the array is the adjoint). If ``full_matrices`` is + ``True``, the array has shape ``(..., N, N)``. If ``full_matrices`` is + ``False``, the array has shape ``(..., K, N)`` where ``K = min(M, N)``. + The first ``x.ndim-2`` dimensions have the same shape as those of the input + ``x``. + + Each returned array has the same floating-point data type as ``x``. + """ + U, S, Vh = _impl.svd(x, full_matrices=full_matrices) + return SVD(U, S, Vh) + + +def svdvals(x: Array, /) -> Array: + """Returns the singular values of a matrix (or a stack of matrices) ``x``. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form matrices on which to perform singular value decomposition. + + Returns + ------- + out + An array with shape ``(..., K)`` that contains the vector(s) of singular values of length ``K``, where ``K = min(M, N)``. For each vector, the singular values are sorted in descending order by magnitude. + """ + return _impl.svdvals(x) + + +QR = collections.namedtuple("QR", ["Q", "R"]) + + +def qr( + x: Array, /, *, mode: Literal["reduced", "complete", "r"] = "reduced" +) -> Tuple[Array, Array]: + """ + Returns the QR decomposition ``x = QR`` of a full column rank matrix (or a stack of + matrices), where ``Q`` is an orthonormal matrix (or a stack of matrices) and ``R`` + is an upper-triangular matrix (or a stack of matrices). + + .. note:: + + Whether an array library explicitly checks whether an input array is a full + column rank matrix (or a stack of full column rank matrices) is + implementation-defined. + + Parameters + ---------- + x + Input array having shape ``(..., M, N)`` and whose innermost two dimensions form + ``MxN`` matrices of rank ``N``. Should have a floating-point data type. + mode + Decomposition mode. Should be one of the following modes: + + - ``'reduced'``: compute only the leading ``K`` columns of ``q``, such that + ``q`` and ``r`` have dimensions ``(..., M, K)`` and ``(..., K, N)``, + respectively, and where ``K = min(M, N)``. + - ``'complete'``: compute ``q`` and ``r`` with dimensions ``(..., M, M)`` and + ``(..., M, N)``, respectively. + + Returns + ------- + out + A namedtuple ``(Q, R)`` whose + + - first element is an array whose shape depends on the value of ``mode`` and + contains matrices with orthonormal columns. If ``mode`` is ``'complete'``, + the array has shape ``(..., M, M)``. If ``mode`` is ``'reduced'``, the array + has shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` + dimensions have the same size as those of the input array ``x``. + - second element is an array whose shape depends on the value of ``mode`` and + contains upper-triangular matrices. If ``mode`` is ``'complete'``, the array + has shape ``(..., M, N)``. If ``mode`` is ``'reduced'``, the array has shape + ``(..., K, N)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions + have the same size as those of the input ``x``. + + Each returned array has a floating-point data type determined by + `type-promotion `_. + """ + Q, R = _impl.qr(x, mode=mode) + return QR(Q, R) + + +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + """Computes the (vector) dot product of two arrays along an axis. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. Must be compatible with ``x1`` for all non-contracted axes. + The size of the axis over which to compute the dot product must be the same size + as the respective axis in ``x1``. + axis + Axis over which to compute the dot product. + + Returns + ------- + out + If ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional array + containing the dot product; otherwise, a non-zero-dimensional array containing + the dot products and having rank ``N-1``, where ``N`` is the rank (number of + dimensions) of the shape determined according to broadcasting along the + non-contracted axes. + """ + return _impl.vecdot(x1, x2, axis) + + +def tensordot( + x1: Array, x2: Array, /, *, axes: Union[int, Tuple[ShapeLike, ShapeLike]] = 2 +) -> Array: + """Returns a tensor contraction of ``x1`` and ``x2`` over specific axes. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. Corresponding contracted axes of ``x1`` and ``x2`` must be equal. + axes + Number of axes (dimensions) to contract or explicit sequences of axes + (dimensions) for ``x1`` and ``x2``, respectively. + + If ``axes`` is an ``int`` equal to ``N``, then contraction will be performed + over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. + The size of each corresponding axis (dimension) must match. + - If ``N`` equals ``0``, the result is the tensor (outer) product. + - If ``N`` equals ``1``, the result is the tensor dot product. + - If ``N`` equals ``2``, the result is the tensor double contraction (default). + + If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first + sequence must apply to ``x`` and the second sequence to ``x2``. Both sequences + must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must + have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. + Each sequence must consist of unique (nonnegative) integers that specify valid + axes for each respective array. + + Returns + ------- + out + An array containing the tensor contraction whose shape consists of the + non-contracted axes (dimensions) of the first array ``x1``, followed by the + non-contracted axes (dimensions) of the second array ``x2``. + """ + if isinstance(axes, tuple): + axes = (asshape(axes[0]), asshape(axes[1])) + return _impl.tensordot(x1, x2, axes) diff --git a/src/probnum/utils/linalg/_cholesky_updates.py b/src/probnum/backend/linalg/_cholesky_updates.py similarity index 83% rename from src/probnum/utils/linalg/_cholesky_updates.py rename to src/probnum/backend/linalg/_cholesky_updates.py index a35c08105..933c3bba8 100644 --- a/src/probnum/utils/linalg/_cholesky_updates.py +++ b/src/probnum/backend/linalg/_cholesky_updates.py @@ -1,16 +1,16 @@ """Cholesky updates.""" -import typing +from typing import Optional -import numpy as np +from probnum import backend __all__ = ["cholesky_update", "tril_to_positive_tril"] def cholesky_update( - S1: np.ndarray, S2: typing.Optional[np.ndarray] = None -) -> np.ndarray: + S1: backend.Array, S2: Optional[backend.Array] = None +) -> backend.Array: r"""Compute Cholesky update/factorization :math:`L` such that :math:`L L^\top = S_1 S_1^\top + S_2 S_2^\top` holds. @@ -39,7 +39,8 @@ def cholesky_update( Examples -------- - >>> from probnum.utils.linalg import cholesky_update + + >>> from probnum.backend.linalg import cholesky_update >>> from probnum.problems.zoo.linalg import random_spd_matrix >>> import numpy as np @@ -63,22 +64,22 @@ def cholesky_update( True """ if S2 is not None: - stacked_up = np.vstack((S1.T, S2.T)) + stacked_up = backend.vstack((S1.T, S2.T)) else: - stacked_up = np.vstack(S1.T) - upper_sqrtm = np.linalg.qr(stacked_up, mode="r") + stacked_up = backend.vstack(S1.T) + _, upper_sqrtm = backend.linalg.qr(stacked_up, mode="r") if S1.ndim == 1: lower_sqrtm = upper_sqrtm.T elif S1.shape[0] <= S1.shape[1]: lower_sqrtm = upper_sqrtm.T else: - lower_sqrtm = np.zeros((S1.shape[0], S1.shape[0])) + lower_sqrtm = backend.zeros((S1.shape[0], S1.shape[0])) lower_sqrtm[:, : -(S1.shape[0] - S1.shape[1])] = upper_sqrtm.T return tril_to_positive_tril(lower_sqrtm) -def tril_to_positive_tril(tril_mat: np.ndarray) -> np.ndarray: +def tril_to_positive_tril(tril_mat: backend.Array) -> backend.Array: r"""Orthogonally transform a lower-triangular matrix into a lower-triangular matrix with positive diagonal. @@ -86,7 +87,7 @@ def tril_to_positive_tril(tril_mat: np.ndarray) -> np.ndarray: The name of the function is based on `np.tril`. """ - d = np.sign(np.diag(tril_mat)) + d = backend.sign(backend.diag(tril_mat)) # Numpy assigns sign 0 to 0.0, which eliminate entire rows in the operation below. d[d == 0] = 1.0 diff --git a/src/probnum/backend/linalg/_inner_product.py b/src/probnum/backend/linalg/_inner_product.py new file mode 100644 index 000000000..de6a507d0 --- /dev/null +++ b/src/probnum/backend/linalg/_inner_product.py @@ -0,0 +1,100 @@ +"""Functions defining useful inner products and associated norms.""" + +from typing import Optional + +from probnum.typing import MatrixType + +from ... import backend as backend + + +def inner_product( + x1: backend.Array, + x2: backend.Array, + /, + A: Optional[MatrixType] = None, + *, + axis: int = -1, +) -> backend.Array: + r"""Computes the inner product :math:`\langle x_1, x_2 \rangle_A := x_1^T A x_2` of + two arrays along an axis. + + For n-d arrays the function computes the inner product over the given axis of the + two arrays ``x1`` and ``x2``. + + Parameters + ---------- + x1 + First input array. + x2 + Second input array. Must be compatible with ``x1`` for all non-contracted axes. + The size of the axis over which to compute the inner product must be the same + size as the respective axis in ``x1``. + A + Symmetric positive (semi-)definite matrix defining the geometry. + axis + Axis over which to compute the inner product. + + Returns + ------- + out : + If ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional array + containing the dot product; otherwise, a non-zero-dimensional array containing + the dot products and having rank ``N-1``, where ``N`` is the rank (number of + dimensions) of the shape determined according to broadcasting along the + non-contracted axes. + + Notes + ----- + Note that the broadcasting behavior of :func:`inner_product` differs from + :func:`numpy.inner`. Rather it follows the broadcasting rules of + :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors. + """ + if A is None: + return backend.vecdot(x1, x2) + + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,) * (ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,) * (ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same shape along the given axis.") + + x1_, x2_ = backend.broadcast_arrays(x1, x2) + x1_ = backend.move_axes(x1_, axis, -1) + x2_ = backend.move_axes(x2_, axis, -1) + + res = x1_[..., None, :] @ (A @ x2_[..., None]) + return backend.asarray(res[..., 0, 0]) + + +def induced_vector_norm( + x: backend.Array, + /, + A: Optional[MatrixType] = None, + axis: int = -1, +) -> backend.Array: + r"""Induced vector norm :math:`\lVert x \rVert_A := \sqrt{x^T A x}`. + + Computes the induced norm over the given axis of the array. + + Parameters + ---------- + x + Array. + A + Symmetric positive (semi-)definite linear operator defining the geometry. + axis + Specifies the axis along which to compute the vector norms. + + Returns + ------- + norm : + Vector norm of ``x`` along the given ``axis``. + """ + + if A is None: + return backend.linalg.vector_norm(x, ord=2, axis=axis, keepdims=False) + + x = backend.move_axes(x, axis, -1) + y = backend.squeeze(A @ x[..., :, None], axis=-1) + + return backend.sqrt(backend.sum(x * y, axis=-1)) diff --git a/src/probnum/backend/linalg/_jax.py b/src/probnum/backend/linalg/_jax.py new file mode 100644 index 000000000..d916a3328 --- /dev/null +++ b/src/probnum/backend/linalg/_jax.py @@ -0,0 +1,150 @@ +"""Implementation of linear algebra functionality in JAX.""" + +import functools +from typing import Literal, Optional, Tuple, Union + +try: + import jax + from jax import numpy as jnp + + # pylint: disable=unused-import + from jax.numpy import diagonal, einsum, kron, matmul, outer, tensordot, trace + from jax.numpy.linalg import ( + det, + eigh, + eigvalsh, + inv, + matrix_power, + pinv, + slogdet, + solve, + svd, + ) +except ModuleNotFoundError: + pass + + +def matrix_rank( + x: "jax.Array", /, *, rtol: Optional[Union[float, "jax.Array"]] = None +) -> "jax.Array": + return jnp.linalg.matrix_rank(x, tol=rtol) + + +def matrix_transpose(x: "jax.Array", /) -> "jax.Array": + return jnp.swapaxes(x, -2, -1) + + +def vector_norm( + x: "jax.Array", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal["inf", "-inf"]] = 2, +) -> "jax.Array": + return jnp.linalg.norm(x=x, ord=ord, keepdims=keepdims, axis=axis) + + +def matrix_norm(x: "jax.Array", /, *, keepdims: bool = False, ord="fro") -> "jax.Array": + return jnp.linalg.norm(x=x, ord=ord, keepdims=keepdims, axis=(-2, -1)) + + +def cholesky(x: "jax.Array", /, *, upper: bool = False) -> "jax.Array": + L = jax.numpy.linalg.cholesky(x) + + return jnp.conj(L.swapaxes(-2, -1)) if upper else L + + +@functools.partial(jax.jit, static_argnames=("transpose", "upper", "unit_diagonal")) +def solve_triangular( + A: jax.numpy.ndarray, + b: jax.numpy.ndarray, + *, + transpose: bool = False, + upper: bool = False, + unit_diagonal: bool = False, +) -> jax.numpy.ndarray: + if b.ndim in (1, 2): + return jax.scipy.linalg.solve_triangular( + A, + b, + trans=1 if transpose else 0, + lower=not upper, + unit_diagonal=unit_diagonal, + ) + + @functools.partial(jax.numpy.vectorize, signature="(n,n),(n,k)->(n,k)") + def _solve_triangular_vectorized( + A: jax.numpy.ndarray, + b: jax.numpy.ndarray, + ) -> jax.numpy.ndarray: + return jax.scipy.linalg.solve_triangular( + A, + b, + trans=1 if transpose else 0, + lower=not upper, + unit_diagonal=unit_diagonal, + ) + + return _solve_triangular_vectorized(A, b) + + +@functools.partial(jax.jit, static_argnames=("upper", "overwrite_b", "check_finite")) +def solve_cholesky( + cholesky: jax.numpy.ndarray, + b: jax.numpy.ndarray, + *, + upper: bool = False, + overwrite_b: bool = False, + check_finite: bool = True, +): + @functools.partial(jax.numpy.vectorize, signature="(n,n),(n,k)->(n,k)") + def _cho_solve_vectorized( + cholesky: jax.numpy.ndarray, + b: jax.numpy.ndarray, + ): + return jax.scipy.linalg.cho_solve( + (cholesky, not upper), + b, + overwrite_b=overwrite_b, + check_finite=check_finite, + ) + + if b.ndim == 1: + return _cho_solve_vectorized( + cholesky, + b[:, None], + )[:, 0] + + return _cho_solve_vectorized(cholesky, b) + + +def qr( + x: "jax.Array", /, *, mode: Literal["reduced", "complete", "r"] = "reduced" +) -> Tuple["jax.Array", "jax.Array"]: + if mode == "r": + r = jnp.linalg.qr(x, mode=mode) + q = None + else: + q, r = jnp.linalg.qr(x, mode=mode) + + return q, r + + +def vecdot(x1: "jax.Array", x2: "jax.Array", axis: int = -1) -> "jax.Array": + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,) * (ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,) * (ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same shape along the given axis.") + + x1_, x2_ = jnp.broadcast_arrays(x1, x2) + x1_ = jnp.moveaxis(x1_, axis, -1) + x2_ = jnp.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return jnp.asarray(res[..., 0, 0]) + + +def svdvals(x: "jax.Array", /) -> "jax.Array": + return jnp.linalg.svd(x, compute_uv=False, hermitian=False) diff --git a/src/probnum/backend/linalg/_numpy.py b/src/probnum/backend/linalg/_numpy.py new file mode 100644 index 000000000..5bc6ec1fc --- /dev/null +++ b/src/probnum/backend/linalg/_numpy.py @@ -0,0 +1,187 @@ +"""Implementation of linear algebra functionality in NumPy.""" + +import functools +from typing import Callable, Literal, Optional, Tuple, Union + +import numpy as np + +# pylint: disable=unused-import +from numpy import diagonal, einsum, kron, matmul, outer, tensordot, trace +from numpy.linalg import ( + det, + eigh, + eigvalsh, + inv, + matrix_power, + pinv, + slogdet, + solve, + svd, +) +import scipy.linalg + + +def matrix_rank( + x: np.ndarray, /, *, rtol: Optional[Union[float, np.ndarray]] = None +) -> np.ndarray: + return np.linalg.matrix_rank(x, tol=rtol) + + +def matrix_transpose(x: np.ndarray, /) -> np.ndarray: + return np.swapaxes(x, -2, -1) + + +def vector_norm( + x: np.ndarray, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal["inf", "-inf"]] = 2, +) -> np.ndarray: + return np.asarray(np.linalg.norm(x=x, ord=ord, keepdims=keepdims, axis=axis)) + + +def matrix_norm(x: np.ndarray, /, *, keepdims: bool = False, ord="fro") -> np.ndarray: + return np.asarray(np.linalg.norm(x=x, ord=ord, keepdims=keepdims, axis=(-2, -1))) + + +def cholesky(x: np.ndarray, /, *, upper: bool = False) -> np.ndarray: + try: + L = np.linalg.cholesky(x) + + return np.conj(L.swapaxes(-2, -1)) if upper else L + except np.linalg.LinAlgError: + return (np.triu if upper else np.tril)(np.full_like(x, np.nan)) + + +def solve_triangular( + A: np.ndarray, + b: np.ndarray, + *, + transpose: bool = False, + upper: bool = False, + unit_diagonal: bool = False, +) -> np.ndarray: + if b.ndim in (1, 2): + return scipy.linalg.solve_triangular( + A, + b, + trans=1 if transpose else 0, + lower=not upper, + unit_diagonal=unit_diagonal, + ) + + return _matmul_broadcasting( + functools.partial( + scipy.linalg.solve_triangular, + A, + trans=1 if transpose else 0, + lower=not upper, + unit_diagonal=unit_diagonal, + ), + b, + ) + + +def solve_cholesky( + cholesky: np.ndarray, + b: np.ndarray, + *, + upper: bool = False, + overwrite_b: bool = False, + check_finite: bool = True, +): + if b.ndim in (1, 2): + return scipy.linalg.cho_solve( + (cholesky, not upper), + b, + overwrite_b=overwrite_b, + check_finite=check_finite, + ) + + return _matmul_broadcasting( + functools.partial( + scipy.linalg.cho_solve, + (cholesky, not upper), + overwrite_b=overwrite_b, + check_finite=check_finite, + ), + b, + ) + + +def _matmul_broadcasting( + matmul_fn: Callable[[np.ndarray], np.ndarray], + x: np.ndarray, +) -> np.ndarray: + # In order to apply __matmul__ broadcasting, we need to reshape the stack of + # matrices `x` into a matrix whose first axis corresponds to the penultimate axis in + # the matrix stack and whose second axis is a flattened/raveled representation of + # all the remaining axes + + # We can handle a stack of vectors in a simplified manner + stack_of_vectors = x.shape[-1] == 1 + + if stack_of_vectors: + x_batch_first = x[..., 0] + else: + x_batch_first = np.swapaxes(x, -2, -1) + + x_batch_last = np.array(x_batch_first.T, copy=False, order="F") + + # Flatten the trailing axes and remember shape to undo flattening operation later + unflatten_shape = x_batch_last.shape[1:] + x_flat_batch_last = x_batch_last.reshape( + (x_batch_last.shape[0], -1), + order="F", + ) + + assert x_flat_batch_last.flags.f_contiguous + + res_flat_batch_last = np.array( + matmul_fn(x_flat_batch_last), + copy=False, + order="F", + ) + + # Undo flattening operation + res_batch_last = res_flat_batch_last.reshape((-1,) + unflatten_shape, order="F") + + res_batch_first = res_batch_last.T + + if stack_of_vectors: + return res_batch_first[..., None] + + return np.swapaxes(res_batch_first, -2, -1) + + +def qr( + x: np.ndarray, /, *, mode: Literal["reduced", "complete", "r"] = "reduced" +) -> Tuple[np.ndarray, np.ndarray]: + if mode == "r": + r = np.linalg.qr(x, mode=mode) + q = None + else: + q, r = np.linalg.qr(x, mode=mode) + + return q, r + + +def vecdot(x1: np.ndarray, x2: np.ndarray, axis: int = -1) -> np.ndarray: + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,) * (ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,) * (ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same shape along the given axis.") + + x1_, x2_ = np.broadcast_arrays(x1, x2) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return np.asarray(res[..., 0, 0]) + + +def svdvals(x: np.ndarray, /) -> np.ndarray: + return np.linalg.svd(x, compute_uv=False, hermitian=False) diff --git a/src/probnum/utils/linalg/_orthogonalize.py b/src/probnum/backend/linalg/_orthogonalize.py similarity index 92% rename from src/probnum/utils/linalg/_orthogonalize.py rename to src/probnum/backend/linalg/_orthogonalize.py index 745cf1c37..9db8fc562 100644 --- a/src/probnum/utils/linalg/_orthogonalize.py +++ b/src/probnum/backend/linalg/_orthogonalize.py @@ -7,7 +7,7 @@ from probnum import linops -from ._inner_product import induced_norm, inner_product as inner_product_fn +from ._inner_product import induced_vector_norm, inner_product as inner_product_fn def gram_schmidt( @@ -49,10 +49,10 @@ def gram_schmidt( if inner_product is None: inprod_fn = inner_product_fn - norm_fn = partial(induced_norm, axis=-1) + norm_fn = partial(induced_vector_norm, axis=-1) elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)): inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product) - norm_fn = lambda v: induced_norm(v, A=inner_product, axis=-1) + norm_fn = lambda v: induced_vector_norm(v, A=inner_product, axis=-1) else: inprod_fn = inner_product norm_fn = lambda v: np.sqrt(inprod_fn(v, v)) @@ -68,7 +68,7 @@ def gram_schmidt( return v_orth -def modified_gram_schmidt( +def gram_schmidt_modified( v: np.ndarray, orthogonal_basis: Iterable[np.ndarray], inner_product: Optional[ @@ -108,10 +108,10 @@ def modified_gram_schmidt( if inner_product is None: inprod_fn = inner_product_fn - norm_fn = induced_norm + norm_fn = induced_vector_norm elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)): inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product) - norm_fn = lambda v: induced_norm(v, A=inner_product) + norm_fn = lambda v: induced_vector_norm(v, A=inner_product) else: inprod_fn = inner_product norm_fn = lambda v: np.sqrt(inprod_fn(v, v)) @@ -127,7 +127,7 @@ def modified_gram_schmidt( return v_orth -def double_gram_schmidt( +def gram_schmidt_double( v: np.ndarray, orthogonal_basis: Iterable[np.ndarray], inner_product: Optional[ @@ -138,7 +138,7 @@ def double_gram_schmidt( ] ] = None, normalize: bool = False, - gram_schmidt_fn: Callable = modified_gram_schmidt, + gram_schmidt_fn: Callable = gram_schmidt_modified, ) -> np.ndarray: r"""Perform the (modified) Gram-Schmidt process twice. diff --git a/src/probnum/backend/linalg/_torch.py b/src/probnum/backend/linalg/_torch.py new file mode 100644 index 000000000..32be946d9 --- /dev/null +++ b/src/probnum/backend/linalg/_torch.py @@ -0,0 +1,130 @@ +"""Implementation of linear algebra functionality in PyTorch.""" + +from typing import Literal, Optional, Tuple, Union + +try: + import torch + + # pylint: disable=unused-import + from torch import diagonal, kron, matmul, outer, tensordot + from torch.linalg import ( + det, + eigh, + eigvalsh, + inv, + matrix_power, + matrix_rank, + pinv, + qr, + slogdet, + solve, + svd, + svdvals, + vecdot, + ) +except ModuleNotFoundError: + pass + + +def matrix_transpose(x: "torch.Tensor", /) -> "torch.Tensor": + return torch.transpose(x, -2, -1) + + +def trace(x: "torch.Tensor", /, *, offset: int = 0) -> "torch.Tensor": + if offset != 0: + raise NotImplementedError + + return torch.trace(x) + + +def pinv( + x: "torch.Tensor", rtol: Optional[Union[float, "torch.Tensor"]] = None +) -> "torch.Tensor": + return torch.linalg.pinv(x, rtol=rtol) + + +def einsum( + *arrays: "torch.Tensor", + optimization: Optional[str] = "greedy", +): + return torch.einsum(*arrays) + + +def vector_norm( + x: "torch.Tensor", + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal["inf", "-inf"]] = 2, +) -> "torch.Tensor": + return torch.linalg.vector_norm(x, ord=ord, dim=axis, keepdim=keepdims) + + +def matrix_norm( + x: "torch.Tensor", /, *, keepdims: bool = False, ord="fro" +) -> "torch.Tensor": + return torch.linalg.matrix_norm(x, ord=ord, dim=(-2, -1), keepdim=keepdims) + + +def norm( + x: "torch.Tensor", + ord: Optional[Union[int, str]] = None, + axis: Optional[Tuple[int, ...]] = None, + keepdims: bool = False, +): + return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims) + + +def cholesky(x: "torch.Tensor", /, *, upper: bool = False) -> "torch.Tensor": + try: + return torch.linalg.cholesky(x, upper=upper) + except RuntimeError: + return (torch.triu if upper else torch.tril)(torch.full_like(x, float("nan"))) + + +def solve_triangular( + A: "torch.Tensor", + b: "torch.Tensor", + *, + transpose: bool = False, + upper: bool = False, + unit_diagonal: bool = False, +) -> "torch.Tensor": + if b.ndim == 1: + return torch.triangular_solve( + b[:, None], + A, + upper=upper, + transpose=transpose, + unitriangular=unit_diagonal, + ).solution[:, 0] + + return torch.triangular_solve( + b, + A, + upper=upper, + transpose=transpose, + unitriangular=unit_diagonal, + ).solution + + +def solve_cholesky( + cholesky: "torch.Tensor", + b: "torch.Tensor", + *, + upper: bool = False, + overwrite_b: bool = False, + check_finite: bool = True, +): + if b.ndim == 1: + return torch.cholesky_solve(b[:, None], cholesky, upper=upper)[:, 0] + + return torch.cholesky_solve(b, cholesky, upper=upper) + + +def qr( + x: "torch.Tensor", /, *, mode: Literal["reduced", "complete", "r"] = "reduced" +) -> Tuple["torch.Tensor", "torch.Tensor"]: + q, r = torch.linalg.qr(x, mode=mode) + return q, r diff --git a/src/probnum/backend/random/__init__.py b/src/probnum/backend/random/__init__.py new file mode 100644 index 000000000..70d3f23ed --- /dev/null +++ b/src/probnum/backend/random/__init__.py @@ -0,0 +1,296 @@ +"""Functionality for random number generation.""" +from __future__ import annotations + +from typing import Optional, Sequence, Union + +from probnum.backend.typing import FloatLike, SeedType, ShapeLike + +from .. import Array, DType, asscalar, asshape, float64 +from ... import config +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = [ + "RNGState", + "rng_state", + "split", + "choice", + "gamma", + "permutation", + "standard_normal", + "uniform", + "uniform_so_group", +] + + +RNGState = _impl.RNGState +"""State of the random number generator.""" + + +def rng_state(seed: SeedType) -> RNGState: + """Create a state of a random number generator from a seed. + + Parameters + ---------- + seed + Seed for the random number generator. + + Returns + ------- + rng_state + State of a random number generator. + """ + return _impl.rng_state(seed=seed) + + +def split(rng_state: RNGState, num: int = 2) -> Sequence[RNGState]: + """Split the random number generator state into multiple. + + Parameters + ---------- + rng_state + Base RNG state. + num + Number of RNG states to split into. + + Returns + ------- + rng_states + Sequence of RNG states. + """ + return _impl.split(rng_state=rng_state, num=num) + + +def choice( + rng_state: RNGState, + x: Union[int, Array], + shape: ShapeLike = (), + replace: bool = True, + p: Optional[Array] = None, + axis: int = 0, +) -> Array: + """Generate a random sample from a given array. + + Parameters + ---------- + rng_state + Random number generator state. + x + If a :class:`~probnum.backend.Array`, a random sample is generated from its + elements. If an `int`, the random sample is generated as if it were + :code:`backend`.arange(x)`. + shape + Sample shape. + replace + Whether the sample is with or without replacement. + p + The probabilities associated with each entry in ``x``. If not given, the sample + assumes a uniform distribution over all entries in ``x``. + axis + The axis along which the selection is performed. + """ + return _impl.choice( + rng_state=rng_state, + x=x, + shape=asshape(shape), + replace=replace, + p=p, + axis=axis, + ) + + +def gamma( + rng_state: RNGState, + shape_param: FloatLike, + scale_param: FloatLike = 1.0, + shape: ShapeLike = (), + *, + dtype: DType = None, +) -> Array: + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, shape + (sometimes designated “k”) and scale (sometimes designated “theta”), where both + parameters are > 0. + + Parameters + ---------- + rng_state + Random number generator state. + shape_param + Shape parameter of the Gamma distribution. + scale_param + Scale parameter of the Gamma distribution. + shape + Sample shape. + dtype + Sample data type. + + Returns + ------- + samples + Samples from the Gamma distribution. + """ + if dtype is None: + dtype = config.default_floating_dtype + return _impl.gamma( + rng_state=rng_state, + shape_param=asscalar(shape_param), + scale_param=asscalar(scale_param), + shape=asshape(shape), + dtype=dtype, + ) + + +def permutation( + rng_state: RNGState, + x: Union[int, Array], + *, + axis: int = 0, + independent: bool = False, +) -> Array: + """Returns a randomly permuted array or range. + + Parameters + ---------- + rng_state + Random number generator state. + x + If ``x`` is an integer, randomly permute ``~probnum.backend.arange(x)``. + If ``x`` is an array, make a copy and shuffle the elements + randomly. + axis + The axis which ``x`` is shuffled along. Default is 0. + independent + If set to ``True``, each individual vector along the given axis is shuffled + independently. Default is ``False``. + + Returns + ------- + out + Permuted array or array range. + """ + return _impl.permutation( + rng_state=rng_state, x=x, axis=axis, independent=independent + ) + + +def standard_normal( + rng_state: RNGState, + shape: ShapeLike = (), + *, + dtype: DType = None, +) -> Array: + """Draw samples from a standard Normal distribution (mean=0, stdev=1). + + Parameters + ---------- + rng_state + Random number generator state. + shape + Sample shape. + dtype + Sample data type. + + Returns + ------- + samples + Samples from the standard normal distribution. + """ + if dtype is None: + dtype = config.default_floating_dtype + return _impl.standard_normal( + rng_state=rng_state, + shape=asshape(shape), + dtype=dtype, + ) + + +def uniform( + rng_state: RNGState, + shape: ShapeLike = (), + *, + dtype: DType = None, + minval: FloatLike = 0.0, + maxval: FloatLike = 1.0, +) -> Array: + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval ``[minval, maxval)`` + (includes ``minval``, but excludes ``maxval``). In other words, any value within the + given interval is equally likely to be drawn by :meth:`uniform`. + + Parameters + ---------- + rng_state + Random number generator state. + shape + Sample shape. + dtype + Sample data type. + minval + Lower bound of the sampled values. All values generated will be greater than + or equal to ``minval``. + maxval + Upper bound of the sampled values. All values generated will be strictly smaller + than ``maxval``. + + Returns + ------- + samples + Samples from the uniform distribution. + """ + if dtype is None: + dtype = config.default_floating_dtype + return _impl.uniform( + rng_state=rng_state, + shape=asshape(shape), + dtype=dtype, + minval=asscalar(minval, dtype=dtype), + maxval=asscalar(maxval, dtype=dtype), + ) + + +def uniform_so_group( + rng_state: RNGState, + n: int, + shape: ShapeLike = (), + *, + dtype: DType = None, +) -> Array: + """Draw samples from the Haar distribution, i.e. from the uniform distribution on + :math:`SO(n)`. + + The generated samples are randomly drawn orthogonal matrices with determinant 1, + i.e. elements of the special orthogonal group :math:`SO(n)`. + + Parameters + ---------- + rng_state + Random number generator state. + n + Matrix dimension. + shape + Sample shape. + dtype + Sample data type. + + Returns + ------- + samples + Samples from the Haar distribution. + """ + if dtype is None: + dtype = config.default_floating_dtype + return _impl.uniform_so_group( + rng_state=rng_state, + n=n, + shape=asshape(shape), + dtype=dtype, + ) diff --git a/src/probnum/backend/random/_jax.py b/src/probnum/backend/random/_jax.py new file mode 100644 index 000000000..1ac34ccde --- /dev/null +++ b/src/probnum/backend/random/_jax.py @@ -0,0 +1,141 @@ +"""Functionality for random number generation implemented in the JAX backend.""" +from __future__ import annotations + +import functools +import secrets +from typing import Optional, Sequence, Union + +try: + import jax + from jax import numpy as jnp +except ModuleNotFoundError: + pass + +from probnum.backend.typing import SeedType, ShapeType + +RNGState = jax.random.PRNGKey + + +def rng_state(seed: SeedType) -> RNGState: + if seed is None: + seed = secrets.randbits(128) + + if not isinstance(seed, int): + return seed + + return jax.random.PRNGKey(seed) + + +def split(rng_state: RNGState, num: int = 2) -> Sequence[RNGState]: + return jax.random.split(key=rng_state, num=num) + + +def choice( + rng_state: RNGState, + x: Union[int, "jax.Array"], + shape: ShapeType = (), + replace: bool = True, + p: Optional["jax.Array"] = None, + axis: int = 0, +) -> "jax.Array": + return jax.random.choice( + key=rng_state, a=x, shape=shape, replace=replace, p=p, axis=axis + ) + + +def uniform( + rng_state: RNGState, + shape: ShapeType = (), + dtype: "jnp.dtype" = None, + minval: "jax.Array" = jnp.array(0.0), + maxval: "jax.Array" = jnp.array(1.0), +) -> "jax.Array": + return jax.random.uniform( + key=rng_state, shape=shape, dtype=dtype, minval=minval, maxval=maxval + ) + + +def standard_normal( + rng_state: RNGState, + shape: ShapeType = (), + dtype: jnp.dtype = None, +) -> "jax.Array": + return jax.random.normal(key=rng_state, shape=shape, dtype=dtype) + + +def gamma( + rng_state: RNGState, + shape_param: "jax.Array", + scale_param: "jax.Array" = jnp.array(1.0), + shape: ShapeType = (), + dtype: jnp.dtype = None, +) -> "jax.Array": + return ( + jax.random.gamma(key=rng_state, a=shape_param, shape=shape, dtype=dtype) + * scale_param + ) + + +@functools.partial(jax.jit, static_argnames=("n", "shape", "dtype")) +def uniform_so_group( + rng_state: RNGState, + n: int, + shape: ShapeType = (), + dtype: jnp.dtype = None, +) -> "jax.Array": + if n == 1: + return jnp.ones(shape + (1, 1), dtype=dtype) + + return _uniform_so_group_pushforward_fn( + standard_normal(rng_state, shape=shape + (n - 1, n), dtype=dtype) + ) + + +@functools.partial(jnp.vectorize, signature="(M,N)->(N,N)") +def _uniform_so_group_pushforward_fn(omega: "jax.Array") -> "jax.Array": + n = omega.shape[1] + + assert omega.shape == (n - 1, n) + + X = jnp.triu(omega) + + X_diag = jnp.diag(X) + D = jnp.vectorize( + lambda x: jax.lax.cond( + x != 0, + lambda x: jnp.sign(x), + lambda _: jnp.ones((), dtype=omega.dtype), + x, + ), + )(X_diag) + + row_norms_sq = jnp.sum(X**2, axis=1) + + X = X.at[jnp.diag_indices(n - 1)].set(jnp.sqrt(row_norms_sq) * D) + X /= jnp.sqrt((row_norms_sq - X_diag**2 + jnp.diag(X) ** 2) / 2.0)[:, None] + + H = jax.lax.fori_loop( + lower=0, + upper=n - 1, + body_fun=lambda idx, H: H - jnp.outer(H @ X[idx, :], X[idx, :]), + init_val=jnp.eye(n, dtype=omega.dtype), + ) + + D = jnp.append( + D, + (-1.0 if n % 2 == 0 else 1.0) * jnp.prod(D), + ) + + return D[:, None] * H + + +def permutation( + rng_state: RNGState, + x: Union[int, "jax.Array"], + *, + axis: int = 0, + independent: bool = False, +): + return jax.random.permutation( + key=rng_state, x=x, axis=axis, independent=independent + ) diff --git a/src/probnum/backend/random/_numpy.py b/src/probnum/backend/random/_numpy.py new file mode 100644 index 000000000..ca3fafde7 --- /dev/null +++ b/src/probnum/backend/random/_numpy.py @@ -0,0 +1,139 @@ +"""Functionality for random number generation implemented in the NumPy backend.""" +from __future__ import annotations + +import functools +from typing import Optional, Sequence, Union + +import numpy as np + +from probnum import backend +from probnum.backend.typing import SeedType, ShapeType + +RNGState = np.random.SeedSequence + + +def rng_state(seed: SeedType) -> RNGState: + return np.random.SeedSequence(seed) + + +def split(rng_state: RNGState, num: int = 2) -> Sequence[RNGState]: + return rng_state.spawn(num) + + +def _rng_from_rng_state(rng_state: RNGState) -> np.random.Generator: + """Create a random generator instance initialized with the given state.""" + if not isinstance(rng_state, RNGState): + raise TypeError(f"`rng_state`s should always have type {RNGState.__name__}.") + + return np.random.default_rng(rng_state) + + +def choice( + rng_state: RNGState, + x: Union[int, np.ndarray], + shape: ShapeType = (), + replace: bool = True, + p: Optional[np.ndarray] = None, + axis: int = 0, +) -> np.ndarray: + return _rng_from_rng_state(rng_state).choice( + a=x, size=shape, replace=replace, p=p, axis=axis, shuffle=True + ) + + +def uniform( + rng_state: RNGState, + shape: ShapeType = (), + dtype: backend.DType = np.double, + minval: np.ndarray = np.array(0.0), + maxval: np.ndarray = np.array(1.0), +) -> np.ndarray: + return np.asarray( + (maxval - minval) + * _rng_from_rng_state(rng_state).random( + size=shape, + dtype=dtype, + ) + + minval + ) + + +def standard_normal( + rng_state: RNGState, + shape: ShapeType = (), + dtype: np.dtype = np.double, +) -> np.ndarray: + return np.asarray( + _rng_from_rng_state(rng_state).standard_normal(size=shape, dtype=dtype) + ) + + +def gamma( + rng_state: RNGState, + shape_param: np.ndarray, + scale_param: np.ndarray = np.array(1.0), + shape: ShapeType = (), + dtype: np.dtype = np.double, +) -> np.ndarray: + return np.asarray( + _rng_from_rng_state(rng_state).standard_gamma( + shape=shape_param, size=shape, dtype=dtype + ) + * scale_param + ) + + +def uniform_so_group( + rng_state: RNGState, + n: int, + shape: ShapeType = (), + dtype: np.dtype = np.double, +) -> np.ndarray: + if n == 1: + return np.ones(shape + (1, 1), dtype=dtype) + + return np.asarray( + _uniform_so_group_pushforward_fn( + standard_normal(rng_state, shape=shape + (n - 1, n), dtype=dtype) + ) + ) + + +@functools.partial(np.vectorize, signature="(M,N)->(N,N)") +def _uniform_so_group_pushforward_fn(omega: np.ndarray) -> np.ndarray: + n = omega.shape[1] + + assert omega.shape == (n - 1, n) + + X = np.triu(omega) + + # Copied and modified from https://github.com/scipy/scipy/blob/1c98aa98a55e2aaf2c15c16b47ee5e258bfcd170/scipy/stats/_multivariate.py#L3373-L3387 + H = np.eye(n, dtype=omega.dtype) + D = np.empty((n,), dtype=omega.dtype) + for idx in range(n - 1): + x = X[idx, idx:] + norm2 = np.dot(x, x) + x0 = x[0].item() + D[idx] = np.sign(x[0]) if x[0] != 0 else 1 + x[0] += D[idx] * np.sqrt(norm2) + x /= np.sqrt((norm2 - x0**2 + x[0] ** 2) / 2.0) + # Householder transformation + H[:, idx:] -= np.outer(np.dot(H[:, idx:], x), x) + D[-1] = (-1) ** (n - 1) * D[:-1].prod() + # Equivalent to np.dot(np.diag(D), H) but faster, apparently + H = (D * H.T).T + return H + + +def permutation( + rng_state: RNGState, + x: Union[int, np.ndarray], + *, + axis: int = 0, + independent: bool = False, +): + rng = _rng_from_rng_state(rng_state) + if independent: + return rng.permuted(x=x, axis=axis, out=None) + else: + return rng.permutation(x=x, axis=axis) diff --git a/src/probnum/backend/random/_torch.py b/src/probnum/backend/random/_torch.py new file mode 100644 index 000000000..1670fc42f --- /dev/null +++ b/src/probnum/backend/random/_torch.py @@ -0,0 +1,181 @@ +"""Functionality for random number generation implemented in the PyTorch backend.""" +from __future__ import annotations + +from typing import Optional, Sequence, Union + +import numpy as np + +try: + import torch + from torch.distributions.utils import broadcast_all +except ModuleNotFoundError: + pass + +from probnum import backend +from probnum.backend.typing import SeedType, ShapeType + +RNGState = np.random.SeedSequence + + +def rng_state(seed: SeedType) -> RNGState: + return np.random.SeedSequence(seed) + + +def split(rng_state: RNGState, num: int = 2) -> Sequence[RNGState]: + return rng_state.spawn(num) + + +def _rng_from_rng_state(rng_state: RNGState) -> torch.Generator: + """Create a random generator instance initialized with the given state.""" + + if not isinstance(rng_state, RNGState): + raise TypeError(f"`rng_state`s should always have type {RNGState.__name__}.") + + rng = torch.Generator() + return rng.manual_seed(int(rng_state.generate_state(1, dtype=np.uint64)[0])) + + +def choice( + rng_state: RNGState, + x: Union[int, "torch.Tensor"], + shape: ShapeType = (), + replace: bool = True, + p: Optional["torch.Tensor"] = None, + axis: int = 0, +) -> "torch.Tensor": + idcs = torch.multinomial( + generator=_rng_from_rng_state(rng_state), + input=p, + num_samples=shape, + replacement=replace, + ) + if backend.isarray(x): + return torch.index_select(input=x, dim=axis, index=idcs) + else: + return idcs + + +def uniform( + rng_state: RNGState, + shape: ShapeType = (), + dtype: "torch.dtype" = None, + minval: float = None, + maxval: float = None, +) -> "torch.Tensor": + rng = _rng_from_rng_state(rng_state) + return (maxval - minval) * torch.rand(shape, generator=rng, dtype=dtype) + minval + + +def standard_normal( + rng_state: RNGState, + shape: ShapeType = (), + dtype: "torch.dtype" = None, +) -> "torch.Tensor": + rng = _rng_from_rng_state(rng_state) + + return torch.randn(shape, generator=rng, dtype=dtype) + + +def gamma( + rng_state: RNGState, + shape_param: "torch.Tensor", + scale_param: "torch.Tensor", + shape: ShapeType = (), + dtype: "torch.dtype" = None, +) -> "torch.Tensor": + rng = _rng_from_rng_state(rng_state) + + shape_param = torch.as_tensor(shape_param, dtype=dtype) + scale_param = torch.as_tensor(scale_param, dtype=dtype) + + # Adapted version of + # https://github.com/pytorch/pytorch/blob/afff38182457f3500c265f232310438dded0e57d/torch/distributions/gamma.py#L59-L63 + shape_param, scale_param = broadcast_all(shape_param, scale_param) + + res_shape = shape + shape_param.shape + + return torch._standard_gamma( + shape_param.expand(res_shape), rng + ) * scale_param.expand(res_shape) + + +def uniform_so_group( + rng_state: RNGState, + n: int, + shape: ShapeType = (), + dtype: "torch.dtype" = None, +) -> "torch.Tensor": + if n == 1: + return torch.ones(shape + (1, 1), dtype=dtype) + + omega = standard_normal(rng_state, shape=shape + (n - 1, n), dtype=dtype) + + sample = _uniform_so_group_pushforward_fn(omega.reshape((-1, n - 1, n))) + + return sample.reshape(shape + (n, n)) + + +try: + + @torch.jit.script + def _uniform_so_group_pushforward_fn(omega: "torch.Tensor") -> "torch.Tensor": + n = omega.shape[-1] + + assert omega.ndim == 3 and omega.shape[-2] == n - 1 + + samples = [] + + for sample_idx in range(omega.shape[0]): + X = torch.triu(omega[sample_idx, :, :]) + X_diag = torch.diag(X) + + D = torch.where( + X_diag != 0, + torch.sign(X_diag), + torch.ones((), dtype=omega.dtype), + ) + + row_norms_sq = torch.sum(X**2, dim=1) + + diag_indices = torch.arange(n - 1) + X[diag_indices, diag_indices] = torch.sqrt(row_norms_sq) * D + + X /= torch.sqrt((row_norms_sq - X_diag**2 + torch.diag(X) ** 2) / 2.0)[ + :, None + ] + + H = torch.eye(n, dtype=omega.dtype) + + for idx in range(n - 1): + H -= torch.outer(H @ X[idx, :], X[idx, :]) + + D = torch.cat( + ( + D, + (-1.0 if n % 2 == 0 else 1.0) * torch.prod(D, dim=0, keepdim=True), + ), + dim=0, + ) + + samples.append(D[:, None] * H) + + return torch.stack(samples, dim=0) + +except (ModuleNotFoundError, NameError): + pass + + +def permutation( + rng_state: RNGState, + x: Union[int, "torch.Tensor"], + *, + axis: int = 0, + independent: bool = False, +): + rng = _rng_from_rng_state(rng_state) + if independent: + idx = torch.argsort(torch.rand(*x.shape, generator=rng), dim=axis) + return torch.gather(x, dim=axis, index=idx) + else: + idx = torch.randperm(x.shape[axis], generator=rng) + return torch.index_select(x, axis, idx) diff --git a/src/probnum/backend/special/__init__.py b/src/probnum/backend/special/__init__.py new file mode 100644 index 000000000..bf2b46312 --- /dev/null +++ b/src/probnum/backend/special/__init__.py @@ -0,0 +1,95 @@ +"""Special functions.""" +from probnum.backend.typing import FloatLike + +from .. import Array, asscalar +from ..._select_backend import BACKEND, Backend + +if BACKEND is Backend.NUMPY: + from . import _numpy as _impl +elif BACKEND is Backend.JAX: + from . import _jax as _impl +elif BACKEND is Backend.TORCH: + from . import _torch as _impl + +__all__ = [ + "gamma", + "modified_bessel2", + "ndtr", + "ndtri", +] +__all__.sort() + + +def gamma(x: Array, /) -> Array: + r"""Gamma function. + + Evaluates the gamma function defined as + + .. math:: + + \Gamma(x) = \int_0^\infty t^{x-1}e^{-t}\,dt + + for :math:`\text{Real}(x) > 0` and is extended to the rest of the complex plane by + analytic continuation. + + The gamma function is often referred to as the generalized factorial since + :math:`\Gamma(n+1) = n!` for natural numbers :math:`n`. More generally it satisfies + the recurrence relation :math:`\Gamma(x + 1) = x \Gamma(x)` for complex :math:`x`, + which, combined with the fact that :math:`\Gamma(1)=1`, implies the above. + + Parameters + ---------- + x + Argument(s) at which to evaluate the gamma function. + """ + return _impl.gamma(x) + + +def modified_bessel2(x: Array, /, *, order: FloatLike) -> Array: + """Modified Bessel function of the second kind of the given order. + + Parameters + ---------- + x + Argument(s) at which to evaluate the Bessel function. + order + Order of Bessel function. + """ + return _impl.modified_bessel2(x, order) + + +def ndtr(x: Array, /) -> Array: + r"""Normal distribution function. + + Returns the area under the Gaussian probability density function, integrated + from minus infinity to x: + + .. math:: + + \begin{align} + \mathrm{ndtr}(x) =& + \ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\ + =&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\ + =&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}}) + \end{align} + + Parameters + ---------- + x + Argument(s) at which to evaluate the Normal distribution function. + """ + return _impl.ndtr(x) + + +def ndtri(p: Array, /) -> Array: + r"""The inverse of the CDF of the Normal distribution function. + + Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal + to `p`. + + Parameters + ---------- + p + Argument(s) at which to evaluate the inverse Normal distribution function. + """ + return _impl.ndtri(p) diff --git a/src/probnum/backend/special/_jax.py b/src/probnum/backend/special/_jax.py new file mode 100644 index 000000000..49d4b0a23 --- /dev/null +++ b/src/probnum/backend/special/_jax.py @@ -0,0 +1,15 @@ +"""Special functions in JAX.""" + +try: + import jax.numpy as jnp + from jax.scipy.special import ndtr, ndtri # pylint: disable=unused-import +except ModuleNotFoundError: + pass + + +def modified_bessel2(x: "jax.Array", order: "jax.Array") -> "jax.Array": + return NotImplementedError + + +def gamma(x: "jax.Array", /) -> "jax.Array": + raise NotImplementedError diff --git a/src/probnum/backend/special/_numpy.py b/src/probnum/backend/special/_numpy.py new file mode 100644 index 000000000..8ab061012 --- /dev/null +++ b/src/probnum/backend/special/_numpy.py @@ -0,0 +1,7 @@ +"""Special functions in NumPy / SciPy.""" +import numpy as np +from scipy.special import gamma, kv, ndtr, ndtri # pylint: disable=unused-import + + +def modified_bessel2(x: np.ndarray, order: float) -> np.ndarray: + return kv(order, x) diff --git a/src/probnum/backend/special/_torch.py b/src/probnum/backend/special/_torch.py new file mode 100644 index 000000000..b54375fdf --- /dev/null +++ b/src/probnum/backend/special/_torch.py @@ -0,0 +1,15 @@ +"""Special functions in PyTorch.""" + +try: + import torch + from torch.special import ndtr, ndtri +except ModuleNotFoundError: + pass + + +def gamma(x: torch.Tensor, /) -> torch.Tensor: + raise NotImplementedError + + +def modified_bessel2(x: torch.Tensor, order: torch.Tensor) -> torch.Tensor: + return NotImplementedError diff --git a/src/probnum/backend/typing.py b/src/probnum/backend/typing.py new file mode 100644 index 000000000..e4799d914 --- /dev/null +++ b/src/probnum/backend/typing.py @@ -0,0 +1,127 @@ +"""Custom type aliases. + +This module defines commonly used types in the library. These are separated into two +different kinds, API types and argument types. + +**API types** (``*Type``) are aliases which define custom types used throughout the +library. Objects of this type may be supplied as arguments or returned by a method. + +**Argument types** (``*Like``) are aliases which define commonly used method +arguments that are internally converted to a standardized representation. These should +only ever be used in the signature of a method and then be converted internally, e.g. in +a class instantiation or an interface. They enable the user to conveniently supply a +variety of objects of different types for the same argument, while ensuring a unified +internal representation of those same objects. As an example, take the different ways a +user might specify a shape: ``2``, ``(2,)``, ``[2, 2]``. These may all be acceptable +arguments to a function taking a shape, but internally should always be converted to a +:attr:`ShapeType`, i.e. a tuple of ``int``\\ s. +""" + +from __future__ import annotations + +import numbers +from typing import Iterable, Optional, Tuple, Union + +import numpy as np +from numpy.typing import ArrayLike as _NumPyArrayLike, DTypeLike as _NumPyDTypeLike + +from ._array_object import Array, Scalar + +__all__ = [ + # API Types + "ShapeType", + "SeedType", + # Argument Types + "IntLike", + "FloatLike", + "ShapeLike", + "DTypeLike", + "ArrayIndicesLike", + "ScalarLike", + "ArrayLike", + "NotImplementedType", +] + +######################################################################################## +# API Types +######################################################################################## + +# Array Utilities +ShapeType = Tuple[int, ...] +"""Type defining a shape of an object.""" + +# Random Number Generation +SeedType = Optional[int] +"""Type defining a seed of a random number generator. + +An object of type :attr:`SeedType` is used to initialize the state of a random number +generator by passing ``seed`` to :func:`~probnum.backend.random.rng_state`.""" + +######################################################################################## +# Argument Types +######################################################################################## + +# Python Numbers +IntLike = Union[int, numbers.Integral, np.integer] +"""Object that can be converted to an integer. + +Arguments of type :attr:`IntLike` should always be converted into :class:`int`\\ s +before further internal processing.""" + +FloatLike = Union[float, numbers.Real, np.floating] +"""Object that can be converted to a float. + +Arguments of type :attr:`FloatLike` should always be converteg into :class:`float`\\ s +before further internal processing.""" + +# Scalars, Arrays and Matrices +ScalarLike = Union[Scalar, int, float, complex, numbers.Number, np.number] +"""Object that can be converted to a scalar value. + +Arguments of type :attr:`ScalarLike` should always be converted into objects of +:class:`~probnum.backend.Scalar` using the function :func:`~probnum.backend.asscalar` +before further internal processing.""" + +ArrayLike = Union[Array, _NumPyArrayLike] +"""Object that can be converted to an array. + +Arguments of type :attr:`ArrayLike` should always be converted into +:class:`~probnum.backend.Array`\\ s +using the function :func:`~probnum.backend.asarray` before further internal +processing.""" + +# Array Utilities +ShapeLike = Union[IntLike, Iterable[IntLike]] +"""Object that can be converted to a shape. + +Arguments of type :attr:`ShapeLike` should always be converted into :class:`ShapeType` +using the function :func:`~probnum.backend.asshape` before further internal +processing.""" + +DTypeLike = Union["probnum.backend.DType", _NumPyDTypeLike] +"""Object that can be converted to an array dtype. + +Arguments of type :attr:`DTypeLike` should always be converted into +:class:`~probnum.backend.DType`\\ s before further internal processing.""" + +_ArrayIndexLike = Union[ + int, + slice, + type(Ellipsis), + None, + "probnum.backend.newaxis", + ArrayLike, +] +ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]] +"""Object that can be converted to indices of an array. + +Type of the argument to the :meth:`__getitem__` method of an +:class:`~probnum.backend.Array` or similar object. +""" + +######################################################################################## +# Other Types +######################################################################################## + +NotImplementedType = type(NotImplemented) +"""Type of the `NotImplemented` constant.""" diff --git a/src/probnum/compat/__init__.py b/src/probnum/compat/__init__.py new file mode 100644 index 000000000..85faebc60 --- /dev/null +++ b/src/probnum/compat/__init__.py @@ -0,0 +1,8 @@ +"""Compatibility functions. + +This module implements functions, which are typically applied to +:class:`~probnum.backend.Array`\\s, and extends their functionality to other objects. +""" + +from . import testing +from ._core import * diff --git a/src/probnum/compat/_core.py b/src/probnum/compat/_core.py new file mode 100644 index 000000000..e971f9447 --- /dev/null +++ b/src/probnum/compat/_core.py @@ -0,0 +1,93 @@ +from typing import Tuple, Union + +import numpy as np + +from probnum import backend, linops, randvars + +__all__ = [ + "to_numpy", + "cast", +] + + +def to_numpy(*xs: Union[backend.Array, linops.LinearOperator]) -> Tuple[np.ndarray]: + res = [] + + for x in xs: + if backend.isarray(x): + x = backend.to_numpy(x) + elif isinstance(x, linops.LinearOperator): + x = backend.to_numpy(x.todense()) + else: + x = np.asarray(x) + + res.append(x) + + if len(xs) == 1: + return res[0] + + return tuple(res) + + +def cast(a, dtype=None, casting="unsafe", copy=None): + if isinstance(a, linops.LinearOperator): + return a.astype(dtype=dtype, casting=casting, copy=copy) + + return backend.cast(a, dtype, casting=casting, copy=copy) + + +def atleast_1d( + *objs: Union[ + backend.Array, + linops.LinearOperator, + randvars.RandomVariable, + ] +) -> Union[ + Union[ + backend.Array, + linops.LinearOperator, + randvars.RandomVariable, + ], + Tuple[ + Union[ + backend.Array, + linops.LinearOperator, + randvars.RandomVariable, + ], + ..., + ], +]: + """Reshape arrays, linear operators and random variables to have at least 1 + dimension. + + Scalar inputs are converted to 1-dimensional arrays, whilst + higher-dimensional inputs are preserved. + + Parameters + ---------- + objs: + One or more input linear operators, random variables or arrays. + + Returns + ------- + res : + An array / random variable / linop or tuple of arrays / random variables / + linear operators, each with ``a.ndim >= 1``. + """ + res = [] + + for obj in objs: + if isinstance(obj, np.ndarray): + obj = np.atleast_1d(obj) + elif isinstance(obj, backend.Array): + obj = backend.atleast_1d(obj) + elif isinstance(obj, randvars.RandomVariable): + if obj.ndim == 0: + obj = obj.reshape((1,)) + + res.append(obj) + + if len(res) == 1: + return res[0] + + return tuple(res) diff --git a/src/probnum/compat/testing.py b/src/probnum/compat/testing.py new file mode 100644 index 000000000..cf0896301 --- /dev/null +++ b/src/probnum/compat/testing.py @@ -0,0 +1,144 @@ +from typing import Union + +import numpy as np + +from probnum import backend, linops + +from . import _core + +__all__ = [ + "assert_allclose", + "assert_array_equal", + "assert_equal", +] + + +def assert_equal( + actual: Union[backend.Array, linops.LinearOperator], + desired: Union[backend.Array, linops.LinearOperator], + /, + *, + err_msg: str = "", + verbose: bool = True, +): + """Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries, + :class:`~probnum.backend.Array`\\s, :class:`~probnum.linops.LinearOperator`\\s), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of ``actual`` and ``desired`` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual + The object to check. + desired + The expected object. + err_msg + The error message to be printed in case of failure. + verbose + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + """ + np.testing.assert_equal( + *_core.to_numpy(actual, desired), err_msg=err_msg, verbose=verbose + ) + + +def assert_allclose( + actual: Union[backend.Array, linops.LinearOperator], + desired: Union[backend.Array, linops.LinearOperator], + /, + *, + rtol: float = 1e-7, + atol: float = 0, + equal_nan: bool = True, + err_msg: str = "", + verbose: bool = True, +): + """Raises an AssertionError if two objects are not equal up to desired tolerance. + + The test compares the difference + between ``actual`` and ``desired`` to ``atol + rtol * abs(desired)``. + + Parameters + ---------- + actual + The ``actual`` object to check. + desired + The ``desired``, expected object. + rtol + Relative tolerance. + atol + Absolute tolerance. + equal_nan + If True, NaNs will compare equal. + err_msg + The error message to be printed in case of failure. + verbose + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If ``actual`` and ``desired`` are not equal up to specified precision. + """ + np.testing.assert_allclose( + *_core.to_numpy(actual, desired), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + err_msg=err_msg, + verbose=verbose, + ) + + +def assert_array_equal( + actual: Union[backend.Array, linops.LinearOperator], + desired: Union[backend.Array, linops.LinearOperator], + /, + *, + err_msg: str = "", + verbose: bool = True, +): + """Raises an AssertionError if two array-like objects are not equal. + + Given two array-like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + Parameters + ---------- + actual + The ``actual`` object to check. + desired + The ``desired``, expected object. + err_msg + The error message to be printed in case of failure. + verbose + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If ``actual`` and ``desired`` objects are not equal. + """ + np.testing.assert_array_equal( + *_core.to_numpy(actual, desired), err_msg=err_msg, verbose=verbose + ) diff --git a/src/probnum/conftest.py b/src/probnum/conftest.py index 7a2faa201..9821b3ddf 100644 --- a/src/probnum/conftest.py +++ b/src/probnum/conftest.py @@ -1,15 +1,18 @@ """Fixtures and configuration for doctests.""" import numpy as np -import pytest import probnum as pn +from probnum import backend + +import pytest @pytest.fixture(autouse=True) def autoimport_packages(doctest_namespace): """This fixture 'imports' standard packages automatically in order to avoid - boilerplate code in doctests""" + boilerplate code in doctests.""" doctest_namespace["pn"] = pn doctest_namespace["np"] = np + doctest_namespace["backend"] = backend diff --git a/src/probnum/diffeq/_odesolution.py b/src/probnum/diffeq/_odesolution.py index 38f8a69e8..113840792 100644 --- a/src/probnum/diffeq/_odesolution.py +++ b/src/probnum/diffeq/_odesolution.py @@ -11,7 +11,7 @@ import numpy as np from probnum import filtsmooth, randvars -from probnum.typing import ArrayLike, FloatLike, IntLike, ShapeLike +from probnum.backend.typing import ArrayLike, FloatLike, IntLike, ShapeLike class ODESolution(filtsmooth.TimeSeriesPosterior): diff --git a/src/probnum/diffeq/_odesolver.py b/src/probnum/diffeq/_odesolver.py index aa512b99c..13c362250 100644 --- a/src/probnum/diffeq/_odesolver.py +++ b/src/probnum/diffeq/_odesolver.py @@ -10,8 +10,8 @@ import numpy as np from probnum import problems +from probnum.backend.typing import FloatLike from probnum.diffeq import callbacks as callback_module # see below -from probnum.typing import FloatLike # From above: # One of the argument to solve() is called 'callback', diff --git a/src/probnum/diffeq/_perturbsolve_ivp.py b/src/probnum/diffeq/_perturbsolve_ivp.py index fcb3741e8..f9e29e0e5 100644 --- a/src/probnum/diffeq/_perturbsolve_ivp.py +++ b/src/probnum/diffeq/_perturbsolve_ivp.py @@ -8,8 +8,8 @@ import scipy.integrate from probnum import problems +from probnum.backend.typing import ArrayLike, FloatLike from probnum.diffeq import perturbed, stepsize -from probnum.typing import ArrayLike, FloatLike __all__ = ["perturbsolve_ivp"] diff --git a/src/probnum/diffeq/_probsolve_ivp.py b/src/probnum/diffeq/_probsolve_ivp.py index c0f3c733b..615ce882b 100644 --- a/src/probnum/diffeq/_probsolve_ivp.py +++ b/src/probnum/diffeq/_probsolve_ivp.py @@ -15,8 +15,8 @@ import numpy as np from probnum import problems, randprocs +from probnum.backend.typing import ArrayLike, FloatLike from probnum.diffeq import _utils, odefilter -from probnum.typing import ArrayLike, FloatLike __all__ = ["probsolve_ivp"] diff --git a/src/probnum/diffeq/odefilter/_odefilter.py b/src/probnum/diffeq/odefilter/_odefilter.py index 986173de0..3c263a399 100644 --- a/src/probnum/diffeq/odefilter/_odefilter.py +++ b/src/probnum/diffeq/odefilter/_odefilter.py @@ -5,7 +5,7 @@ import numpy as np import scipy.linalg -from probnum import filtsmooth, randprocs, randvars, utils +from probnum import backend, filtsmooth, randprocs, randvars from probnum.diffeq import _odesolver, _odesolver_state, stepsize from probnum.diffeq.odefilter import ( _odefilter_solution, @@ -213,7 +213,7 @@ def attempt_step(self, state, dt): noisy_component = randvars.Normal( mean=np.zeros(state.rv.shape), cov=state.rv.cov.copy(), - cov_cholesky=state.rv.cov_cholesky.copy(), + cache={"cov_cholesky": state.rv._cov_cholesky.copy()}, ) # Compute the measurements for the error-free component @@ -233,16 +233,16 @@ def attempt_step(self, state, dt): # we manually update only the covariance. # The first two are only matrix square-roots and will be turned into proper # Cholesky factors below. - pred_sqrtm = Phi @ noisy_component.cov_cholesky + pred_sqrtm = Phi @ noisy_component._cov_cholesky meas_sqrtm = H @ pred_sqrtm - full_meas_cov_cholesky = utils.linalg.cholesky_update( - meas_rv_error_free.cov_cholesky, meas_sqrtm + full_meas_cov_cholesky = backend.linalg.cholesky_update( + meas_rv_error_free._cov_cholesky, meas_sqrtm ) full_meas_cov = full_meas_cov_cholesky @ full_meas_cov_cholesky.T meas_rv = randvars.Normal( mean=meas_rv_error_free.mean, cov=full_meas_cov, - cov_cholesky=full_meas_cov_cholesky, + cache={"cov_cholesky": full_meas_cov_cholesky}, ) # Estimate local diffusion_model and error @@ -274,7 +274,7 @@ def attempt_step(self, state, dt): new_rv = randvars.Normal( mean=state.rv.mean.copy(), cov=state.rv.cov.copy(), - cov_cholesky=state.rv.cov_cholesky.copy(), + cache={"cov_cholesky": state.rv._cov_cholesky.copy()}, ) state = _odesolver_state.ODESolverState( ivp=state.ivp, @@ -294,24 +294,24 @@ def attempt_step(self, state, dt): # With the updated diffusion, we need to re-compute the covariances of the # predicted RV and measured RV. # The resulting predicted and measured RV are overwritten herein. - full_pred_cov_cholesky = utils.linalg.cholesky_update( - np.sqrt(local_diffusion) * pred_rv_error_free.cov_cholesky, pred_sqrtm + full_pred_cov_cholesky = backend.linalg.cholesky_update( + np.sqrt(local_diffusion) * pred_rv_error_free._cov_cholesky, pred_sqrtm ) full_pred_cov = full_pred_cov_cholesky @ full_pred_cov_cholesky.T pred_rv = randvars.Normal( mean=pred_rv_error_free.mean, cov=full_pred_cov, - cov_cholesky=full_pred_cov_cholesky, + cache={"cov_cholesky": full_pred_cov_cholesky}, ) - full_meas_cov_cholesky = utils.linalg.cholesky_update( - np.sqrt(local_diffusion) * meas_rv_error_free.cov_cholesky, meas_sqrtm + full_meas_cov_cholesky = backend.linalg.cholesky_update( + np.sqrt(local_diffusion) * meas_rv_error_free._cov_cholesky, meas_sqrtm ) full_meas_cov = full_meas_cov_cholesky @ full_meas_cov_cholesky.T meas_rv = randvars.Normal( mean=meas_rv_error_free.mean, cov=full_meas_cov, - cov_cholesky=full_meas_cov_cholesky, + cache={"cov_cholesky": full_meas_cov_cholesky}, ) else: @@ -319,19 +319,19 @@ def attempt_step(self, state, dt): # This has not been assembled as a standalone random variable yet, # but is needed for the update below. # (The measurement has been updated already.) - full_pred_cov_cholesky = utils.linalg.cholesky_update( - pred_rv_error_free.cov_cholesky, pred_sqrtm + full_pred_cov_cholesky = backend.linalg.cholesky_update( + pred_rv_error_free._cov_cholesky, pred_sqrtm ) full_pred_cov = full_pred_cov_cholesky @ full_pred_cov_cholesky.T pred_rv = randvars.Normal( mean=pred_rv_error_free.mean, cov=full_pred_cov, - cov_cholesky=full_pred_cov_cholesky, + cache={"cov_cholesky": full_pred_cov_cholesky}, ) # Gain needs manual catching up, too. Use it to compute the update crosscov = full_pred_cov @ H.T - gain = scipy.linalg.cho_solve((meas_rv.cov_cholesky, True), crosscov.T).T + gain = scipy.linalg.cho_solve((meas_rv._cov_cholesky, True), crosscov.T).T zero_data = np.zeros(meas_rv.mean.shape) filt_rv, _ = self.measurement_model.backward_realization( zero_data, pred_rv, rv_forwarded=meas_rv, gain=gain @@ -382,7 +382,7 @@ def postprocess(self, odesol): state=randvars.Normal( mean=rv.mean, cov=s * rv.cov, - cov_cholesky=np.sqrt(s) * rv.cov_cholesky, + cache={"cov_cholesky": np.sqrt(s) * rv._cov_cholesky}, ), ) diff --git a/src/probnum/diffeq/odefilter/_odefilter_solution.py b/src/probnum/diffeq/odefilter/_odefilter_solution.py index 8cb7ee90c..e3959cc82 100644 --- a/src/probnum/diffeq/odefilter/_odefilter_solution.py +++ b/src/probnum/diffeq/odefilter/_odefilter_solution.py @@ -4,9 +4,9 @@ import numpy as np -from probnum import filtsmooth, randvars, utils +from probnum import backend, filtsmooth, randvars +from probnum.backend.typing import ArrayLike, FloatLike, IntLike, ShapeLike from probnum.diffeq import _odesolution -from probnum.typing import ArrayLike, FloatLike, IntLike, ShapeLike class ODEFilterSolution(_odesolution.ODESolution): @@ -150,5 +150,5 @@ def _project_rv(projmat, rv): new_mean = projmat @ rv.mean new_cov = projmat @ rv.cov @ projmat.T - new_cov_cholesky = utils.linalg.cholesky_update(projmat @ rv.cov_cholesky) - return randvars.Normal(new_mean, new_cov, cov_cholesky=new_cov_cholesky) + new_cov_cholesky = backend.linalg.cholesky_update(projmat @ rv._cov_cholesky) + return randvars.Normal(new_mean, new_cov, cache={"cov_cholesky": new_cov_cholesky}) diff --git a/src/probnum/diffeq/odefilter/information_operators/_information_operator.py b/src/probnum/diffeq/odefilter/information_operators/_information_operator.py index e0773f024..0cde398f5 100644 --- a/src/probnum/diffeq/odefilter/information_operators/_information_operator.py +++ b/src/probnum/diffeq/odefilter/information_operators/_information_operator.py @@ -6,7 +6,7 @@ import numpy as np from probnum import problems, randprocs, randvars -from probnum.typing import FloatLike, IntLike +from probnum.backend.typing import FloatLike, IntLike __all__ = ["InformationOperator", "ODEInformationOperator"] diff --git a/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py b/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py index 7c7a33896..60c5fa01b 100644 --- a/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py +++ b/src/probnum/diffeq/odefilter/information_operators/_ode_residual.py @@ -5,8 +5,8 @@ import numpy as np from probnum import problems, randprocs +from probnum.backend.typing import FloatLike, IntLike from probnum.diffeq.odefilter.information_operators import _information_operator -from probnum.typing import FloatLike, IntLike __all__ = ["ODEResidual"] diff --git a/src/probnum/diffeq/odefilter/init_routines/_autodiff.py b/src/probnum/diffeq/odefilter/init_routines/_autodiff.py index 933c00917..a4b42611d 100644 --- a/src/probnum/diffeq/odefilter/init_routines/_autodiff.py +++ b/src/probnum/diffeq/odefilter/init_routines/_autodiff.py @@ -54,7 +54,7 @@ def __call__( return randvars.Normal( mean=np.asarray(mean), cov=np.asarray(zeros), - cov_cholesky=np.asarray(zeros), + cache={"cov_cholesky": np.asarray(zeros)}, ) def _compute_ode_derivatives(self, *, f, y0, num_derivatives): diff --git a/src/probnum/diffeq/odefilter/init_routines/_non_probabilistic_fit.py b/src/probnum/diffeq/odefilter/init_routines/_non_probabilistic_fit.py index 6dd2a8a11..6928bd90c 100644 --- a/src/probnum/diffeq/odefilter/init_routines/_non_probabilistic_fit.py +++ b/src/probnum/diffeq/odefilter/init_routines/_non_probabilistic_fit.py @@ -8,7 +8,7 @@ import scipy.integrate as sci from probnum import filtsmooth, problems, randprocs, randvars -from probnum.typing import FloatLike +from probnum.backend.typing import FloatLike from ._interface import InitializationRoutine @@ -54,7 +54,7 @@ def _improve(self, *, data, prior_process): process_noise = randvars.Normal( mean=np.zeros(ode_dim), cov=np.diag(observation_noise_std**2), - cov_cholesky=np.diag(observation_noise_std), + cache={"cov_cholesky": np.diag(observation_noise_std)}, ) measmod_scipy = randprocs.markov.discrete.LTIGaussian( transition_matrix=proj_to_y, diff --git a/src/probnum/diffeq/odefilter/init_routines/_stack.py b/src/probnum/diffeq/odefilter/init_routines/_stack.py index 2bb00dd14..f8ba840f9 100644 --- a/src/probnum/diffeq/odefilter/init_routines/_stack.py +++ b/src/probnum/diffeq/odefilter/init_routines/_stack.py @@ -29,7 +29,7 @@ def __call__( return randvars.Normal( mean=np.asarray(mean), cov=np.diag(std**2), - cov_cholesky=np.diag(std), + cache={"cov_cholesky": np.diag(std)}, ) def _stack_initial_states(self, *, ivp, num_derivatives): diff --git a/src/probnum/diffeq/odefilter/utils/_problem_utils.py b/src/probnum/diffeq/odefilter/utils/_problem_utils.py index a396b5646..72b1222cc 100644 --- a/src/probnum/diffeq/odefilter/utils/_problem_utils.py +++ b/src/probnum/diffeq/odefilter/utils/_problem_utils.py @@ -5,8 +5,8 @@ import numpy as np from probnum import problems, randprocs, randvars +from probnum.backend.typing import FloatLike from probnum.diffeq.odefilter import approx_strategies, information_operators -from probnum.typing import FloatLike __all__ = ["ivp_to_regression_problem"] @@ -117,7 +117,9 @@ def _construct_measurement_models_gaussian_likelihood( """Construct measurement models for the IVP with Gaussian likelihoods.""" diff = ode_measurement_variance * np.eye(ode_information_operator.output_dim) diff_cholesky = np.sqrt(diff) - noise = randvars.Normal(mean=shift_vector, cov=diff, cov_cholesky=diff_cholesky) + noise = randvars.Normal( + mean=shift_vector, cov=diff, cache={"cov_cholesky": diff_cholesky} + ) measmod_initial_condition = randprocs.markov.discrete.LTIGaussian( transition_matrix=transition_matrix, diff --git a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py index 768bf77f7..41004a8b9 100644 --- a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py +++ b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_odesolution.py @@ -3,9 +3,9 @@ from scipy.integrate._ivp.common import OdeSolution from probnum import randvars +from probnum.backend.typing import ArrayLike from probnum.diffeq import _odesolution from probnum.filtsmooth._timeseriesposterior import DenseOutputValueType -from probnum.typing import ArrayLike class WrappedScipyODESolution(_odesolution.ODESolution): diff --git a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py index 6f2dd88be..8d325615b 100644 --- a/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py +++ b/src/probnum/diffeq/perturbed/scipy_wrapper/_wrapped_scipy_solver.py @@ -8,9 +8,9 @@ from scipy.integrate._ivp.common import OdeSolution from probnum import randvars +from probnum.backend.typing import FloatLike from probnum.diffeq import _odesolver, _odesolver_state from probnum.diffeq.perturbed.scipy_wrapper import _wrapped_scipy_odesolution -from probnum.typing import FloatLike class WrappedScipyRungeKutta(_odesolver.ODESolver): diff --git a/src/probnum/diffeq/perturbed/step/_perturbation_functions.py b/src/probnum/diffeq/perturbed/step/_perturbation_functions.py index fe0a7b44b..52a120173 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbation_functions.py +++ b/src/probnum/diffeq/perturbed/step/_perturbation_functions.py @@ -4,7 +4,7 @@ import numpy as np import scipy -from probnum.typing import FloatLike, IntLike, ShapeLike +from probnum.backend.typing import FloatLike, IntLike, ShapeLike def perturb_uniform( diff --git a/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py b/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py index a92798e2e..558671355 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py +++ b/src/probnum/diffeq/perturbed/step/_perturbedstepsolution.py @@ -6,8 +6,8 @@ from scipy.integrate._ivp import rk from probnum import randvars +from probnum.backend.typing import FloatLike from probnum.diffeq import _odesolution -from probnum.typing import FloatLike class PerturbedStepSolution(_odesolution.ODESolution): diff --git a/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py b/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py index 381868675..5f3dd2f02 100644 --- a/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py +++ b/src/probnum/diffeq/perturbed/step/_perturbedstepsolver.py @@ -5,13 +5,13 @@ import numpy as np from probnum import randvars +from probnum.backend.typing import FloatLike from probnum.diffeq import _odesolver, _odesolver_state from probnum.diffeq.perturbed import scipy_wrapper from probnum.diffeq.perturbed.step import ( _perturbation_functions, _perturbedstepsolution, ) -from probnum.typing import FloatLike class PerturbedStepSolver(_odesolver.ODESolver): diff --git a/src/probnum/diffeq/stepsize/_steprule.py b/src/probnum/diffeq/stepsize/_steprule.py index 5276de28f..ca91f2f9d 100644 --- a/src/probnum/diffeq/stepsize/_steprule.py +++ b/src/probnum/diffeq/stepsize/_steprule.py @@ -5,7 +5,7 @@ import numpy as np -from probnum.typing import ArrayLike, FloatLike, IntLike +from probnum.backend.typing import ArrayLike, FloatLike, IntLike class StepRule(ABC): diff --git a/src/probnum/filtsmooth/_kalman_filter_smoother.py b/src/probnum/filtsmooth/_kalman_filter_smoother.py index f2ad6669f..eb9cd7cdb 100644 --- a/src/probnum/filtsmooth/_kalman_filter_smoother.py +++ b/src/probnum/filtsmooth/_kalman_filter_smoother.py @@ -5,8 +5,8 @@ import numpy as np from probnum import problems, randprocs, randvars +from probnum.backend.typing import ArrayLike from probnum.filtsmooth import gaussian -from probnum.typing import ArrayLike __all__ = ["filter_kalman", "smooth_rts"] diff --git a/src/probnum/filtsmooth/_timeseriesposterior.py b/src/probnum/filtsmooth/_timeseriesposterior.py index 11bf1ce42..a7951946b 100644 --- a/src/probnum/filtsmooth/_timeseriesposterior.py +++ b/src/probnum/filtsmooth/_timeseriesposterior.py @@ -8,7 +8,13 @@ import numpy as np from probnum import randvars -from probnum.typing import ArrayIndicesLike, ArrayLike, FloatLike, IntLike, ShapeLike +from probnum.backend.typing import ( + ArrayIndicesLike, + ArrayLike, + FloatLike, + IntLike, + ShapeLike, +) DenseOutputValueType = Union[randvars.RandomVariable, randvars._RandomVariableList] """Output type of interpolation. diff --git a/src/probnum/filtsmooth/gaussian/_kalmanposterior.py b/src/probnum/filtsmooth/gaussian/_kalmanposterior.py index 79eddf73c..934fe0a1d 100644 --- a/src/probnum/filtsmooth/gaussian/_kalmanposterior.py +++ b/src/probnum/filtsmooth/gaussian/_kalmanposterior.py @@ -12,10 +12,10 @@ import numpy as np from scipy import stats -from probnum import randprocs, randvars, utils +from probnum import backend, randprocs, randvars +from probnum.backend.typing import ArrayLike, FloatLike, IntLike, ShapeLike from probnum.filtsmooth import _timeseriesposterior from probnum.filtsmooth.gaussian import approx -from probnum.typing import ArrayLike, FloatLike, IntLike, ShapeLike GaussMarkovPriorTransitionArgType = Union[ randprocs.markov.discrete.LinearGaussian, @@ -70,7 +70,7 @@ def sample( size: Optional[ShapeLike] = (), ) -> np.ndarray: - size = utils.as_shape(size) + size = backend.asshape(size) single_rv_shape = self.states[0].shape single_rv_ndim = self.states[0].ndim diff --git a/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py b/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py index 92a2a60ae..570d93ec2 100644 --- a/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py +++ b/src/probnum/filtsmooth/gaussian/approx/_unscentedkalman.py @@ -94,7 +94,7 @@ def _linearize_via_cubature(*, t, model, rv, unit_params, forw_impl, backw_impl) """Linearize a nonlinear model statistically with spherical cubature integration.""" sigma_points_unit, weights = unit_params - sigma_points = sigma_points_unit @ rv.cov_cholesky.T + rv.mean[None, :] + sigma_points = sigma_points_unit @ rv._cov_cholesky.T + rv.mean[None, :] sigma_points_transitioned = np.stack( [model.transition_fun(t, p) for p in sigma_points], axis=0 diff --git a/src/probnum/filtsmooth/particle/_particle_filter.py b/src/probnum/filtsmooth/particle/_particle_filter.py index c57ea6e65..ddf18c49f 100644 --- a/src/probnum/filtsmooth/particle/_particle_filter.py +++ b/src/probnum/filtsmooth/particle/_particle_filter.py @@ -5,12 +5,12 @@ import numpy as np from probnum import problems, randprocs, randvars +from probnum.backend.typing import FloatLike, IntLike from probnum.filtsmooth import _bayesfiltsmooth from probnum.filtsmooth.particle import ( _importance_distributions, _particle_filter_posterior, ) -from probnum.typing import FloatLike, IntLike # Terribly long variable names, but internal only, so no worries. ParticleFilterMeasurementModelArgType = Union[ @@ -39,9 +39,6 @@ class ParticleFilter(_bayesfiltsmooth.BayesFiltSmooth): A PF estimates the posterior distribution of a Markov process given noisy, non-linear observations, with a set of particles. - The random state of the particle filter is inferred - from the random state of the initial random variable. - Parameters ---------- prior_process : diff --git a/src/probnum/filtsmooth/particle/_particle_filter_posterior.py b/src/probnum/filtsmooth/particle/_particle_filter_posterior.py index f19c6fa52..20dbcdc03 100644 --- a/src/probnum/filtsmooth/particle/_particle_filter_posterior.py +++ b/src/probnum/filtsmooth/particle/_particle_filter_posterior.py @@ -5,8 +5,8 @@ import numpy as np from probnum import randvars +from probnum.backend.typing import ArrayLike, FloatLike, ShapeLike from probnum.filtsmooth import _timeseriesposterior -from probnum.typing import ArrayLike, FloatLike, ShapeLike class ParticleFilterPosterior(_timeseriesposterior.TimeSeriesPosterior): diff --git a/src/probnum/functions/_algebra_fallbacks.py b/src/probnum/functions/_algebra_fallbacks.py index d511cdef5..a2b3305eb 100644 --- a/src/probnum/functions/_algebra_fallbacks.py +++ b/src/probnum/functions/_algebra_fallbacks.py @@ -5,10 +5,8 @@ import functools import operator -import numpy as np - -from probnum import utils -from probnum.typing import ScalarLike, ScalarType +from probnum import backend +from probnum.backend.typing import ScalarLike from ._function import Function @@ -61,7 +59,7 @@ def summands(self) -> tuple[SumFunction, ...]: r"""The functions :math:`f_1, \dotsc, f_n` to be added.""" return self._summands - def _evaluate(self, x: np.ndarray) -> np.ndarray: + def _evaluate(self, x: backend.Array) -> backend.Array: return functools.reduce( operator.add, (summand(x) for summand in self._summands) ) @@ -100,7 +98,7 @@ def __init__(self, function: Function, scalar: ScalarLike): ) self._function = function - self._scalar = utils.as_numpy_scalar(scalar) + self._scalar = backend.asscalar(scalar) super().__init__( input_shape=self._function.input_shape, @@ -113,29 +111,29 @@ def function(self) -> Function: return self._function @property - def scalar(self) -> ScalarType: + def scalar(self) -> backend.Scalar: r"""The scalar :math:`\alpha`.""" return self._scalar - def _evaluate(self, x: np.ndarray) -> np.ndarray: + def _evaluate(self, x: backend.Array) -> backend.Array: return self._scalar * self._function(x) @functools.singledispatchmethod def __mul__(self, other): - if np.ndim(other) == 0: + if backend.ndim(other) == 0: return ScaledFunction( function=self._function, - scalar=self._scalar * np.asarray(other), + scalar=self._scalar * backend.asarray(other), ) return super().__mul__(other) @functools.singledispatchmethod def __rmul__(self, other): - if np.ndim(other) == 0: + if backend.ndim(other) == 0: return ScaledFunction( function=self._function, - scalar=np.asarray(other) * self._scalar, + scalar=backend.asarray(other) * self._scalar, ) return super().__rmul__(other) diff --git a/src/probnum/functions/_function.py b/src/probnum/functions/_function.py index 12573c6ec..808a6a3f8 100644 --- a/src/probnum/functions/_function.py +++ b/src/probnum/functions/_function.py @@ -6,10 +6,8 @@ import functools from typing import Callable -import numpy as np - -from probnum import utils -from probnum.typing import ArrayLike, ShapeLike, ShapeType +from probnum import backend +from probnum.backend.typing import ArrayLike, ShapeLike, ShapeType class Function(abc.ABC): @@ -36,10 +34,10 @@ class Function(abc.ABC): """ def __init__(self, input_shape: ShapeLike, output_shape: ShapeLike = ()) -> None: - self._input_shape = utils.as_shape(input_shape) + self._input_shape = backend.asshape(input_shape) self._input_ndim = len(self._input_shape) - self._output_shape = utils.as_shape(output_shape) + self._output_shape = backend.asshape(output_shape) self._output_ndim = len(self._output_shape) @property @@ -68,7 +66,7 @@ def output_ndim(self) -> int: """Syntactic sugar for ``len(output_shape)``.""" return self._output_ndim - def __call__(self, x: ArrayLike) -> np.ndarray: + def __call__(self, x: ArrayLike) -> backend.Array: """Evaluate the function at a given input. The function is vectorized over the batch shape of the input. @@ -91,7 +89,7 @@ def __call__(self, x: ArrayLike) -> np.ndarray: If the shape of ``x`` does not match :attr:`input_shape` along its last dimensions. """ - x = np.asarray(x) + x = backend.asarray(x) # Shape checking if x.shape[x.ndim - self.input_ndim :] != self.input_shape: @@ -112,7 +110,7 @@ def __call__(self, x: ArrayLike) -> np.ndarray: return fx @abc.abstractmethod - def _evaluate(self, x: np.ndarray) -> np.ndarray: + def _evaluate(self, x: backend.Array) -> backend.Array: pass def __neg__(self): @@ -128,7 +126,7 @@ def __sub__(self, other): @functools.singledispatchmethod def __mul__(self, other): - if np.ndim(other) == 0: + if backend.ndim(other) == 0: from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel ScaledFunction, ) @@ -139,7 +137,7 @@ def __mul__(self, other): @functools.singledispatchmethod def __rmul__(self, other): - if np.ndim(other) == 0: + if backend.ndim(other) == 0: from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel ScaledFunction, ) @@ -150,9 +148,9 @@ def __rmul__(self, other): class LambdaFunction(Function): - """Define a :class:`Function` from a given :class:`callable`. + """Define a :class:`Function` from a given :class:`Callable`. - Creates a :class:`Function` from a given :class:`callable` and in- and output + Creates a :class:`Function` from a given :class:`Callable` and in- and output shapes. This provides a convenient interface to define a :class:`Function`. Parameters @@ -166,10 +164,10 @@ class LambdaFunction(Function): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.functions import LambdaFunction >>> fn = LambdaFunction(fn=lambda x: 2 * x + 1, input_shape=(2,), output_shape=(2,)) - >>> fn(np.array([[1, 2], [4, 5]])) + >>> fn(backend.asarray([[1, 2], [4, 5]])) array([[ 3, 5], [ 9, 11]]) @@ -180,7 +178,7 @@ class LambdaFunction(Function): def __init__( self, - fn: Callable[[np.ndarray], np.ndarray], + fn: Callable[[backend.Array], backend.Array], input_shape: ShapeLike, output_shape: ShapeLike = (), ) -> None: @@ -188,5 +186,5 @@ def __init__( super().__init__(input_shape, output_shape) - def _evaluate(self, x: np.ndarray) -> np.ndarray: + def _evaluate(self, x: backend.Array) -> backend.Array: return self._fn(x) diff --git a/src/probnum/functions/_zero.py b/src/probnum/functions/_zero.py index 2c6ae01ea..fd03f6487 100644 --- a/src/probnum/functions/_zero.py +++ b/src/probnum/functions/_zero.py @@ -2,7 +2,7 @@ import functools -import numpy as np +from probnum import backend from . import _function @@ -10,8 +10,8 @@ class Zero(_function.Function): """Zero mean function.""" - def _evaluate(self, x: np.ndarray) -> np.ndarray: - return np.zeros_like( # pylint: disable=unexpected-keyword-arg + def _evaluate(self, x: backend.Array) -> backend.Array: + return backend.zeros_like( x, shape=x.shape[: x.ndim - self._input_ndim] + self._output_shape, ) diff --git a/src/probnum/linalg/_problinsolve.py b/src/probnum/linalg/_problinsolve.py index 134caae3f..ce537e723 100644 --- a/src/probnum/linalg/_problinsolve.py +++ b/src/probnum/linalg/_problinsolve.py @@ -15,7 +15,7 @@ import scipy.sparse import probnum # pylint: disable=unused-import -from probnum import linops, randvars, utils +from probnum import linops, randvars from probnum.linalg.solvers.matrixbased import SymmetricMatrixBasedSolver from probnum.typing import LinearOperatorLike @@ -201,7 +201,7 @@ def problinsolve( # Select and initialize solver linear_solver = _init_solver( A=A, - b=utils.as_colvec(b[:, i]), + b=as_colvec(b[:, i]), A0=A0, Ainv0=Ainv0, x0=x, @@ -344,9 +344,9 @@ def _preprocess_linear_system(A, b, x0=None): """ # Transform linear system to correct dimensions if not isinstance(b, randvars.RandomVariable): - b = utils.as_colvec(b) # (n,) -> (n, 1) + b = as_colvec(b) # (n,) -> (n, 1) if x0 is not None: - x0 = utils.as_colvec(x0) # (n,) -> (n, 1) + x0 = as_colvec(x0) # (n,) -> (n, 1) return A, b, x0 @@ -477,3 +477,24 @@ def _postprocess(info, A): scipy.linalg.LinAlgWarning, stacklevel=3, ) + + +def as_colvec( + vec: Union[np.ndarray, "probnum.randvars.RandomVariable"] +) -> Union[np.ndarray, "probnum.randvars.RandomVariable"]: + """Transform the given vector or random variable to column format. Given a vector + (or random variable) of dimension (n,) return an array with dimensions (n, 1) + instead. Higher-dimensional arrays are not changed. + + Parameters + ---------- + vec + Vector, array or random variable to be transformed into a column vector. + """ + if isinstance(vec, probnum.randvars.RandomVariable): + if vec.shape != (vec.shape[0], 1): + vec.reshape(newshape=(vec.shape[0], 1)) + else: + if vec.ndim == 1: + return vec[:, None] + return vec diff --git a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py index 0c93aaf3e..e64544da3 100644 --- a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py +++ b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py @@ -72,14 +72,15 @@ class ProbabilisticLinearSolver( -------- Define a linear system. - >>> import numpy as np + >>> from probnum import backend >>> from probnum.problems import LinearSystem >>> from probnum.problems.zoo.linalg import random_spd_matrix - >>> rng = np.random.default_rng(42) + >>> rng_state = backend.random.rng_state(42) + >>> rng_state, rng_state_A, rng_state_b = backend.random.split(rng_state, 3) >>> n = 100 - >>> A = random_spd_matrix(rng=rng, dim=n) - >>> b = rng.standard_normal(size=(n,)) + >>> A = random_spd_matrix(rng_state=rng_state_A, shape=(n,n)) + >>> b = backend.random.standard_normal(rng_state_b, shape=(n,)) >>> linsys = LinearSystem(A=A, b=b) Create a custom probabilistic linear solver from pre-defined components. @@ -116,8 +117,8 @@ class ProbabilisticLinearSolver( Solve the linear system using the custom solver. >>> belief, solver_state = pls.solve(prior=prior, problem=linsys) - >>> np.linalg.norm(linsys.A @ belief.x.mean - linsys.b) / np.linalg.norm(linsys.b) - 7.1886e-06 + >>> backend.linalg.vector_norm(solver_state.residual) + array(6.56325045e-05) """ def __init__( diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py index 1d6a7637f..234113717 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py @@ -4,8 +4,8 @@ import probnum # pylint: disable="unused-import" from probnum import randvars +from probnum.backend.typing import FloatLike from probnum.linalg.solvers.beliefs import LinearSystemBelief -from probnum.typing import FloatLike from .._linear_solver_belief_update import LinearSolverBeliefUpdate diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 62be3da21..4c00a7df8 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -109,31 +109,6 @@ def dim_mismatch_error(**kwargs): """ ) - if x is not None and not isinstance(x, randvars.RandomVariable): - raise TypeError( - f"""The belief about the solution 'x' must be a RandomVariable, but - is {type(x)}. - """ - ) - if A is not None and not isinstance(A, randvars.RandomVariable): - raise TypeError( - f"""The belief about the matrix 'A' must be a RandomVariable, but - is {type(A)}. - """ - ) - if Ainv is not None and not isinstance(Ainv, randvars.RandomVariable): - raise TypeError( - f"""The belief about the inverse matrix 'Ainv' must be a RandomVariable, - but is {type(Ainv)}. - """ - ) - if b is not None and not isinstance(b, randvars.RandomVariable): - raise TypeError( - f"""The belief about the right-hand-side 'b' must be a RandomVariable, - but is {type(b)}. - """ - ) - self._x = x self._A = A self._Ainv = Ainv diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 4896a315e..4d0a236c2 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -71,7 +71,9 @@ def __call__( prev_residual = solver_state.residuals[solver_state.step - 1] # A-conjugacy correction (in exact arithmetic) - beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 + beta = ( + np.linalg.norm(residual, ord=2) / np.linalg.norm(prev_residual, ord=2) + ) ** 2 action = residual + beta * solver_state.actions[solver_state.step - 1] # Reorthogonalization of the resulting action diff --git a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py index 683ea0ab4..7a6827272 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py @@ -2,8 +2,8 @@ import numpy as np -import probnum # pylint: disable="unused-import" -from probnum.typing import ScalarLike +import probnum +from probnum.backend.typing import ScalarLike from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion @@ -33,8 +33,8 @@ def __init__( rtol: ScalarLike = 10**-5, ): self.qoi = qoi - self.atol = probnum.utils.as_numpy_scalar(atol) - self.rtol = probnum.utils.as_numpy_scalar(rtol) + self.atol = probnum.backend.asscalar(atol) + self.rtol = probnum.backend.asscalar(rtol) def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" diff --git a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py index bbc3cad7c..ebe1eeadb 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py @@ -3,7 +3,7 @@ import numpy as np import probnum -from probnum.typing import ScalarLike +from probnum.backend.typing import ScalarLike from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion @@ -28,8 +28,8 @@ def __init__( atol: ScalarLike = 10**-5, rtol: ScalarLike = 10**-5, ): - self.atol = probnum.utils.as_numpy_scalar(atol) - self.rtol = probnum.utils.as_numpy_scalar(rtol) + self.atol = probnum.backend.asscalar(atol) + self.rtol = probnum.backend.asscalar(rtol) def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" diff --git a/src/probnum/linops/__init__.py b/src/probnum/linops/__init__.py index dbe5006e7..7c553534d 100644 --- a/src/probnum/linops/__init__.py +++ b/src/probnum/linops/__init__.py @@ -22,7 +22,7 @@ Selection, ) from ._scaling import Scaling, Zero -from ._utils import LinearOperatorLike, aslinop +from ._utils import aslinop # Public classes and functions. Order is reflected in documentation. __all__ = [ diff --git a/src/probnum/linops/_arithmetic.py b/src/probnum/linops/_arithmetic.py index 7a88eb1a6..df64ee28c 100644 --- a/src/probnum/linops/_arithmetic.py +++ b/src/probnum/linops/_arithmetic.py @@ -4,8 +4,8 @@ import numpy as np import scipy.sparse -from probnum import config, utils -from probnum.typing import NotImplementedType, ScalarLike, ShapeLike +from probnum import backend, config +from probnum.backend.typing import NotImplementedType, ScalarLike, ShapeLike from ._arithmetic_fallbacks import ( NegatedLinearOperator, @@ -397,13 +397,13 @@ def _apply( ) -> Union[LinearOperator, NotImplementedType]: if np.ndim(op1) == 0: key1 = np.number - op1 = utils.as_numpy_scalar(op1) + op1 = backend.asscalar(op1) else: key1 = type(op1) if np.ndim(op2) == 0: key2 = np.number - op2 = utils.as_numpy_scalar(op2) + op2 = backend.asscalar(op2) else: key2 = type(op2) diff --git a/src/probnum/linops/_arithmetic_fallbacks.py b/src/probnum/linops/_arithmetic_fallbacks.py index 1ffa6f20b..afc99ff01 100644 --- a/src/probnum/linops/_arithmetic_fallbacks.py +++ b/src/probnum/linops/_arithmetic_fallbacks.py @@ -7,8 +7,8 @@ import numpy as np -from probnum.typing import NotImplementedType, ScalarLike -import probnum.utils +from probnum import backend +from probnum.backend.typing import NotImplementedType, ScalarLike from ._linear_operator import BinaryOperandType, LambdaLinearOperator, LinearOperator @@ -30,7 +30,7 @@ def __init__(self, linop: LinearOperator, scalar: ScalarLike): dtype = np.result_type(linop.dtype, scalar) self._linop = linop - self._scalar = probnum.utils.as_numpy_scalar(scalar, dtype) + self._scalar = backend.asscalar(scalar, dtype) super().__init__( self._linop.shape, @@ -72,7 +72,7 @@ def _symmetrize(self) -> ScaledLinearOperator: class NegatedLinearOperator(ScaledLinearOperator): def __init__(self, linop: LinearOperator): - super().__init__(linop, scalar=probnum.utils.as_numpy_scalar(-1, linop.dtype)) + super().__init__(linop, scalar=backend.asscalar(-1, linop.dtype)) def __neg__(self) -> "LinearOperator": return self._linop diff --git a/src/probnum/linops/_kronecker.py b/src/probnum/linops/_kronecker.py index 5af505a76..be02f0a23 100644 --- a/src/probnum/linops/_kronecker.py +++ b/src/probnum/linops/_kronecker.py @@ -5,7 +5,8 @@ import numpy as np -from probnum.typing import DTypeLike, LinearOperatorLike, NotImplementedType +from probnum.backend.typing import DTypeLike, NotImplementedType +from probnum.typing import LinearOperatorLike from . import _linear_operator, _utils diff --git a/src/probnum/linops/_linear_operator.py b/src/probnum/linops/_linear_operator.py index 28010fbf0..10ac71302 100644 --- a/src/probnum/linops/_linear_operator.py +++ b/src/probnum/linops/_linear_operator.py @@ -9,9 +9,8 @@ import scipy.linalg import scipy.sparse.linalg -from probnum import config -from probnum.typing import DTypeLike, ScalarLike, ShapeLike -import probnum.utils +from probnum import backend, config +from probnum.backend.typing import DTypeLike, ScalarLike, ShapeLike BinaryOperandType = Union[ "LinearOperator", ScalarLike, np.ndarray, scipy.sparse.spmatrix @@ -52,7 +51,7 @@ def __init__( shape: ShapeLike, dtype: DTypeLike, ): - self.__shape = probnum.utils.as_shape(shape, ndim=2) + self.__shape = backend.asshape(shape, ndim=2) # DType self.__dtype = np.dtype(dtype) @@ -123,8 +122,7 @@ def _apply(self, x: np.ndarray, axis: int) -> np.ndarray: raise NotImplementedError() def __call__(self, x: np.ndarray, axis: Optional[int] = None) -> np.ndarray: - """Apply the linear operator to an input array along a specified - axis. + """Apply the linear operator to an input array along a specified axis. Parameters ---------- @@ -557,7 +555,7 @@ def logabsdet(self) -> np.inexact: def _logabsdet_fallback(self) -> np.inexact: if self.det() == 0: - return probnum.utils.as_numpy_scalar(-np.inf, dtype=self._inexact_dtype) + return backend.asscalar(-np.inf, dtype=self._inexact_dtype) else: return np.log(np.abs(self.det())) @@ -1544,7 +1542,7 @@ def __init__( shape: ShapeLike, dtype: DTypeLike = np.double, ): - shape = probnum.utils.as_shape(shape) + shape = backend.asshape(shape) if len(shape) == 1: shape = 2 * shape @@ -1568,13 +1566,9 @@ def __init__( rank=lambda: np.intp(shape[0]), eigvals=lambda: np.ones(shape[0], dtype=self._inexact_dtype), cond=self._cond, - det=lambda: probnum.utils.as_numpy_scalar(1.0, dtype=self._inexact_dtype), - logabsdet=lambda: probnum.utils.as_numpy_scalar( - 0.0, dtype=self._inexact_dtype - ), - trace=lambda: probnum.utils.as_numpy_scalar( - self.shape[0], dtype=self.dtype - ), + det=lambda: backend.asscalar(1.0, dtype=self._inexact_dtype), + logabsdet=lambda: backend.asscalar(0.0, dtype=self._inexact_dtype), + trace=lambda: backend.asscalar(self.shape[0], dtype=self.dtype), ) # Matrix properties @@ -1588,11 +1582,9 @@ def _cond( self, p: Optional[Union[None, int, str, np.floating]] = None ) -> np.inexact: if p is None or p in (2, 1, np.inf, -2, -1, -np.inf): - return probnum.utils.as_numpy_scalar(1.0, dtype=self._inexact_dtype) + return backend.asscalar(1.0, dtype=self._inexact_dtype) elif p == "fro": - return probnum.utils.as_numpy_scalar( - self.shape[0], dtype=self._inexact_dtype - ) + return backend.asscalar(self.shape[0], dtype=self._inexact_dtype) else: return np.linalg.cond(self.todense(cache=False), p=p) @@ -1624,7 +1616,7 @@ def __init__(self, indices, shape, dtype=np.double): "output-dimension (shape[0]) is larger than the input-dimension " "(shape[1]), consider using `Embedding`." ) - self._indices = probnum.utils.as_shape(indices) + self._indices = backend.asshape(indices) assert len(self._indices) == shape[0] super().__init__( @@ -1676,8 +1668,8 @@ def __init__( "(shape[1]), consider using `Selection`." ) - self._take_indices = probnum.utils.as_shape(take_indices) - self._put_indices = probnum.utils.as_shape(put_indices) + self._take_indices = backend.asshape(take_indices) + self._put_indices = backend.asshape(put_indices) self._fill_value = fill_value super().__init__( diff --git a/src/probnum/linops/_scaling.py b/src/probnum/linops/_scaling.py index 95b076105..c9c219b53 100644 --- a/src/probnum/linops/_scaling.py +++ b/src/probnum/linops/_scaling.py @@ -5,8 +5,8 @@ import numpy as np -from probnum.typing import DTypeLike, ScalarLike, ShapeLike -import probnum.utils +from probnum import backend +from probnum.backend.typing import DTypeLike, ScalarLike, ShapeLike from . import _linear_operator @@ -49,7 +49,7 @@ def __init__( if np.ndim(factors) == 0: # Isotropic scaling - self._scalar = probnum.utils.as_numpy_scalar(factors, dtype=dtype) + self._scalar = backend.asscalar(factors, dtype=dtype) if shape is None: raise ValueError( @@ -57,7 +57,7 @@ def __init__( "specified." ) - shape = probnum.utils.as_shape(shape) + shape = backend.asshape(shape) if len(shape) == 1: shape = 2 * shape @@ -110,7 +110,7 @@ def __init__( self._scalar.astype(self._inexact_dtype, copy=False) ** shape[0] ) logabsdet = lambda: ( - probnum.utils.as_numpy_scalar(-np.inf, dtype=self._inexact_dtype) + backend.asscalar(-np.inf, dtype=self._inexact_dtype) if self._scalar == 0 else shape[0] * np.log(np.abs(self._scalar)) ) @@ -272,7 +272,7 @@ def _cond_anisotropic(self, p: Union[None, int, float, str]) -> np.inexact: if abs_min == 0.0: # The operator is singular - return probnum.utils.as_numpy_scalar(np.inf, dtype=self._inexact_dtype) + return backend.asscalar(np.inf, dtype=self._inexact_dtype) if p is None: p = 2 @@ -301,11 +301,9 @@ def _cond_isotropic(self, p: Union[None, int, float, str]) -> np.inexact: return self._inexact_dtype.type(np.inf) if p is None or p in (2, 1, np.inf, -2, -1, -np.inf): - return probnum.utils.as_numpy_scalar(1.0, dtype=self._inexact_dtype) + return backend.asscalar(1.0, dtype=self._inexact_dtype) elif p == "fro": - return probnum.utils.as_numpy_scalar( - min(self.shape), dtype=self._inexact_dtype - ) + return backend.asscalar(min(self.shape), dtype=self._inexact_dtype) else: return np.linalg.cond(self.todense(cache=False), p=p) diff --git a/src/probnum/problems/_problems.py b/src/probnum/problems/_problems.py index 54a2f2398..736b6d884 100644 --- a/src/probnum/problems/_problems.py +++ b/src/probnum/problems/_problems.py @@ -9,7 +9,7 @@ import scipy.sparse from probnum import linops, randvars -from probnum.typing import FloatLike +from probnum.backend.typing import FloatLike @dataclasses.dataclass diff --git a/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py b/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py index 3d0dbdf20..3b21225cb 100644 --- a/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py +++ b/src/probnum/problems/zoo/filtsmooth/_filtsmooth_problems.py @@ -5,8 +5,8 @@ import numpy as np from probnum import diffeq, filtsmooth, problems, randprocs, randvars +from probnum.backend.typing import FloatLike, IntLike from probnum.problems.zoo import diffeq as diffeq_zoo -from probnum.typing import FloatLike, IntLike __all__ = [ "benes_daum", @@ -119,7 +119,7 @@ def car_tracking( initrv = randvars.Normal( np.zeros(model_dim), measurement_variance * np.eye(model_dim), - cov_cholesky=np.sqrt(measurement_variance) * np.eye(model_dim), + cache={"cov_cholesky": np.sqrt(measurement_variance) * np.eye(model_dim)}, ) # Set up regression problem diff --git a/src/probnum/problems/zoo/linalg/_random_linear_system.py b/src/probnum/problems/zoo/linalg/_random_linear_system.py index b0f9baafe..76c043207 100644 --- a/src/probnum/problems/zoo/linalg/_random_linear_system.py +++ b/src/probnum/problems/zoo/linalg/_random_linear_system.py @@ -6,16 +6,17 @@ import numpy as np import scipy.sparse -from probnum import linops, problems, randvars +from probnum import backend, linops, problems, randvars +from probnum.backend.random import RNGState from probnum.typing import LinearOperatorLike def random_linear_system( - rng: np.random.Generator, + rng_state: RNGState, matrix: Union[ LinearOperatorLike, Callable[ - [np.random.Generator, Optional[Any]], + [RNGState, Optional[Any]], Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], ], ], @@ -24,20 +25,18 @@ def random_linear_system( ) -> problems.LinearSystem: """Random linear system. - Generate a random linear system from a (random) matrix. - If ``matrix`` is a callable instead of a matrix or linear operator, - the system matrix is sampled by passing the random generator - instance ``rng``. The solution of the linear system is set - to a realization from ``solution_rv``. If ``None`` the solution - is drawn from a standard normal distribution with iid components. + Generate a random linear system from a (random) matrix. If ``matrix`` is a callable + instead of a matrix or linear operator, the system matrix is sampled by passing the + random generator state ``rng_state``. The solution of the linear system is set to a + realization from ``solution_rv``. If ``None`` the solution is drawn from a + standard normal distribution with iid components. Parameters ---------- - rng - Random number generator. + rng_state + State of the random number generator. matrix - Matrix, linear operator or callable returning either - for a given random number generator instance. + Matrix, linear operator or callable returning either for a given RNG state. solution_rv Random variable from which the solution of the linear system is sampled. kwargs @@ -51,53 +50,58 @@ def random_linear_system( Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.problems.zoo.linalg import random_linear_system - >>> rng = np.random.default_rng(42) + >>> rng_state = backend.random.rng_state(42) Linear system with given system matrix. - >>> import scipy.stats - >>> unitary_matrix = scipy.stats.unitary_group.rvs(dim=5, random_state=rng) - >>> linsys_unitary = random_linear_system(rng, unitary_matrix) + >>> unitary_matrix = backend.random.uniform_so_group(rng_state, n=5) + >>> linsys_unitary = random_linear_system(rng_state, unitary_matrix) >>> np.abs(np.linalg.det(linsys_unitary.A)) 1.0 Linear system with random symmetric positive-definite matrix. >>> from probnum.problems.zoo.linalg import random_spd_matrix - >>> linsys_spd = random_linear_system(rng, random_spd_matrix, dim=2) + >>> linsys_spd = random_linear_system(rng_state, random_spd_matrix, shape=(2,2)) >>> linsys_spd - LinearSystem(A=array([[ 9.62543582, 3.14955953], - [ 3.14955953, 13.28720426]]), b=array([-2.7108139 , 1.10779288]), - solution=array([-0.33488503, 0.16275307])) + LinearSystem(A=array([[10.61706238, -0.78723358], + [-0.78723358, 10.06458988]]), b=array([3.96470544, 5.76555243]), + solution=array([0.41832997, 0.60557617])) Linear system with random sparse matrix. >>> import scipy.sparse - >>> random_sparse_matrix = lambda rng,m,n: scipy.sparse.random(m=m, n=n,\ - random_state=rng) - >>> linsys_sparse = random_linear_system(rng, random_sparse_matrix, m=4, n=2) + >>> from probnum.problems.zoo.linalg import random_sparse_spd_matrix + >>> import scipy.sparse + >>> from probnum.problems.zoo.linalg import random_sparse_spd_matrix + >>> linsys_sparse = random_linear_system( + ... rng_state, random_sparse_spd_matrix, shape=(10,10), density=0.1 + ... ) >>> isinstance(linsys_sparse.A, scipy.sparse.spmatrix) True """ + # Generate system matrix if isinstance(matrix, (np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator)): A = matrix else: - A = matrix(rng=rng, **kwargs) + rng_state, matrix_rng_state = backend.random.split(rng_state, num=2) + + A = matrix(rng_state=matrix_rng_state, **kwargs) # Sample solution if solution_rv is None: n = A.shape[1] - x = randvars.Normal(mean=0.0, cov=1.0).sample(size=(n,), rng=rng) + x = backend.random.standard_normal(rng_state, shape=(n,)) else: if A.shape[1] != solution_rv.shape[0]: raise ValueError( f"Shape of the system matrix: {A.shape} must match shape \ of the solution: {solution_rv.shape}." ) - x = solution_rv.sample(size=(), rng=rng) + x = solution_rv.sample(rng_state=rng_state, sample_shape=()) return problems.LinearSystem(A=A, b=A @ x, solution=x) diff --git a/src/probnum/problems/zoo/linalg/_random_spd_matrix.py b/src/probnum/problems/zoo/linalg/_random_spd_matrix.py index f93e5a7c1..f6bfe8e76 100644 --- a/src/probnum/problems/zoo/linalg/_random_spd_matrix.py +++ b/src/probnum/problems/zoo/linalg/_random_spd_matrix.py @@ -6,14 +6,16 @@ import numpy as np import scipy.stats -from probnum.typing import IntLike +from probnum import backend +from probnum.backend.random import RNGState +from probnum.backend.typing import ShapeLike def random_spd_matrix( - rng: np.random.Generator, - dim: IntLike, + rng_state: RNGState, + shape: ShapeLike, spectrum: Sequence = None, -) -> np.ndarray: +) -> backend.Array: r"""Random symmetric positive definite matrix. Constructs a random symmetric positive definite matrix from a given spectrum. An @@ -25,10 +27,10 @@ def random_spd_matrix( Parameters ---------- - rng - Random number generator. - dim - Matrix dimension. + rng_state + State of the random number generator. + shape + Shape of the resulting matrix. spectrum Eigenvalues of the matrix. @@ -39,53 +41,60 @@ def random_spd_matrix( Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.problems.zoo.linalg import random_spd_matrix - >>> rng = np.random.default_rng(1) - >>> mat = random_spd_matrix(rng, dim=5) + >>> rng_state = backend.random.rng_state(1) + >>> mat = random_spd_matrix(rng_state, shape=(5, 5)) >>> mat - array([[10.24394619, 0.05484236, 0.39575826, -0.70032495, -0.75482692], - [ 0.05484236, 11.31516868, 0.6968935 , -0.13877394, 0.52783063], - [ 0.39575826, 0.6968935 , 11.5728974 , 0.21214568, 1.07692458], - [-0.70032495, -0.13877394, 0.21214568, 9.88674751, -1.09750511], - [-0.75482692, 0.52783063, 1.07692458, -1.09750511, 10.193655 ]]) + array([[ 8.93286789, 0.46676405, -2.10171474, 1.44158222, -0.32869563], + [ 0.46676405, 7.63938418, -2.45135608, 2.03734623, 0.8095071 ], + [-2.10171474, -2.45135608, 8.52968389, -0.11968995, 1.74237472], + [ 1.44158222, 2.03734623, -0.11968995, 8.58417432, -1.61553113], + [-0.32869563, 0.8095071 , 1.74237472, -1.61553113, 8.1054103 ]]) Check for symmetry and positive definiteness. - >>> np.all(mat == mat.T) + >>> backend.all(mat == mat.T) True - >>> np.linalg.eigvals(mat) - array([ 8.09147328, 12.7635956 , 10.84504988, 10.73086331, 10.78143272]) + >>> backend.linalg.eigvalsh(mat) + array([ 3.51041217, 7.80937731, 8.49510526, 8.76024149, 13.21638435]) """ + shape = backend.asshape(shape) + + if not shape == () and shape[0] != shape[1]: + raise ValueError(f"Shape must represent a square matrix, but is {shape}.") + + gamma_rng_state, so_rng_state = backend.random.split(rng_state, num=2) # Initialization if spectrum is None: - # Create a custom ordered spectrum if none is given. - spectrum_shape: float = 10.0 - spectrum_scale: float = 1.0 - spectrum_offset: float = 0.0 - - spectrum = scipy.stats.gamma.rvs( - spectrum_shape, - loc=spectrum_offset, - scale=spectrum_scale, - size=dim, - random_state=rng, + spectrum = backend.random.gamma( + gamma_rng_state, + shape_param=10.0, + scale_param=1.0, + shape=shape[:1], ) - spectrum = np.sort(spectrum)[::-1] - else: - spectrum = np.asarray(spectrum) - if not np.all(spectrum > 0): + spectrum = backend.asarray(spectrum) + + if spectrum.shape != shape[:1]: + raise ValueError( + f"Size of the spectrum {spectrum.shape} and shape {shape} are not " + + "compatible." + ) + + if not backend.all(spectrum > 0): raise ValueError(f"Eigenvalues must be positive, but are {spectrum}.") - # Early exit for d=1 -- special_ortho_group does not like this case. - if dim == 1: + if len(shape) == 0: + return spectrum + + if shape[0] == 1: return spectrum.reshape((1, 1)) # Draw orthogonal matrix with respect to the Haar measure - orth_mat = scipy.stats.special_ortho_group.rvs(dim, random_state=rng) - spd_mat = orth_mat @ np.diag(spectrum) @ orth_mat.T + orth_mat = backend.random.uniform_so_group(so_rng_state, n=shape[0]) + spd_mat = (orth_mat * spectrum[None, :]) @ orth_mat.T # Symmetrize to avoid numerically not symmetric matrix # Since A commutes with itself (AA' = A'A = AA) the eigenvalues do not change. @@ -93,8 +102,8 @@ def random_spd_matrix( def random_sparse_spd_matrix( - rng: np.random.Generator, - dim: IntLike, + rng_state: RNGState, + shape: ShapeLike, density: float, chol_entry_min: float = 0.1, chol_entry_max: float = 1.0, @@ -110,10 +119,10 @@ def random_sparse_spd_matrix( Parameters ---------- - rng - Random number generator. - dim - Matrix dimension. + rng_state + State of the random number generator. + shape + Shape of the resulting matrix. density Degree of sparsity of the off-diagonal entries of the Cholesky factor. Between 0 and 1 where 1 represents a dense matrix. @@ -130,10 +139,10 @@ def random_sparse_spd_matrix( Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.problems.zoo.linalg import random_sparse_spd_matrix - >>> rng = np.random.default_rng(42) - >>> sparsemat = random_sparse_spd_matrix(rng, dim=5, density=0.1) + >>> rng_state = backend.random.rng_state(42) + >>> sparsemat = random_sparse_spd_matrix(rng_state, shape=(5,5), density=0.1) >>> sparsemat <5x5 sparse matrix of type '' with 9 stored elements in Compressed Sparse Row format> @@ -148,17 +157,20 @@ def random_sparse_spd_matrix( # Initialization if not 0 <= density <= 1: raise ValueError(f"Density must be between 0 and 1, but is {density}.") - chol = scipy.sparse.eye(dim, format="csr") - num_off_diag_cholesky = int(0.5 * dim * (dim - 1)) + if not shape == () and shape[0] != shape[1]: + raise ValueError(f"Shape must represent a square matrix, but is {shape}.") + + chol = scipy.sparse.eye(shape[0], format="csr") + num_off_diag_cholesky = int(0.5 * shape[0] * (shape[0] - 1)) num_nonzero_entries = int(num_off_diag_cholesky * density) if num_nonzero_entries > 0: sparse_matrix = scipy.sparse.rand( - m=dim, - n=dim, + m=shape[0], + n=shape[0], format="csr", density=density, - random_state=rng, + random_state=np.random.default_rng(rng_state), ) # Rescale entries diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index ffb185dea..36564ffc1 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -13,12 +13,12 @@ import numpy as np +from probnum.backend.typing import FloatLike, IntLike from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure from probnum.quad.solvers import BayesianQuadrature, BQIterInfo from probnum.quad.typing import DomainLike, DomainType from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal -from probnum.typing import FloatLike, IntLike def bayesquad( diff --git a/src/probnum/quad/_utils.py b/src/probnum/quad/_utils.py index 4e7c5a1a5..9a651729d 100644 --- a/src/probnum/quad/_utils.py +++ b/src/probnum/quad/_utils.py @@ -1,4 +1,4 @@ -"""Helper functions for the quad package""" +"""Helper functions for the quad package.""" from __future__ import annotations @@ -6,7 +6,7 @@ import numpy as np -from probnum.typing import IntLike +from probnum.backend.typing import IntLike from .typing import DomainLike, DomainType @@ -14,8 +14,8 @@ def as_domain( domain: DomainLike, input_dim: Optional[IntLike] ) -> Tuple[DomainType, int]: - """Static method that converts the integration domain and input dimension to - the correct types. + """Static method that converts the integration domain and input dimension to the + correct types. If no ``input_dim`` is given, the dimension is inferred from the size of domain limits ``domain[0]`` and ``domain[1]``. These must be either scalars diff --git a/src/probnum/quad/integration_measures/_integration_measures.py b/src/probnum/quad/integration_measures/_integration_measures.py index 6fb210c2b..a89063865 100644 --- a/src/probnum/quad/integration_measures/_integration_measures.py +++ b/src/probnum/quad/integration_measures/_integration_measures.py @@ -7,10 +7,10 @@ import numpy as np import scipy.stats +from probnum.backend.typing import FloatLike, IntLike from probnum.quad._utils import as_domain from probnum.quad.typing import DomainLike from probnum.randvars import Normal -from probnum.typing import FloatLike, IntLike class IntegrationMeasure(abc.ABC): diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index 6b037ecba..e91c71468 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -7,6 +7,7 @@ import numpy as np +from probnum.backend.typing import FloatLike, IntLike from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.quad.solvers._bq_state import BQIterInfo, BQState @@ -22,7 +23,6 @@ from probnum.quad.typing import DomainLike from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal -from probnum.typing import FloatLike, IntLike # pylint: disable=too-many-branches, too-complex diff --git a/src/probnum/quad/solvers/_bq_state.py b/src/probnum/quad/solvers/_bq_state.py index 337d5c571..ae4bf2f13 100644 --- a/src/probnum/quad/solvers/_bq_state.py +++ b/src/probnum/quad/solvers/_bq_state.py @@ -7,11 +7,11 @@ import numpy as np +from probnum.backend.typing import FloatLike from probnum.quad.integration_measures import IntegrationMeasure from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal -from probnum.typing import FloatLike # pylint: disable=too-few-public-methods,too-many-instance-attributes diff --git a/src/probnum/quad/solvers/belief_updates/_belief_update.py b/src/probnum/quad/solvers/belief_updates/_belief_update.py index a04bb5356..888f2c401 100644 --- a/src/probnum/quad/solvers/belief_updates/_belief_update.py +++ b/src/probnum/quad/solvers/belief_updates/_belief_update.py @@ -8,11 +8,11 @@ import numpy as np from scipy.linalg import cho_factor, cho_solve +from probnum.backend.typing import FloatLike from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.quad.solvers._bq_state import BQState from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal -from probnum.typing import FloatLike # pylint: disable=too-few-public-methods, too-many-locals @@ -63,7 +63,7 @@ def __call__( def _compute_gram_cho_factor(self, gram: np.ndarray) -> np.ndarray: """Compute the Cholesky decomposition of a positive-definite Gram matrix for use - in scipy.linalg.cho_solve + in scipy.linalg.cho_solve. .. warning:: Uses scipy.linalg.cho_factor. The returned matrix is only to be used in @@ -84,8 +84,11 @@ def _compute_gram_cho_factor(self, gram: np.ndarray) -> np.ndarray: # pylint: disable=no-self-use def _gram_cho_solve(self, gram_cho_factor: np.ndarray, z: np.ndarray) -> np.ndarray: - """Wrapper for scipy.linalg.cho_solve. Meant to be used for linear systems of - the gram matrix. Requires the solution of scipy.linalg.cho_factor as input.""" + """Wrapper for scipy.linalg.cho_solve. + + Meant to be used for linear systems of the gram matrix. Requires the solution of + scipy.linalg.cho_factor as input. + """ return cho_solve(gram_cho_factor, z) @@ -173,8 +176,10 @@ def __call__( # pylint: disable=no-self-use def _estimate_kernel(self, kernel: Kernel) -> Tuple[Kernel, bool]: - """Estimate the intrinsic kernel parameters. That is, all parameters except the - scale.""" + """Estimate the intrinsic kernel parameters. + + That is, all parameters except the scale. + """ new_kernel = kernel kernel_was_updated = False return new_kernel, kernel_was_updated diff --git a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py index dbcfbbf52..64a124e45 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py +++ b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py @@ -1,12 +1,12 @@ -"""Stopping criterion based on the absolute value of the integral variance""" +"""Stopping criterion based on the absolute value of the integral variance.""" from __future__ import annotations +from probnum.backend.typing import FloatLike from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import FloatLike -# pylint: disable=too-few-public-methods, fixme +# pylint: disable=too-few-public-methods class IntegralVarianceTolerance(BQStoppingCriterion): diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py index bf87dd252..cf3a57ece 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -2,9 +2,9 @@ from __future__ import annotations +from probnum.backend.typing import IntLike from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import IntLike # pylint: disable=too-few-public-methods diff --git a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py index 6e9144809..a43502f2c 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py +++ b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py @@ -5,9 +5,9 @@ import numpy as np +from probnum.backend.typing import FloatLike from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion -from probnum.typing import FloatLike # pylint: disable=too-few-public-methods diff --git a/src/probnum/quad/typing.py b/src/probnum/quad/typing.py index f8009394a..42b2723f0 100644 --- a/src/probnum/quad/typing.py +++ b/src/probnum/quad/typing.py @@ -6,7 +6,7 @@ import numpy as np -from probnum.typing import FloatLike +from probnum.backend.typing import FloatLike __all__ = ["DomainLike", "DomainType"] diff --git a/src/probnum/randprocs/_gaussian_process.py b/src/probnum/randprocs/_gaussian_process.py index 4844e2cb3..65c6db784 100644 --- a/src/probnum/randprocs/_gaussian_process.py +++ b/src/probnum/randprocs/_gaussian_process.py @@ -2,16 +2,14 @@ from __future__ import annotations -import numpy as np - -from probnum import randvars -from probnum.typing import ArrayLike +from probnum import backend, randvars +from probnum.backend.typing import ArrayLike from . import _random_process, kernels from .. import functions -class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]): +class GaussianProcess(_random_process.RandomProcess[ArrayLike, backend.Array]): """Gaussian processes. A Gaussian process is a continuous stochastic process which if evaluated at a @@ -34,20 +32,19 @@ class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]): -------- Define a Gaussian process with a zero mean function and RBF kernel. - >>> import numpy as np - >>> from probnum.functions import Zero + >>> from probnum import backend, functions >>> from probnum.randprocs.kernels import ExpQuad >>> from probnum.randprocs import GaussianProcess - >>> mu = Zero(input_shape=()) # zero-mean function - >>> k = ExpQuad(input_shape=()) # RBF kernel + >>> mu = functions.Zero(input_shape=()) + >>> k = ExpQuad(input_shape=()) >>> gp = GaussianProcess(mu, k) Sample from the Gaussian process. - >>> x = np.linspace(-1, 1, 5) - >>> rng = np.random.default_rng(seed=42) - >>> gp.sample(rng, x) - array([-0.7539949 , -0.6658092 , -0.52972512, 0.0674298 , 0.72066223]) + >>> x = backend.linspace(-1, 1, 5) + >>> rng_state = backend.random.rng_state(seed=42) + >>> gp.sample(rng_state, x) + array([ 0.30471708, -0.22021158, -0.36160304, 0.05888274, 0.27793918]) >>> gp.cov.matrix(x) array([[1. , 0.8824969 , 0.60653066, 0.32465247, 0.13533528], [0.8824969 , 1. , 0.8824969 , 0.60653066, 0.32465247], @@ -67,13 +64,15 @@ def __init__( super().__init__( input_shape=mean.input_shape, output_shape=mean.output_shape, - dtype=np.dtype(np.float_), + dtype=backend.asdtype(backend.float64), mean=mean, cov=cov, ) def __call__(self, args: ArrayLike) -> randvars.Normal: return randvars.Normal( - mean=np.array(self.mean(args), copy=False), # pylint: disable=not-callable + mean=backend.asarray( + self.mean(args), copy=False # pylint: disable=not-callable + ), cov=self.cov.matrix(args), ) diff --git a/src/probnum/randprocs/_random_process.py b/src/probnum/randprocs/_random_process.py index 559c25e00..a8e719f3d 100644 --- a/src/probnum/randprocs/_random_process.py +++ b/src/probnum/randprocs/_random_process.py @@ -5,11 +5,10 @@ import abc from typing import Callable, Generic, Optional, Type, TypeVar, Union -import numpy as np - -from probnum import functions, randvars, utils as _utils +from probnum import backend, functions, randvars +from probnum.backend.random import RNGState +from probnum.backend.typing import DTypeLike, ShapeLike, ShapeType from probnum.randprocs import kernels -from probnum.typing import DTypeLike, ShapeLike, ShapeType InputType = TypeVar("InputType") OutputType = TypeVar("OutputType") @@ -31,7 +30,7 @@ class RandomProcess(Generic[InputType, OutputType], abc.ABC): Output shape of the random process. dtype Data type of the random process evaluated at an input. If ``object`` will be - converted to ``numpy.dtype``. + converted to :class:`~probnum.backend.DType``. mean Mean function of the random process. cov @@ -59,10 +58,10 @@ def __init__( mean: Optional[functions.Function] = None, cov: Optional[kernels.Kernel] = None, ): - self._input_shape = _utils.as_shape(input_shape) + self._input_shape = backend.asshape(input_shape) self._input_ndim = len(self._input_shape) - self._output_shape = _utils.as_shape(output_shape) + self._output_shape = backend.asshape(output_shape) self._output_ndim = len(self._output_shape) if self._output_ndim > 1: @@ -71,7 +70,7 @@ def __init__( "dimension." ) - self._dtype = np.dtype(dtype) + self._dtype = backend.asdtype(dtype) # Mean function if mean is not None: @@ -135,7 +134,7 @@ def output_ndim(self) -> int: return self._output_ndim @property - def dtype(self) -> np.dtype: + def dtype(self) -> backend.DType: """Data type of (elements of) the random process evaluated at an input.""" return self._dtype @@ -147,7 +146,7 @@ def __repr__(self) -> str: ) @abc.abstractmethod - def __call__(self, args: InputType) -> randvars.RandomVariable[OutputType]: + def __call__(self, args: InputType) -> randvars.RandomVariable: """Evaluate the random process at a set of input arguments. Parameters @@ -233,7 +232,7 @@ def var(self, args: InputType) -> OutputType: assert self._output_ndim == 1 - return np.diagonal(pointwise_covs, axis1=-2, axis2=-1) + return backend.linalg.diagonal(pointwise_covs, axis1=-2, axis2=-1) def std(self, args: InputType) -> OutputType: """Standard deviation function. @@ -250,14 +249,14 @@ def std(self, args: InputType) -> OutputType: *shape=* ``batch_shape +`` :attr:`output_shape` -- Standard deviation of the process at ``args``. """ - return np.sqrt(self.var(args=args)) + return backend.sqrt(self.var(args=args)) def push_forward( self, args: InputType, base_measure: Type[randvars.RandomVariable], - sample: np.ndarray, - ) -> np.ndarray: + sample: backend.Array, + ) -> backend.Array: """Transform samples from a base measure into samples from the random process. This function can be used to control sampling from the random process by @@ -278,9 +277,9 @@ def push_forward( def sample( self, - rng: np.random.Generator, - args: InputType = None, - size: ShapeLike = (), + rng_state: RNGState, + args: Optional[InputType] = None, + sample_shape: ShapeLike = (), ) -> Union[Callable[[InputType], OutputType], OutputType]: """Sample paths from the random process. @@ -290,42 +289,45 @@ def sample( Parameters ---------- - rng - Random number generator. + rng_state + Random number generator state. args - *shape=* ``size +`` :attr:`input_shape` -- (Batch of) input(s) at + *shape=* ``sample_shape +`` :attr:`input_shape` -- (Batch of) input(s) at which the sample paths will be evaluated. Currently, we require - ``size`` to have at most one dimension. If ``None``, sample paths, + ``sample_shape`` to have at most one dimension. If ``None``, sample paths, i.e. callables are returned. - size - Size of the sample. + sample_shape + Shape of the sample. """ if args is None: raise NotImplementedError - return self._sample_at_input(rng=rng, args=args, size=size) + return self._sample_at_input( + rng_state=rng_state, args=args, sample_shape=sample_shape + ) def _sample_at_input( self, - rng: np.random.Generator, + rng_state: RNGState, args: InputType, - size: ShapeLike = (), + sample_shape: ShapeLike = (), ) -> OutputType: """Evaluate a set of sample paths at the given inputs. This function should be implemented by subclasses of :class:`RandomProcess`. This enables :meth:`sample` to both return functions, i.e. sample paths if - only a `size` is provided and random variables if inputs are provided as well. + only a `sample_shape` is provided and random variables if inputs are provided as + well. Parameters ---------- - rng - Random number generator. + rng_state + Random number generator state. args - *shape=* ``size +`` :attr:`input_shape` -- (Batch of) input(s) at + *shape=* ``sample_shape +`` :attr:`input_shape` -- (Batch of) input(s) at which the sample paths will be evaluated. Currently, we require - ``size`` to have at most one dimension. - size - Size of the sample. + ``sample_shape`` to have at most one dimension. + sample_shape + Shape of the sample. """ - return self(args).sample(rng, size=size) + return self(args).sample(rng_state=rng_state, sample_shape=sample_shape) diff --git a/src/probnum/randprocs/kernels/__init__.py b/src/probnum/randprocs/kernels/__init__.py index 526a0afe3..d9df641d5 100644 --- a/src/probnum/randprocs/kernels/__init__.py +++ b/src/probnum/randprocs/kernels/__init__.py @@ -21,13 +21,13 @@ __all__ = [ "Kernel", "IsotropicMixin", - "WhiteNoise", - "Linear", - "Polynomial", "ExpQuad", - "RatQuad", + "Linear", "Matern", + "Polynomial", "ProductMatern", + "RatQuad", + "WhiteNoise", ] # Set correct module paths. Corrects links and module paths in documentation. diff --git a/src/probnum/randprocs/kernels/_arithmetic_fallbacks.py b/src/probnum/randprocs/kernels/_arithmetic_fallbacks.py index 856ed8739..7a5ef168e 100644 --- a/src/probnum/randprocs/kernels/_arithmetic_fallbacks.py +++ b/src/probnum/randprocs/kernels/_arithmetic_fallbacks.py @@ -6,15 +6,13 @@ import operator from typing import Optional, Tuple, Union -import numpy as np - -from probnum import utils -from probnum.typing import NotImplementedType, ScalarLike +from probnum import backend +from probnum.backend.typing import NotImplementedType, ScalarLike from ._kernel import BinaryOperandType, Kernel ######################################################################################## -# Generic Linear Operator Arithmetic (Fallbacks) +# Generic Kernel Arithmetic (Fallbacks) ######################################################################################## @@ -41,17 +39,19 @@ def __init__(self, kernel: Kernel, scalar: ScalarLike): if not isinstance(kernel, Kernel): raise TypeError("`kernel` must be a `Kernel`") - if np.ndim(scalar) != 0: + if backend.ndim(scalar) != 0: raise TypeError("`scalar` must be a scalar.") self._kernel = kernel - self._scalar = utils.as_numpy_scalar(scalar) + self._scalar = backend.asscalar(scalar) super().__init__( input_shape=kernel.input_shape, output_shape=kernel.output_shape ) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] = None + ) -> backend.Array: return self._scalar * self._kernel(x0, x1) def __repr__(self) -> str: @@ -90,7 +90,9 @@ def __init__(self, *summands: Kernel): input_shape=summands[0].input_shape, output_shape=summands[0].output_shape ) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: return functools.reduce( operator.add, (summand(x0, x1) for summand in self._summands) ) @@ -147,7 +149,9 @@ def __init__(self, *factors: Kernel): input_shape=factors[0].input_shape, output_shape=factors[0].output_shape ) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: return functools.reduce( operator.mul, (factor(x0, x1) for factor in self._factors) ) @@ -180,9 +184,9 @@ def _mul_fallback( if isinstance(op1, Kernel): if isinstance(op2, Kernel): res = ProductKernel(op1, op2) - elif np.ndim(op2) == 0: + elif backend.ndim(op2) == 0: res = ScaledKernel(kernel=op1, scalar=op2) elif isinstance(op2, Kernel): - if np.ndim(op1) == 0: + if backend.ndim(op1) == 0: res = ScaledKernel(kernel=op2, scalar=op1) return res diff --git a/src/probnum/randprocs/kernels/_exponentiated_quadratic.py b/src/probnum/randprocs/kernels/_exponentiated_quadratic.py index 5a868d73a..ea34d69c2 100644 --- a/src/probnum/randprocs/kernels/_exponentiated_quadratic.py +++ b/src/probnum/randprocs/kernels/_exponentiated_quadratic.py @@ -2,10 +2,8 @@ from typing import Optional -import numpy as np - -from probnum.typing import ScalarLike, ShapeLike -import probnum.utils as _utils +from probnum import backend +from probnum.backend.typing import ScalarLike, ShapeLike from ._kernel import IsotropicMixin, Kernel @@ -36,10 +34,10 @@ class ExpQuad(Kernel, IsotropicMixin): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import ExpQuad >>> K = ExpQuad(input_shape=(), lengthscale=0.1) - >>> xs = np.linspace(0, 1, 3) + >>> xs = backend.linspace(0, 1, 3) >>> K.matrix(xs) array([[1.00000000e+00, 3.72665317e-06, 1.92874985e-22], [3.72665317e-06, 1.00000000e+00, 3.72665317e-06], @@ -47,16 +45,19 @@ class ExpQuad(Kernel, IsotropicMixin): """ def __init__(self, input_shape: ShapeLike, lengthscale: ScalarLike = 1.0): - self.lengthscale = _utils.as_numpy_scalar(lengthscale) + self.lengthscale = backend.asscalar(lengthscale) super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + @backend.jit_method + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: if x1 is None: - return np.ones_like( # pylint: disable=unexpected-keyword-arg + return backend.ones_like( # pylint: disable=unexpected-keyword-arg x0, shape=x0.shape[: x0.ndim - self.input_ndim], ) - return np.exp( + return backend.exp( -self._squared_euclidean_distances(x0, x1) / (2.0 * self.lengthscale**2) ) diff --git a/src/probnum/randprocs/kernels/_kernel.py b/src/probnum/randprocs/kernels/_kernel.py index 62c440696..299308454 100644 --- a/src/probnum/randprocs/kernels/_kernel.py +++ b/src/probnum/randprocs/kernels/_kernel.py @@ -3,12 +3,12 @@ from __future__ import annotations import abc +import functools +import operator from typing import Optional, Union -import numpy as np - -from probnum import utils as _pn_utils -from probnum.typing import ArrayLike, ScalarLike, ShapeLike, ShapeType +from probnum import backend +from probnum.backend.typing import ArrayLike, ScalarLike, ShapeLike, ShapeType BinaryOperandType = Union["Kernel", ScalarLike] @@ -138,7 +138,7 @@ def __init__( input_shape: ShapeLike, output_shape: ShapeLike = (), ): - self._input_shape = _pn_utils.as_shape(input_shape) + self._input_shape = backend.asshape(input_shape) self._input_ndim = len(self._input_shape) if self._input_ndim > 1: @@ -146,7 +146,7 @@ def __init__( "Currently, we only support kernels with at most 1 input dimension." ) - self._output_shape = _pn_utils.as_shape(output_shape) + self._output_shape = backend.asshape(output_shape) self._output_ndim = len(self._output_shape) @property @@ -159,6 +159,11 @@ def input_ndim(self) -> int: """Syntactic sugar for ``len(input_shape)``.""" return self._input_ndim + @functools.cached_property + def input_size(self) -> int: + """Product over the entries of :attr:`input_shape`.""" + return functools.reduce(operator.add, self._input_shape, 1) + @property def output_shape(self) -> ShapeType: """Shape of single, i.e. non-batched, return values of the covariance function. @@ -182,11 +187,12 @@ def __repr__(self) -> str: f" output_shape={self.output_shape}>" ) + @backend.jit_method def __call__( self, x0: ArrayLike, x1: Optional[ArrayLike], - ) -> np.ndarray: + ) -> backend.Array: """Evaluate the (cross-)covariance function(s). The evaluation of the (cross-covariance) function(s) is vectorized over the @@ -244,10 +250,10 @@ def __call__( See documentation of class :class:`Kernel`. """ - x0 = np.asarray(x0) + x0 = backend.asarray(x0) if x1 is not None: - x1 = np.asarray(x1) + x1 = backend.asarray(x1) # Shape checking broadcast_batch_shape = self._check_shapes( @@ -261,11 +267,12 @@ def __call__( return k_x0_x1 + @backend.jit_method def matrix( self, x0: ArrayLike, x1: Optional[ArrayLike] = None, - ) -> np.ndarray: + ) -> backend.Array: """A convenience function for computing a kernel matrix for two sets of inputs. This is syntactic sugar for ``k(x0[:, None], x1[None, :])``. Hence, it @@ -308,8 +315,8 @@ def matrix( See documentation of class :class:`Kernel`. """ - x0 = np.asarray(x0) - x1 = x0 if x1 is None else np.asarray(x1) + x0 = backend.asarray(x0) + x1 = x0 if x1 is None else backend.asarray(x1) # Shape checking errmsg = ( @@ -335,7 +342,7 @@ def _evaluate( self, x0: ArrayLike, x1: Optional[ArrayLike], - ) -> np.ndarray: + ) -> backend.Array: """Implementation of the kernel evaluation which is called after input checking. When implementing a particular kernel, the subclass should implement the kernel @@ -407,7 +414,7 @@ def _check_shapes( raise ValueError(err_msg.format(argname="x1", shape=x1_shape)) try: - broadcast_batch_shape = np.broadcast_shapes( + broadcast_batch_shape = backend.broadcast_shapes( broadcast_batch_shape, x1_shape[: len(x1_shape) - self._input_ndim], ) @@ -420,9 +427,10 @@ def _check_shapes( return broadcast_batch_shape + @backend.jit_method def _euclidean_inner_products( - self, x0: np.ndarray, x1: Optional[np.ndarray] - ) -> np.ndarray: + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: """Implementation of the Euclidean inner product, which supports scalar inputs and an optional second argument.""" prods = x0**2 if x1 is None else x0 * x1 @@ -432,7 +440,7 @@ def _euclidean_inner_products( assert self.input_ndim == 1 - return np.sum(prods, axis=-1) + return backend.sum(prods, axis=-1) #################################################################################### # Binary Arithmetic @@ -478,13 +486,14 @@ class IsotropicMixin(abc.ABC): # pylint: disable=too-few-public-methods Hence, all isotropic kernels are stationary. """ + @backend.jit_method def _squared_euclidean_distances( - self, x0: np.ndarray, x1: Optional[np.ndarray] - ) -> np.ndarray: + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: """Implementation of the squared Euclidean distance, which supports scalar inputs and an optional second argument.""" if x1 is None: - return np.zeros_like( # pylint: disable=unexpected-keyword-arg + return backend.zeros_like( x0, shape=x0.shape[: x0.ndim - self._input_ndim], ) @@ -496,17 +505,18 @@ def _squared_euclidean_distances( assert self.input_ndim == 1 - return np.sum(sqdiffs, axis=-1) + return backend.sum(sqdiffs, axis=-1) + @backend.jit_method def _euclidean_distances( - self, x0: np.ndarray, x1: Optional[np.ndarray] - ) -> np.ndarray: + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: """Implementation of the Euclidean distance, which supports scalar inputs and an optional second argument.""" if x1 is None: - return np.zeros_like( # pylint: disable=unexpected-keyword-arg + return backend.zeros_like( x0, shape=x0.shape[: x0.ndim - self._input_ndim], ) - return np.sqrt(self._squared_euclidean_distances(x0, x1)) + return backend.sqrt(self._squared_euclidean_distances(x0, x1)) diff --git a/src/probnum/randprocs/kernels/_linear.py b/src/probnum/randprocs/kernels/_linear.py index 4cbd82993..6d870d4bc 100644 --- a/src/probnum/randprocs/kernels/_linear.py +++ b/src/probnum/randprocs/kernels/_linear.py @@ -4,10 +4,8 @@ from typing import Optional -import numpy as np - -from probnum.typing import ScalarLike, ShapeLike -import probnum.utils as _utils +from probnum import backend +from probnum.backend.typing import ScalarLike, ShapeLike from ._kernel import Kernel @@ -33,18 +31,21 @@ class Linear(Kernel): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import Linear >>> K = Linear(input_shape=2) - >>> xs = np.array([[1, 2], [2, 3]]) + >>> xs = backend.asarray([[1, 2], [2, 3]]) >>> K.matrix(xs) array([[ 5., 8.], [ 8., 13.]]) """ def __init__(self, input_shape: ShapeLike, constant: ScalarLike = 0.0): - self.constant = _utils.as_numpy_scalar(constant) + self.constant = backend.asscalar(constant) super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: + @backend.jit_method + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: return self._euclidean_inner_products(x0, x1) + self.constant diff --git a/src/probnum/randprocs/kernels/_matern.py b/src/probnum/randprocs/kernels/_matern.py index 99e28dc03..5c0320e51 100644 --- a/src/probnum/randprocs/kernels/_matern.py +++ b/src/probnum/randprocs/kernels/_matern.py @@ -2,12 +2,8 @@ from typing import Optional -import numpy as np -import scipy.spatial.distance -import scipy.special - -from probnum.typing import ScalarLike, ShapeLike -import probnum.utils as _utils +from probnum import backend +from probnum.backend.typing import FloatLike, ScalarLike, ShapeLike from ._kernel import IsotropicMixin, Kernel @@ -52,10 +48,10 @@ class Matern(Kernel, IsotropicMixin): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import Matern >>> K = Matern(input_shape=(), lengthscale=0.1, nu=2.5) - >>> xs = np.linspace(0, 1, 3) + >>> xs = backend.linspace(0, 1, 3) >>> K.matrix(xs) array([[1.00000000e+00, 7.50933789e-04, 3.69569622e-08], [7.50933789e-04, 1.00000000e+00, 7.50933789e-04], @@ -66,52 +62,55 @@ def __init__( self, input_shape: ShapeLike, lengthscale: ScalarLike = 1.0, - nu: ScalarLike = 1.5, + nu: FloatLike = 1.5, ): - self.lengthscale = _utils.as_numpy_scalar(lengthscale) - if not self.lengthscale > 0: + self.lengthscale = backend.asscalar(lengthscale) + if self.lengthscale <= 0.0: raise ValueError(f"Lengthscale l={self.lengthscale} must be positive.") - self.nu = _utils.as_numpy_scalar(nu) - if not self.nu > 0: + self.nu = float(nu) + if self.nu <= 0.0: raise ValueError(f"Hyperparameter nu={self.nu} must be positive.") super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + @backend.jit_method + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] = None + ) -> backend.Array: distances = self._euclidean_distances(x0, x1) # Kernel matrix computation dependent on differentiability if self.nu == 0.5: - return np.exp(-1.0 / self.lengthscale * distances) + return backend.exp(-1.0 / self.lengthscale * distances) if self.nu == 1.5: - scaled_distances = np.sqrt(3) / self.lengthscale * distances - return (1.0 + scaled_distances) * np.exp(-scaled_distances) + scaled_distances = backend.sqrt(3) / self.lengthscale * distances + return (1.0 + scaled_distances) * backend.exp(-scaled_distances) if self.nu == 2.5: - scaled_distances = np.sqrt(5) / self.lengthscale * distances - return (1.0 + scaled_distances + scaled_distances**2 / 3.0) * np.exp( + scaled_distances = backend.sqrt(5) / self.lengthscale * distances + return (1.0 + scaled_distances + scaled_distances**2 / 3.0) * backend.exp( -scaled_distances ) if self.nu == 3.5: - scaled_distances = np.sqrt(7) / self.lengthscale * distances + scaled_distances = backend.sqrt(7) / self.lengthscale * distances # Using Horner's method speeds up computations substantially return ( 1.0 + (1.0 + (2.0 / 5.0 + scaled_distances / 15.0) * scaled_distances) * scaled_distances - ) * np.exp(-scaled_distances) + ) * backend.exp(-scaled_distances) - if self.nu == np.inf: - return np.exp(-1.0 / (2.0 * self.lengthscale**2) * distances**2) + if self.nu == backend.inf: + return backend.exp(-1.0 / (2.0 * self.lengthscale**2) * distances**2) # The modified Bessel function K_nu is not defined for z=0 - distances = np.maximum(distances, np.finfo(distances.dtype).eps) + distances = backend.maximum(distances, backend.finfo(distances.dtype).eps) - scaled_distances = np.sqrt(2 * self.nu) / self.lengthscale * distances + scaled_distances = backend.sqrt(2 * self.nu) / self.lengthscale * distances return ( 2 ** (1.0 - self.nu) - / scipy.special.gamma(self.nu) + / backend.special.gamma(self.nu) * scaled_distances**self.nu - * scipy.special.kv(self.nu, scaled_distances) + * backend.special.modified_bessel2(scaled_distances, order=self.nu) ) diff --git a/src/probnum/randprocs/kernels/_polynomial.py b/src/probnum/randprocs/kernels/_polynomial.py index 6828dc583..518de421e 100644 --- a/src/probnum/randprocs/kernels/_polynomial.py +++ b/src/probnum/randprocs/kernels/_polynomial.py @@ -2,10 +2,8 @@ from typing import Optional -import numpy as np - -from probnum.typing import IntLike, ScalarLike, ShapeLike -import probnum.utils as _utils +from probnum import backend +from probnum.backend.typing import IntLike, ScalarLike, ShapeLike from ._kernel import Kernel @@ -33,10 +31,10 @@ class Polynomial(Kernel): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import Polynomial >>> K = Polynomial(input_shape=2, constant=1.0, exponent=3) - >>> xs = np.array([[1, -1], [-1, 0]]) + >>> xs = backend.asarray([[1, -1], [-1, 0]]) >>> K.matrix(xs) array([[27., 0.], [ 0., 8.]]) @@ -48,9 +46,12 @@ def __init__( constant: ScalarLike = 0.0, exponent: IntLike = 1.0, ): - self.constant = _utils.as_numpy_scalar(constant) - self.exponent = _utils.as_numpy_scalar(exponent) + self.constant = backend.asscalar(constant) + self.exponent = backend.asscalar(exponent) super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + @backend.jit_method + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] = None + ) -> backend.Array: return (self._euclidean_inner_products(x0, x1) + self.constant) ** self.exponent diff --git a/src/probnum/randprocs/kernels/_product_matern.py b/src/probnum/randprocs/kernels/_product_matern.py index e4f522d33..d302f295a 100644 --- a/src/probnum/randprocs/kernels/_product_matern.py +++ b/src/probnum/randprocs/kernels/_product_matern.py @@ -1,11 +1,9 @@ """Product Matern kernel.""" -from typing import Optional, Union +from typing import Optional -import numpy as np - -from probnum import utils as _utils -from probnum.typing import ScalarLike, ShapeLike +from probnum import backend +from probnum.backend.typing import ArrayLike, ShapeLike from ._kernel import Kernel from ._matern import Matern @@ -38,12 +36,12 @@ class ProductMatern(Kernel): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import ProductMatern - >>> lengthscales = np.array([0.1, 1.2]) - >>> nus = np.array([0.5, 3.5]) + >>> lengthscales = backend.asarray([0.1, 1.2]) + >>> nus = backend.asarray([0.5, 3.5]) >>> K = ProductMatern(input_shape=(2,), lengthscales=lengthscales, nus=nus) - >>> xs = np.array([[0.0, 0.5], [1.0, 1.0], [0.5, 0.2]]) + >>> xs = backend.asarray([[0.0, 0.5], [1.0, 1.0], [0.5, 0.2]]) >>> K.matrix(xs) array([[1.00000000e+00, 4.03712525e-05, 6.45332482e-03], [4.03712525e-05, 1.00000000e+00, 5.05119251e-03], @@ -58,11 +56,14 @@ class ProductMatern(Kernel): def __init__( self, input_shape: ShapeLike, - lengthscales: Union[np.ndarray, ScalarLike], - nus: Union[np.ndarray, ScalarLike], + lengthscales: ArrayLike, + nus: ArrayLike, ): - input_shape = _utils.as_shape(input_shape) - if input_shape == () and not (np.isscalar(lengthscales) and np.isscalar(nus)): + input_shape = backend.asshape(input_shape) + + if input_shape == () and not ( + backend.ndim(lengthscales) == 0 and backend.ndim(nus) == 0 + ): raise ValueError( f"'lengthscales' and 'nus' must be scalar if 'input_shape' is " f"{input_shape}." @@ -72,33 +73,34 @@ def __init__( # If only single scalar lengthcsale or nu is given, use this in every dimension def expand_array(x, ndim): - return np.full((ndim,), _utils.as_numpy_scalar(x)) + return backend.full((ndim,), backend.asscalar(x)) - if isinstance(lengthscales, np.ndarray): - if lengthscales.shape == (): - lengthscales = expand_array(lengthscales, input_dim) - if isinstance(nus, np.ndarray): - if nus.shape == (): - nus = expand_array(nus, input_dim) + lengthscales = backend.asarray(lengthscales) - # also expand if scalars are given - if np.isscalar(lengthscales): + if lengthscales.shape == (): lengthscales = expand_array(lengthscales, input_dim) - if np.isscalar(nus): + + self.lengthscales = lengthscales + + nus = backend.asarray(nus) + + if nus.shape == (): nus = expand_array(nus, input_dim) + self.nus = nus + univariate_materns = [] for dim in range(input_dim): univariate_materns.append( Matern(input_shape=(), lengthscale=lengthscales[dim], nu=nus[dim]) ) self.univariate_materns = univariate_materns - self.nus = nus - self.lengthscales = lengthscales super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] = None + ) -> backend.Array: # scalar case is same as a scalar Matern if self.input_shape == (): diff --git a/src/probnum/randprocs/kernels/_rational_quadratic.py b/src/probnum/randprocs/kernels/_rational_quadratic.py index ad970a0b1..128bb670c 100644 --- a/src/probnum/randprocs/kernels/_rational_quadratic.py +++ b/src/probnum/randprocs/kernels/_rational_quadratic.py @@ -2,10 +2,8 @@ from typing import Optional -import numpy as np - -from probnum.typing import ScalarLike, ShapeLike -import probnum.utils as _utils +from probnum import backend +from probnum.backend.typing import ScalarLike, ShapeLike from ._kernel import IsotropicMixin, Kernel @@ -46,10 +44,10 @@ class RatQuad(Kernel, IsotropicMixin): Examples -------- - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randprocs.kernels import RatQuad >>> K = RatQuad(input_shape=1, lengthscale=0.1, alpha=3) - >>> xs = np.linspace(0, 1, 3)[:, None] + >>> xs = backend.linspace(0, 1, 3)[:, None] >>> K(xs[:, None, :], xs[None, :, :]) array([[1.00000000e+00, 7.25051190e-03, 1.81357765e-04], [7.25051190e-03, 1.00000000e+00, 7.25051190e-03], @@ -62,15 +60,17 @@ def __init__( lengthscale: ScalarLike = 1.0, alpha: ScalarLike = 1.0, ): - self.lengthscale = _utils.as_numpy_scalar(lengthscale) - self.alpha = _utils.as_numpy_scalar(alpha) + self.lengthscale = backend.asscalar(lengthscale) + self.alpha = backend.asscalar(alpha) if not self.alpha > 0: raise ValueError(f"Scale mixture alpha={self.alpha} must be positive.") super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] = None + ) -> backend.Array: if x1 is None: - return np.ones_like( # pylint: disable=unexpected-keyword-arg + return backend.ones_like( # pylint: disable=unexpected-keyword-arg x0, shape=x0.shape[: x0.ndim - self.input_ndim], ) diff --git a/src/probnum/randprocs/kernels/_white_noise.py b/src/probnum/randprocs/kernels/_white_noise.py index f66659334..3b4e1d894 100644 --- a/src/probnum/randprocs/kernels/_white_noise.py +++ b/src/probnum/randprocs/kernels/_white_noise.py @@ -2,10 +2,8 @@ from typing import Optional -import numpy as np - -from probnum import utils as _utils -from probnum.typing import ScalarLike, ShapeLike +from probnum import backend +from probnum.backend.typing import ScalarLike, ShapeLike from ._kernel import Kernel @@ -31,13 +29,15 @@ def __init__(self, input_shape: ShapeLike, sigma_sq: ScalarLike = 1.0): if sigma_sq < 0: raise ValueError(f"Noise level sigma_sq={sigma_sq} must be non-negative.") - self.sigma_sq = _utils.as_numpy_scalar(sigma_sq) + self.sigma_sq = backend.asscalar(sigma_sq) super().__init__(input_shape=input_shape) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: if x1 is None: - return np.full_like( # pylint: disable=unexpected-keyword-arg + return backend.full_like( # pylint: disable=unexpected-keyword-arg x0, self.sigma_sq, shape=x0.shape[: x0.ndim - self.input_ndim], @@ -46,4 +46,4 @@ def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: if self.input_shape == (): return self.sigma_sq * (x0 == x1) - return self.sigma_sq * np.all(x0 == x1, axis=-1) + return self.sigma_sq * backend.all(x0 == x1, axis=-1) diff --git a/src/probnum/randprocs/markov/_markov.py b/src/probnum/randprocs/markov/_markov.py index bb0d89dbf..ecf660922 100644 --- a/src/probnum/randprocs/markov/_markov.py +++ b/src/probnum/randprocs/markov/_markov.py @@ -1,17 +1,12 @@ """Markovian processes.""" -from typing import Optional, Union +from typing import Optional -import numpy as np -import scipy.stats - -from probnum import functions, randvars, utils +from probnum import backend, functions, randvars +from probnum.backend.random import RNGState +from probnum.backend.typing import ArrayLike, ShapeLike from probnum.randprocs import _random_process, kernels from probnum.randprocs.markov import _transition, continuous, discrete -from probnum.typing import ShapeLike - -InputType = Union[np.floating, np.ndarray] -OutputType = Union[np.floating, np.ndarray] class _MarkovBase(_random_process.RandomProcess): @@ -30,7 +25,7 @@ def __init__( super().__init__( input_shape=input_shape, output_shape=output_shape, - dtype=np.dtype(np.float_), + dtype=backend.float64, mean=functions.LambdaFunction( lambda x: self.__call__(args=x).mean, input_shape=input_shape, @@ -43,42 +38,43 @@ def __init__( ), ) - def __call__(self, args: InputType) -> randvars.RandomVariable: + def __call__(self, args: ArrayLike) -> randvars.RandomVariable: raise NotImplementedError def _sample_at_input( self, - rng: np.random.Generator, - args: InputType, - size: ShapeLike = (), - ) -> OutputType: + rng_state: RNGState, + args: ArrayLike, + sample_shape: ShapeLike = (), + ) -> backend.Array: - size = utils.as_shape(size) - args = np.atleast_1d(args) + sample_shape = backend.asshape(sample_shape) + args = backend.asarray(args) if args.ndim > 1: raise ValueError(f"Invalid args shape {args.shape}") - base_measure_realizations = scipy.stats.norm.rvs( - size=(size + args.shape + self.initrv.shape), random_state=rng + base_measure_realizations = backend.random.standard_normal( + rng_state=rng_state, + shape=(sample_shape + args.shape + self.initrv.shape), ) - if size == (): - return np.array( + if sample_shape == (): + return backend.asarray( self.transition.jointly_transform_base_measure_realization_list_forward( base_measure_realizations=base_measure_realizations, t=args, initrv=self.initrv, - _diffusion_list=np.ones_like(args[:-1]), + _diffusion_list=backend.ones_like(args[:-1]), ) ) - return np.stack( + return backend.stack( [ self.transition.jointly_transform_base_measure_realization_list_forward( base_measure_realizations=base_real, t=args, initrv=self.initrv, - _diffusion_list=np.ones_like(args[:-1]), + _diffusion_list=backend.ones_like(args[:-1]), ) for base_real in base_measure_realizations ] @@ -95,7 +91,9 @@ def __init__( output_shape=output_shape, ) - def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray: + def _evaluate( + self, x0: backend.Array, x1: Optional[backend.Array] + ) -> backend.Array: if x1 is None: return self._markov_proc_call(args=x0).cov @@ -129,7 +127,7 @@ class MarkovProcess(_MarkovBase): def __init__( self, *, - initarg: np.ndarray, + initarg: backend.Array, initrv: randvars.RandomVariable, transition: continuous.SDE, ): @@ -140,7 +138,7 @@ def __init__( super().__init__( initrv=initrv, transition=transition, - input_shape=np.asarray(initarg).shape, + input_shape=backend.asarray(initarg).shape, ) self.initarg = initarg @@ -151,7 +149,7 @@ class MarkovSequence(_MarkovBase): def __init__( self, *, - initarg: np.ndarray, + initarg: backend.Array, initrv: randvars.RandomVariable, transition: continuous.SDE, ): @@ -162,6 +160,6 @@ def __init__( super().__init__( initrv=initrv, transition=transition, - input_shape=np.asarray(initarg).shape, + input_shape=backend.asarray(initarg).shape, ) self.initarg = initarg diff --git a/src/probnum/randprocs/markov/_transition.py b/src/probnum/randprocs/markov/_transition.py index 25c94f004..0cbbaa4bb 100644 --- a/src/probnum/randprocs/markov/_transition.py +++ b/src/probnum/randprocs/markov/_transition.py @@ -5,7 +5,7 @@ import numpy as np from probnum import randvars -from probnum.typing import FloatLike, IntLike +from probnum.backend.typing import FloatLike, IntLike class Transition(abc.ABC): @@ -379,7 +379,7 @@ def jointly_transform_base_measure_realization_list_backward( """ curr_rv = rv_list[-1] - curr_sample = curr_rv.mean + curr_rv.cov_cholesky @ base_measure_realizations[ + curr_sample = curr_rv.mean + curr_rv._cov_cholesky @ base_measure_realizations[ -1 ].reshape((-1,)) out_samples = [curr_sample] @@ -403,7 +403,7 @@ def jointly_transform_base_measure_realization_list_backward( ) curr_sample = ( curr_rv.mean - + curr_rv.cov_cholesky + + curr_rv._cov_cholesky @ base_measure_realizations[idx - 1].reshape( -1, ) @@ -448,7 +448,7 @@ def jointly_transform_base_measure_realization_list_forward( """ curr_rv = initrv - curr_sample = curr_rv.mean + curr_rv.cov_cholesky @ base_measure_realizations[ + curr_sample = curr_rv.mean + curr_rv._cov_cholesky @ base_measure_realizations[ 0 ].reshape((-1,)) out_samples = [curr_sample] @@ -470,7 +470,7 @@ def jointly_transform_base_measure_realization_list_forward( ) curr_sample = ( curr_rv.mean - + curr_rv.cov_cholesky + + curr_rv._cov_cholesky @ base_measure_realizations[idx - 1].reshape((-1,)) ) out_samples.append(curr_sample) diff --git a/src/probnum/randprocs/markov/continuous/_diffusions.py b/src/probnum/randprocs/markov/continuous/_diffusions.py index 43d33cf1d..a103633fa 100644 --- a/src/probnum/randprocs/markov/continuous/_diffusions.py +++ b/src/probnum/randprocs/markov/continuous/_diffusions.py @@ -8,7 +8,7 @@ import scipy.linalg from probnum import randvars -from probnum.typing import ArrayIndicesLike, ArrayLike, FloatLike +from probnum.backend.typing import ArrayIndicesLike, ArrayLike, FloatLike class Diffusion(abc.ABC): @@ -194,7 +194,7 @@ def tmax(self) -> float: def _compute_local_quasi_mle(meas_rv): - std_like = meas_rv.cov_cholesky + std_like = meas_rv._cov_cholesky whitened_res = scipy.linalg.solve_triangular(std_like, meas_rv.mean, lower=True) ssq = whitened_res @ whitened_res / meas_rv.size return ssq diff --git a/src/probnum/randprocs/markov/continuous/_linear_sde.py b/src/probnum/randprocs/markov/continuous/_linear_sde.py index 2e7f611f8..6b1da0929 100644 --- a/src/probnum/randprocs/markov/continuous/_linear_sde.py +++ b/src/probnum/randprocs/markov/continuous/_linear_sde.py @@ -7,9 +7,9 @@ import scipy.linalg from probnum import randvars +from probnum.backend.linalg import tril_to_positive_tril +from probnum.backend.typing import FloatLike, IntLike from probnum.randprocs.markov.continuous import _sde -from probnum.typing import FloatLike, IntLike -from probnum.utils.linalg import tril_to_positive_tril class LinearSDE(_sde.SDE): @@ -212,7 +212,7 @@ def _solve_mde_forward_sqrt(self, rv, t, dt, _diffusion=1.0): ) return randvars.Normal( - mean=new_mean, cov=new_cov, cov_cholesky=new_cov_cholesky + mean=new_mean, cov=new_cov, cache={"cov_cholesky": new_cov_cholesky} ), { "sol": sol, "sol_mean": sol_mean, @@ -403,7 +403,7 @@ def f(t, y): y_new = np.hstack((new_mean, new_cov_cholesky_flat)) return y_new - initcov_cholesky_flat = initrv.cov_cholesky.flatten() + initcov_cholesky_flat = initrv._cov_cholesky.flatten() y0 = np.hstack((initrv.mean, initcov_cholesky_flat)) return f, y0 diff --git a/src/probnum/randprocs/markov/continuous/_sde.py b/src/probnum/randprocs/markov/continuous/_sde.py index 0dd54a77a..435c51781 100644 --- a/src/probnum/randprocs/markov/continuous/_sde.py +++ b/src/probnum/randprocs/markov/continuous/_sde.py @@ -4,8 +4,8 @@ import numpy as np +from probnum.backend.typing import FloatLike, IntLike from probnum.randprocs.markov import _transition -from probnum.typing import FloatLike, IntLike class SDE(_transition.Transition): diff --git a/src/probnum/randprocs/markov/discrete/_condition_state.py b/src/probnum/randprocs/markov/discrete/_condition_state.py index 4ff4b35c7..bbf3f0dbc 100644 --- a/src/probnum/randprocs/markov/discrete/_condition_state.py +++ b/src/probnum/randprocs/markov/discrete/_condition_state.py @@ -6,7 +6,9 @@ def condition_state_on_measurement(measurement, forwarded_rv, rv, gain): zero_mat = np.zeros((len(measurement), len(measurement))) - meas_as_rv = randvars.Normal(mean=measurement, cov=zero_mat, cov_cholesky=zero_mat) + meas_as_rv = randvars.Normal( + mean=measurement, cov=zero_mat, cache={"cov_cholesky": zero_mat} + ) return condition_state_on_rv(meas_as_rv, forwarded_rv, rv, gain) diff --git a/src/probnum/randprocs/markov/discrete/_linear_gaussian.py b/src/probnum/randprocs/markov/discrete/_linear_gaussian.py index 150d095c7..afa4ce7be 100644 --- a/src/probnum/randprocs/markov/discrete/_linear_gaussian.py +++ b/src/probnum/randprocs/markov/discrete/_linear_gaussian.py @@ -7,9 +7,10 @@ import scipy.linalg from probnum import config, linops, randvars +from probnum.backend.linalg import cholesky_update, tril_to_positive_tril +from probnum.backend.typing import FloatLike, IntLike from probnum.randprocs.markov.discrete import _nonlinear_gaussian -from probnum.typing import FloatLike, IntLike, LinearOperatorLike -from probnum.utils.linalg import cholesky_update, tril_to_positive_tril +from probnum.typing import LinearOperatorLike class LinearGaussian(_nonlinear_gaussian.NonlinearGaussian): @@ -193,11 +194,11 @@ def _forward_rv_sqrt( H = self.transition_matrix_fun(t) noise = self.noise_fun(t) - shift, SR = noise.mean, noise.cov_cholesky + shift, SR = noise.mean, noise._cov_cholesky new_mean = H @ rv.mean + shift new_cov_cholesky = cholesky_update( - H @ rv.cov_cholesky, np.sqrt(_diffusion) * SR + H @ rv._cov_cholesky, np.sqrt(_diffusion) * SR ) new_cov = new_cov_cholesky @ new_cov_cholesky.T crosscov = rv.cov @ H.T @@ -207,7 +208,9 @@ def _forward_rv_sqrt( (new_cov_cholesky, True), crosscov.T ).T return ( - randvars.Normal(new_mean, cov=new_cov, cov_cholesky=new_cov_cholesky), + randvars.Normal( + new_mean, cov=new_cov, cache={"cov_cholesky": new_cov_cholesky} + ), info, ) @@ -247,10 +250,10 @@ def _backward_rv_sqrt( state_trans = self.transition_matrix_fun(t) noise = self.noise_fun(t) shift = noise.mean - proc_noise_chol = np.sqrt(_diffusion) * noise.cov_cholesky + proc_noise_chol = np.sqrt(_diffusion) * noise._cov_cholesky - chol_past = rv.cov_cholesky - chol_obtained = rv_obtained.cov_cholesky + chol_past = rv._cov_cholesky + chol_obtained = rv_obtained._cov_cholesky output_dim = self.output_dim input_dim = self.input_dim @@ -284,7 +287,12 @@ def _backward_rv_sqrt( new_cov = new_cov_cholesky @ new_cov_cholesky.T info = {"rv_forwarded": rv_forwarded} - return randvars.Normal(new_mean, new_cov, cov_cholesky=new_cov_cholesky), info + return ( + randvars.Normal( + new_mean, new_cov, cache={"cov_cholesky": new_cov_cholesky} + ), + info, + ) def _backward_rv_joseph( self, diff --git a/src/probnum/randprocs/markov/discrete/_lti_gaussian.py b/src/probnum/randprocs/markov/discrete/_lti_gaussian.py index 1971b2769..07047de86 100644 --- a/src/probnum/randprocs/markov/discrete/_lti_gaussian.py +++ b/src/probnum/randprocs/markov/discrete/_lti_gaussian.py @@ -2,8 +2,9 @@ from probnum import randvars +from probnum.backend.typing import ArrayLike from probnum.randprocs.markov.discrete import _linear_gaussian -from probnum.typing import ArrayLike, LinearOperatorLike +from probnum.typing import LinearOperatorLike class LTIGaussian(_linear_gaussian.LinearGaussian): diff --git a/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py b/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py index c92cf0549..bb8796d46 100644 --- a/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py +++ b/src/probnum/randprocs/markov/discrete/_nonlinear_gaussian.py @@ -5,9 +5,9 @@ import numpy as np from probnum import randvars +from probnum.backend.typing import ArrayLike, FloatLike, IntLike from probnum.randprocs.markov import _transition from probnum.randprocs.markov.discrete import _condition_state -from probnum.typing import ArrayLike, FloatLike, IntLike class NonlinearGaussian(_transition.Transition): diff --git a/src/probnum/randprocs/markov/integrator/_ioup.py b/src/probnum/randprocs/markov/integrator/_ioup.py index e024c35b0..20ca389ef 100644 --- a/src/probnum/randprocs/markov/integrator/_ioup.py +++ b/src/probnum/randprocs/markov/integrator/_ioup.py @@ -108,7 +108,7 @@ def __init__( zeros = np.zeros(ioup_transition.state_dimension) cov_cholesky = scale_cholesky * np.eye(ioup_transition.state_dimension) initrv = randvars.Normal( - mean=zeros, cov=cov_cholesky**2, cov_cholesky=cov_cholesky + mean=zeros, cov=cov_cholesky**2, cache={"cov_cholesky": cov_cholesky} ) super().__init__(transition=ioup_transition, initrv=initrv, initarg=initarg) diff --git a/src/probnum/randprocs/markov/integrator/_iwp.py b/src/probnum/randprocs/markov/integrator/_iwp.py index b62c6ac59..2c6e8754d 100644 --- a/src/probnum/randprocs/markov/integrator/_iwp.py +++ b/src/probnum/randprocs/markov/integrator/_iwp.py @@ -107,7 +107,7 @@ def __init__( zeros = np.zeros(iwp_transition.state_dimension) cov_cholesky = scale_cholesky * np.eye(iwp_transition.state_dimension) initrv = randvars.Normal( - mean=zeros, cov=cov_cholesky**2, cov_cholesky=cov_cholesky + mean=zeros, cov=cov_cholesky**2, cache={"cov_cholesky": cov_cholesky} ) super().__init__(transition=iwp_transition, initrv=initrv, initarg=initarg) @@ -208,7 +208,7 @@ def equivalent_discretisation_preconditioned(self): return discrete.LTIGaussian( transition_matrix=state_transition, noise=randvars.Normal( - mean=empty_shift, cov=noise, cov_cholesky=noise_cholesky + mean=empty_shift, cov=noise, cache={"cov_cholesky": noise_cholesky} ), forward_implementation=self.forward_implementation, backward_implementation=self.backward_implementation, @@ -302,7 +302,7 @@ def discretise(self, dt): # always exists, even for non-square root implementations. proc_noise_cov_cholesky = ( self.precon(dt) - @ self.equivalent_discretisation_preconditioned.noise.cov_cholesky + @ self.equivalent_discretisation_preconditioned.noise._cov_cholesky ) return discrete.LTIGaussian( @@ -310,7 +310,7 @@ def discretise(self, dt): noise=randvars.Normal( mean=zero_shift, cov=proc_noise_cov_mat, - cov_cholesky=proc_noise_cov_cholesky, + cache={"cov_cholesky": proc_noise_cov_cholesky}, ), forward_implementation=self.forward_implementation, backward_implementation=self.forward_implementation, diff --git a/src/probnum/randprocs/markov/integrator/_matern.py b/src/probnum/randprocs/markov/integrator/_matern.py index e8db12d82..489396678 100644 --- a/src/probnum/randprocs/markov/integrator/_matern.py +++ b/src/probnum/randprocs/markov/integrator/_matern.py @@ -108,7 +108,7 @@ def __init__( zeros = np.zeros(matern_transition.state_dimension) cov_cholesky = scale_cholesky * np.eye(matern_transition.state_dimension) initrv = randvars.Normal( - mean=zeros, cov=cov_cholesky**2, cov_cholesky=cov_cholesky + mean=zeros, cov=cov_cholesky**2, cache={"cov_cholesky": cov_cholesky} ) super().__init__(transition=matern_transition, initrv=initrv, initarg=initarg) diff --git a/src/probnum/randprocs/markov/integrator/_preconditioner.py b/src/probnum/randprocs/markov/integrator/_preconditioner.py index 3be0a8859..8b8aaf006 100644 --- a/src/probnum/randprocs/markov/integrator/_preconditioner.py +++ b/src/probnum/randprocs/markov/integrator/_preconditioner.py @@ -23,10 +23,10 @@ def apply_precon(precon, rv): # When they are resolved, this function here will hopefully be superfluous. new_mean = precon @ rv.mean - new_cov_cholesky = precon @ rv.cov_cholesky # precon is diagonal, so this is valid + new_cov_cholesky = precon @ rv._cov_cholesky # precon is diagonal, so this is valid new_cov = new_cov_cholesky @ new_cov_cholesky.T - return randvars.Normal(new_mean, new_cov, cov_cholesky=new_cov_cholesky) + return randvars.Normal(new_mean, new_cov, cache={"cov_cholesky": new_cov_cholesky}) class Preconditioner(abc.ABC): diff --git a/src/probnum/randprocs/markov/integrator/convert/_convert.py b/src/probnum/randprocs/markov/integrator/convert/_convert.py index 5e4240e1a..8fada1e0e 100644 --- a/src/probnum/randprocs/markov/integrator/convert/_convert.py +++ b/src/probnum/randprocs/markov/integrator/convert/_convert.py @@ -2,8 +2,8 @@ import numpy as np +from probnum.backend.typing import IntLike from probnum.randprocs.markov.integrator import _integrator -from probnum.typing import IntLike def convert_derivwise_to_coordwise( diff --git a/src/probnum/randprocs/markov/utils/_generate_measurements.py b/src/probnum/randprocs/markov/utils/_generate_measurements.py index 8b029d55e..a03d9dc26 100644 --- a/src/probnum/randprocs/markov/utils/_generate_measurements.py +++ b/src/probnum/randprocs/markov/utils/_generate_measurements.py @@ -2,6 +2,7 @@ import numpy as np +from probnum import backend from probnum.randprocs.markov import _markov, _transition @@ -34,9 +35,14 @@ def generate_artificial_measurements( """ obs = np.zeros((len(times), measmod.output_dim)) - latent_states = prior_process.sample(rng, args=times) + rng_state = backend.random.rng_state( + int(rng.bit_generator._seed_seq.generate_state(1, dtype=np.uint64)[0] // 2) + ) + latent_states_rng_state, rng_state = backend.random.split(rng_state, num=2) + latent_states = prior_process.sample(rng_state=latent_states_rng_state, args=times) for idx, (state, t) in enumerate(zip(latent_states, times)): measured_rv, _ = measmod.forward_realization(state, t=t) - obs[idx] = measured_rv.sample(rng=rng) + sample_rng_state, rng_state = backend.random.split(rng_state, num=2) + obs[idx] = measured_rv.sample(rng_state=sample_rng_state) return latent_states, obs diff --git a/src/probnum/randvars/__init__.py b/src/probnum/randvars/__init__.py index 64832a8d6..e00ad1436 100644 --- a/src/probnum/randvars/__init__.py +++ b/src/probnum/randvars/__init__.py @@ -15,11 +15,7 @@ RandomVariable, ) from ._randomvariablelist import _RandomVariableList -from ._scipy_stats import ( - WrappedSciPyContinuousRandomVariable, - WrappedSciPyDiscreteRandomVariable, - WrappedSciPyRandomVariable, -) +from ._sym_mat_normal import SymmetricMatrixNormal from ._utils import asrandvar # Public classes and functions. Order is reflected in documentation. @@ -30,10 +26,8 @@ "ContinuousRandomVariable", "Constant", "Normal", + "SymmetricMatrixNormal", "Categorical", - "WrappedSciPyRandomVariable", - "WrappedSciPyDiscreteRandomVariable", - "WrappedSciPyContinuousRandomVariable", "_RandomVariableList", ] @@ -42,12 +36,9 @@ DiscreteRandomVariable.__module__ = "probnum.randvars" ContinuousRandomVariable.__module__ = "probnum.randvars" -WrappedSciPyRandomVariable.__module__ = "probnum.randvars" -WrappedSciPyDiscreteRandomVariable.__module__ = "probnum.randvars" -WrappedSciPyContinuousRandomVariable.__module__ = "probnum.randvars" - Constant.__module__ = "probnum.randvars" Normal.__module__ = "probnum.randvars" +SymmetricMatrixNormal.__module__ = "probnum.randvars" Categorical.__module__ = "probnum.randvars" _RandomVariableList.__module__ = "probnum.randvars" diff --git a/src/probnum/randvars/_arithmetic.py b/src/probnum/randvars/_arithmetic.py index 1ea239029..bb52f1c73 100644 --- a/src/probnum/randvars/_arithmetic.py +++ b/src/probnum/randvars/_arithmetic.py @@ -4,9 +4,8 @@ import operator from typing import Any, Callable, Dict, Tuple, Union -import numpy as np - -from probnum import utils as _utils +from probnum import backend +from probnum.backend.typing import NotImplementedType import probnum.linops as _linear_operators from ._constant import Constant as _Constant @@ -56,7 +55,7 @@ def pow_(rv1: Any, rv2: Any) -> _RandomVariable: ######################################################################################## _RandomVariableBinaryOperator = Callable[ - [_RandomVariable, _RandomVariable], Union[_RandomVariable, type(NotImplemented)] + [_RandomVariable, _RandomVariable], Union[_RandomVariable, NotImplementedType] ] _OperatorRegistryType = Dict[Tuple[type, type], _RandomVariableBinaryOperator] @@ -76,12 +75,12 @@ def _apply( op_registry: _OperatorRegistryType, rv1: Any, rv2: Any, -) -> Union[_RandomVariable, type(NotImplemented)]: +) -> Union[_RandomVariable, NotImplementedType]: # Convert arguments to random variables rv1 = _asrandvar(rv1) rv2 = _asrandvar(rv2) - # Search specific operatir + # Search specific operator key = (type(rv1), type(rv2)) if key in op_registry: @@ -125,13 +124,16 @@ def _rv_binary_op(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable def _make_rv_binary_op_result_shape_dtype_sample_fn(op_fn, rv1, rv2): - sample_fn = lambda rng, size: op_fn( - rv1.sample(rng, size=size), - rv2.sample(rng, size=size), - ) + def sample_fn(rng_state, sample_shape): + rng_state1, rng_state2 = backend.random.split(rng_state, 2) + + return op_fn( + rv1.sample(rng_state=rng_state1, sample_shape=sample_shape), + rv2.sample(rng_state=rng_state2, sample_shape=sample_shape), + ) # Infer shape and dtype - infer_sample = sample_fn(np.random.default_rng(1), ()) + infer_sample = sample_fn(backend.random.rng_state(1), ()) shape = infer_sample.shape dtype = infer_sample.dtype @@ -181,29 +183,32 @@ def _generic_rv_add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariab # Constant - Constant Arithmetic ######################################################################################## -_add_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.add) -_sub_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.sub) -_mul_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.mul) -_matmul_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory( +_constant_constant_operator_factory = ( + _Constant._binary_operator_factory # pylint: disable=protected-access +) + +_add_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.add) +_sub_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.sub) +_mul_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.mul) +_matmul_fns[(_Constant, _Constant)] = _constant_constant_operator_factory( operator.matmul ) -_truediv_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory( +_truediv_fns[(_Constant, _Constant)] = _constant_constant_operator_factory( operator.truediv ) -_floordiv_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory( +_floordiv_fns[(_Constant, _Constant)] = _constant_constant_operator_factory( operator.floordiv ) -_mod_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.mod) -_divmod_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(divmod) -_pow_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.pow) +_mod_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.mod) +_divmod_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(divmod) +_pow_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.pow) ######################################################################################## # Normal - Normal Arithmetic ######################################################################################## -_add_fns[(_Normal, _Normal)] = _Normal._add_normal -_sub_fns[(_Normal, _Normal)] = _Normal._sub_normal - +_add_fns[(_Normal, _Normal)] = _Normal._add_normal # pylint: disable=protected-access +_sub_fns[(_Normal, _Normal)] = _Normal._sub_normal # pylint: disable=protected-access ######################################################################################## # Normal - Constant Arithmetic @@ -211,11 +216,15 @@ def _generic_rv_add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariab def _add_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal: - cov_cholesky = norm_rv.cov_cholesky if norm_rv.cov_cholesky_is_precomputed else None + if "cov_cholesky" in norm_rv._cache: + cache = norm_rv._cache["cov_cholesky"] + else: + cache = None + return _Normal( mean=norm_rv.mean + constant_rv.support, cov=norm_rv.cov, - cov_cholesky=cov_cholesky, + cache=cache, ) @@ -224,11 +233,15 @@ def _add_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal: def _sub_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal: - cov_cholesky = norm_rv.cov_cholesky if norm_rv.cov_cholesky_is_precomputed else None + if "cov_cholesky" in norm_rv._cache: + cache = {"cov_cholesky": norm_rv._cache["cov_cholesky"]} + else: + cache = None + return _Normal( mean=norm_rv.mean - constant_rv.support, cov=norm_rv.cov, - cov_cholesky=cov_cholesky, + cache=cache, ) @@ -236,11 +249,15 @@ def _sub_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal: def _sub_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal: - cov_cholesky = norm_rv.cov_cholesky if norm_rv.cov_cholesky_is_precomputed else None + if "cov_cholesky" in norm_rv._cache: + cache = {"cov_cholesky": norm_rv._cache["cov_cholesky"]} + else: + cache = None + return _Normal( mean=constant_rv.support - norm_rv.mean, cov=norm_rv.cov, - cov_cholesky=cov_cholesky, + cache=cache, ) @@ -249,22 +266,24 @@ def _sub_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal: def _mul_normal_constant( norm_rv: _Normal, constant_rv: _Constant -) -> Union[_Normal, _Constant, type(NotImplemented)]: +) -> Union[_Normal, _Constant, NotImplementedType]: if constant_rv.size == 1: if constant_rv.support == 0: return _Constant( - support=np.zeros_like(norm_rv.mean), + support=backend.zeros_like(norm_rv.mean), ) - if norm_rv.cov_cholesky_is_precomputed: - cov_cholesky = constant_rv.support * norm_rv.cov_cholesky + if "cov_cholesky" in norm_rv._cache: + cache = { + "cov_cholesky": constant_rv.support * norm_rv._cache["cov_cholesky"] + } else: - cov_cholesky = None + cache = None return _Normal( mean=constant_rv.support * norm_rv.mean, cov=(constant_rv.support**2) * norm_rv.cov, - cov_cholesky=cov_cholesky, + cache=cache, ) return NotImplemented @@ -281,9 +300,10 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal is a matrix- or multi-variate normal random variable and :math:`A` a constant. """ if norm_rv.ndim == 1 or (norm_rv.ndim == 2 and norm_rv.shape[0] == 1): - if norm_rv.cov_cholesky_is_precomputed: - cov_cholesky = _utils.linalg.cholesky_update( - constant_rv.support.T @ norm_rv.cov_cholesky + + if "cov_cholesky" in norm_rv._cache: + cov_cholesky = backend.linalg.cholesky_update( + constant_rv.support.T @ norm_rv._cache["cov_cholesky"] ) else: cov_cholesky = None @@ -291,10 +311,21 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal mean = norm_rv.mean @ constant_rv.support cov = constant_rv.support.T @ (norm_rv.cov @ constant_rv.support) - if cov.shape == () and mean.shape == (1,): + if mean.shape == (): + cov = cov.reshape(()) + + if cov_cholesky is not None: + cov_cholesky = cov_cholesky.reshape(()) + elif mean.shape == (1,): cov = cov.reshape((1, 1)) - return _Normal(mean=mean, cov=cov, cov_cholesky=cov_cholesky) + if cov_cholesky is not None: + cov_cholesky = cov_cholesky.reshape((1, 1)) + + if cov_cholesky is not None: + return _Normal(mean=mean, cov=cov, cache={"cov_cholesky": cov_cholesky}) + + return _Normal(mean=mean, cov=cov) # This part does not do the Cholesky update, # because of performance configurations: currently, there is no way of switching @@ -326,17 +357,32 @@ def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal a matrix- or multi-variate normal random variable and :math:`A` a constant. """ if norm_rv.ndim == 1 or (norm_rv.ndim == 2 and norm_rv.shape[1] == 1): - if norm_rv.cov_cholesky_is_precomputed: - cov_cholesky = _utils.linalg.cholesky_update( - constant_rv.support @ norm_rv.cov_cholesky + + if "cov_cholesky" in norm_rv._cache: + cov_cholesky = backend.linalg.cholesky_update( + constant_rv.support @ norm_rv._cache["cov_cholesky"] ) else: cov_cholesky = None - return _Normal( - mean=constant_rv.support @ norm_rv.mean, - cov=constant_rv.support @ (norm_rv.cov @ constant_rv.support.T), - cov_cholesky=cov_cholesky, - ) + + mean = constant_rv.support @ norm_rv.mean + cov = constant_rv.support @ (norm_rv.cov @ constant_rv.support.T) + + if mean.shape == (): + cov = cov.reshape(()) + + if cov_cholesky is not None: + cov_cholesky = cov_cholesky.reshape(()) + elif mean.shape == (1,): + cov = cov.reshape((1, 1)) + + if cov_cholesky is not None: + cov_cholesky = cov_cholesky.reshape((1, 1)) + + if cov_cholesky is not None: + return _Normal(mean=mean, cov=cov, cache={"cov_cholesky": cov_cholesky}) + + return _Normal(mean=mean, cov=cov) # This part does not do the Cholesky update, # because of performance configurations: currently, there is no way of switching @@ -345,7 +391,7 @@ def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal if constant_rv.support.ndim == 1: constant_rv_support = constant_rv.support[None, :] else: - constant_rv_support = constant_rv.support + constant_rv_support = constant_rv.supportndarray cov_update = _linear_operators.Kronecker( constant_rv_support, @@ -367,15 +413,17 @@ def _truediv_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Norma if constant_rv.support == 0: raise ZeroDivisionError - if norm_rv.cov_cholesky_is_precomputed: - cov_cholesky = norm_rv.cov_cholesky / constant_rv.support + if "cov_cholesky" in norm_rv._cache: + cache = { + "cov_cholesky": norm_rv._cache["cov_cholesky"] / constant_rv.support + } else: - cov_cholesky = None + cache = None return _Normal( mean=norm_rv.mean / constant_rv.support, cov=norm_rv.cov / (constant_rv.support**2), - cov_cholesky=cov_cholesky, + cache=cache, ) return NotImplemented diff --git a/src/probnum/randvars/_categorical.py b/src/probnum/randvars/_categorical.py index a66955635..136e06c92 100644 --- a/src/probnum/randvars/_categorical.py +++ b/src/probnum/randvars/_categorical.py @@ -3,6 +3,10 @@ import numpy as np +from probnum import BACKEND, Backend, backend +from probnum.backend.random import RNGState +from probnum.backend.typing import ArrayLike, SeedType, ShapeLike + from ._random_variable import DiscreteRandomVariable @@ -11,26 +15,29 @@ class Categorical(DiscreteRandomVariable): Parameters ---------- - probabilities : + probabilities Probabilities of the events. - support : + support Support of the categorical distribution. Optional. Default is None, in which case the support is chosen as :math:`(0, ..., K-1)` where - :math:`K` is the number of elements in `event_probabilities`. + :math:`K` is the number of elements in `probabilities`. """ def __init__( self, - probabilities: np.ndarray, - support: Optional[np.ndarray] = None, + probabilities: ArrayLike, + support: Optional[backend.Array] = None, ): - # The set of events is names "support" to be aligned with the method + + # The set of events is named "support" to be aligned with the method # DiscreteRandomVariable.in_support(). + self._probabilities = backend.asarray(probabilities) num_categories = len(probabilities) - self._probabilities = np.asarray(probabilities) self._support = ( - np.asarray(support) if support is not None else np.arange(num_categories) + backend.asarray(support) + if support is not None + else backend.arange(num_categories) ) parameters = { @@ -39,28 +46,30 @@ def __init__( "num_categories": num_categories, } - def _sample_categorical(rng, size=()): + def _sample_categorical(rng_state: RNGState, sample_shape: ShapeLike = ()): """Sample from a categorical distribution. - While on first sight, one might think that this - implementation can be replaced by - `np.random.choice(self.support, size, self.probabilities)`, - this is not true, because `np.random.choice` cannot handle - arrays with `ndim > 1`, but `self.support` can be just that. - This detour via the `mask` avoids this problem. + While on first sight, one might think that this implementation can be + replaced by `np.random.choice(self.support, sample_shape, + self.probabilities)`, this is not true, because `np.random.choice` cannot + handle arrays with `ndim > 1`, but `self.support` can be just that. This + detour via the `mask` avoids this problem. """ - - indices = rng.choice( - np.arange(len(self.support)), size=size, p=self.probabilities - ).reshape(size) + sample_shape = backend.asshape(sample_shape) + indices = backend.random.choice( + rng_state, + np.arange(len(self.support)), + shape=sample_shape, + p=self.probabilities, + ).reshape(sample_shape) return self.support[indices] - def _pmf_categorical(x): + def _pmf_categorical(x: ArrayLike): """PMF of a categorical distribution.""" # This implementation is defense against cryptic warnings such as: # https://stackoverflow.com/questions/45020217/numpy-where-function-throws-a-futurewarning-returns-scalar-instead-of-list - x = np.asarray(x) + x = backend.asarray(x) if x.dtype != self.dtype: raise ValueError( "The data type of x does not match with the data type of the " @@ -71,7 +80,7 @@ def _pmf_categorical(x): return self.probabilities[mask][0] if len(mask) > 0 else 0.0 def _mode_categorical(): - mask = np.argmax(self.probabilities) + mask = backend.argmax(self.probabilities) return self.support[mask] super().__init__( @@ -84,16 +93,16 @@ def _mode_categorical(): ) @property - def probabilities(self) -> np.ndarray: + def probabilities(self) -> backend.Array: """Event probabilities of the categorical distribution.""" return self._probabilities @property - def support(self) -> np.ndarray: + def support(self) -> backend.Array: """Support of the categorical distribution.""" return self._support - def resample(self, rng: np.random.Generator) -> "Categorical": + def resample(self, rng_state: RNGState) -> "Categorical": """Resample the support of the categorical random variable. Return a new categorical random variable (RV), where the support @@ -103,18 +112,18 @@ def resample(self, rng: np.random.Generator) -> "Categorical": Parameters ---------- - rng : - Random number generator. + rng_state + Random number generator state. Returns ------- Categorical Categorical random variable with resampled support - (according to self.probabilities). + (according to ``self.probabilities``). """ num_events = len(self.support) - new_support = self.sample(rng=rng, size=num_events) - new_probabilities = np.ones(self.probabilities.shape) / num_events + new_support = self.sample(rng_state, sample_shape=num_events) + new_probabilities = backend.ones(self.probabilities.shape) / num_events return Categorical( support=new_support, probabilities=new_probabilities, diff --git a/src/probnum/randvars/_constant.py b/src/probnum/randvars/_constant.py index c831417d2..f14be313b 100644 --- a/src/probnum/randvars/_constant.py +++ b/src/probnum/randvars/_constant.py @@ -3,19 +3,16 @@ from __future__ import annotations from functools import cached_property -from typing import Callable, TypeVar +from typing import Callable -import numpy as np - -from probnum import config, linops, utils as _utils -from probnum.typing import ArrayIndicesLike, ShapeLike, ShapeType +from probnum import backend, config, linops +from probnum.backend.random import RNGState +from probnum.backend.typing import ArrayIndicesLike, ShapeLike, ShapeType from . import _random_variable -ValueType = TypeVar("ValueType") - -class Constant(_random_variable.DiscreteRandomVariable[ValueType]): +class Constant(_random_variable.DiscreteRandomVariable): """Random variable representing a constant value. Discrete random variable which (with probability one) takes a constant value. The @@ -45,38 +42,35 @@ class Constant(_random_variable.DiscreteRandomVariable[ValueType]): Examples -------- - >>> from probnum import randvars + >>> from probnum import backend, randvars >>> import numpy as np >>> rv1 = randvars.Constant(support=0.) >>> rv2 = randvars.Constant(support=1.) >>> rv = rv1 + rv2 - >>> rng = np.random.default_rng(seed=42) - >>> rv.sample(rng, size=5) + >>> rng_state = backend.random.rng_state(42) + >>> rv.sample(rng_state, 5) array([1., 1., 1., 1., 1.]) """ def __init__( self, - support: ValueType, + support: backend.Array, ): - if np.isscalar(support): - support = _utils.as_numpy_scalar(support) - - self._support = support + self._support = backend.asarray(support) support_floating = self._support.astype( - np.promote_types(self._support.dtype, np.float_) + backend.promote_types(self._support.dtype, backend.float64) ) if config.matrix_free: cov = lambda: ( linops.Zero(shape=((self._support.size, self._support.size))) if self._support.ndim > 0 - else _utils.as_numpy_scalar(0.0, support_floating.dtype) + else backend.asscalar(0.0, support_floating.dtype) ) else: - cov = lambda: np.broadcast_to( - _utils.as_numpy_scalar(0.0, support_floating.dtype), + cov = lambda: backend.broadcast_to( + backend.asscalar(0.0, support_floating.dtype), shape=( (self._support.size, self._support.size) if self._support.ndim > 0 @@ -84,8 +78,8 @@ def __init__( ), ) - var = lambda: np.broadcast_to( - _utils.as_numpy_scalar(0.0, support_floating.dtype), + var = lambda: backend.broadcast_to( + backend.asscalar(0.0, support_floating.dtype), shape=self._support.shape, ) @@ -94,9 +88,13 @@ def __init__( dtype=self._support.dtype, parameters={"support": self._support}, sample=self._sample, - in_support=lambda x: np.all(x == self._support), - pmf=lambda x: np.float_(1.0 if np.all(x == self._support) else 0.0), - cdf=lambda x: np.float_(1.0 if np.all(x >= self._support) else 0.0), + in_support=lambda x: backend.all(x == self._support), + pmf=lambda x: backend.float64( + 1.0 if backend.all(x == self._support) else 0.0 + ), + cdf=lambda x: backend.float64( + 1.0 if backend.all(x >= self._support) else 0.0 + ), mode=lambda: self._support, median=lambda: support_floating, mean=lambda: support_floating, @@ -106,13 +104,13 @@ def __init__( ) @cached_property - def cov_cholesky(self): + def _cov_cholesky(self): # Pure utility attribute (it is zero anyway). # Make Constant behave more like Normal with zero covariance. return self.cov @property - def support(self) -> ValueType: + def support(self) -> backend.Array: """Constant value taken by the random variable.""" return self._support @@ -141,13 +139,15 @@ def transpose(self, *axes: int) -> "Constant": support=self._support.transpose(*axes), ) - def _sample(self, rng: np.random.Generator, size: ShapeLike = ()) -> ValueType: - size = _utils.as_shape(size) + def _sample( + self, rng_state: RNGState, sample_shape: ShapeLike = () + ) -> backend.Array: + # pylint: disable=unused-argument - if size == (): + if sample_shape == (): return self._support.copy() - return np.tile(self._support, reps=size + (1,) * self.ndim) + return backend.tile(self._support, reps=sample_shape + (1,) * self.ndim) # Unary arithmetic operations @@ -170,7 +170,7 @@ def __abs__(self) -> "Constant": @staticmethod def _binary_operator_factory( - operator: Callable[[ValueType, ValueType], ValueType] + operator: Callable[[backend.Array, backend.Array], backend.Array] ) -> Callable[["Constant", "Constant"], "Constant"]: def _constant_rv_binary_operator( constant_rv1: Constant, constant_rv2: Constant diff --git a/src/probnum/randvars/_normal.py b/src/probnum/randvars/_normal.py index cdc428a57..30569bb09 100644 --- a/src/probnum/randvars/_normal.py +++ b/src/probnum/randvars/_normal.py @@ -1,25 +1,25 @@ """Normally distributed / Gaussian random variables.""" from __future__ import annotations -from functools import cached_property -from typing import Callable, Optional, Union - -import numpy as np -import scipy.linalg -import scipy.stats - -from probnum import config, linops, utils as _utils -from probnum.typing import ArrayIndicesLike, FloatLike, ShapeLike, ShapeType +import functools +import operator +from typing import Any, Dict, Optional, Union + +from probnum import backend, config, linops +from probnum.backend.random import RNGState +from probnum.backend.typing import ( + ArrayIndicesLike, + ArrayLike, + FloatLike, + ShapeLike, + ShapeType, +) +from probnum.typing import MatrixType from . import _random_variable -ValueType = Union[np.floating, np.ndarray, linops.LinearOperator] - - -# pylint: disable="too-complex" - -class Normal(_random_variable.ContinuousRandomVariable[ValueType]): +class Normal(_random_variable.ContinuousRandomVariable): """Random variable with a normal distribution. Gaussian random variables are ubiquitous in probability theory, since the @@ -37,15 +37,6 @@ class Normal(_random_variable.ContinuousRandomVariable[ValueType]): Mean of the random variable. cov : (Co-)variance of the random variable. - cov_cholesky : - (Lower triangular) Cholesky factor of the covariance matrix. If None, then the - Cholesky factor of the covariance matrix is computed when - :attr:`Normal.cov_cholesky` is called and then cached. If specified, the value - is returned by :attr:`Normal.cov_cholesky`. In this case, its type and data type - are compared to the type and data type of the covariance. If the types do not - match, an exception is thrown. If the data types do not match, - the data type of the Cholesky factor is promoted to the data type of the - covariance matrix. See Also -------- @@ -53,52 +44,48 @@ class Normal(_random_variable.ContinuousRandomVariable[ValueType]): Examples -------- - >>> import numpy as np - >>> from probnum import randvars + >>> from probnum import backend, randvars >>> x = randvars.Normal(mean=0.5, cov=1.0) - >>> rng = np.random.default_rng(42) - >>> x.sample(rng=rng, size=(2, 2)) + >>> rng_state = backend.random.rng_state(42) + >>> x.sample(rng_state=rng_state, sample_shape=(2, 2)) array([[ 0.80471708, -0.53998411], [ 1.2504512 , 1.44056472]]) """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements + # TODO (#678): `cov_cholesky` should be passed to the `cov` `LinearOperator` def __init__( self, - mean: Union[float, np.floating, np.ndarray, linops.LinearOperator], - cov: Union[float, np.floating, np.ndarray, linops.LinearOperator], - cov_cholesky: Optional[ - Union[float, np.floating, np.ndarray, linops.LinearOperator] - ] = None, + mean: Union[ArrayLike, linops.LinearOperator], + cov: Union[ArrayLike, linops.LinearOperator], + cache: Optional[Dict[str, Any]] = None, ): - # Type normalization - if np.isscalar(mean): - mean = _utils.as_numpy_scalar(mean) + # pylint: disable=too-many-branches - if np.isscalar(cov): - cov = _utils.as_numpy_scalar(cov) + # Type normalization + if not isinstance(mean, linops.LinearOperator): + mean = backend.asarray(mean) - if np.isscalar(cov_cholesky): - cov_cholesky = _utils.as_numpy_scalar(cov_cholesky) + if not isinstance(cov, linops.LinearOperator): + cov = backend.asarray(cov) # Data type normalization - dtype = np.promote_types(mean.dtype, cov.dtype) + dtype = backend.promote_types(mean.dtype, cov.dtype) - if not np.issubdtype(dtype, np.floating): - dtype = np.dtype(np.double) + if not backend.is_floating_dtype(dtype): + dtype = config.default_floating_dtype - mean = mean.astype(dtype, order="C", casting="safe", copy=False) - cov = cov.astype(dtype, order="C", casting="safe", copy=False) + # Circular dependency -> defer import + from probnum import compat # pylint: disable=import-outside-toplevel - # Shape checking - if not 0 <= mean.ndim <= 2: - raise ValueError( - f"Gaussian random variables must either be scalars, vectors, or " - f"matrices (or linear operators), but the given mean is a {mean.ndim}-" - f"dimensional tensor." - ) + mean = compat.cast(mean, dtype=dtype, casting="safe", copy=False) + cov = compat.cast(cov, dtype=dtype, casting="safe", copy=False) - expected_cov_shape = (np.prod(mean.shape),) * 2 if len(mean.shape) > 0 else () + # Shape checking + expected_cov_shape = ( + (functools.reduce(operator.mul, mean.shape, 1),) * 2 + if mean.ndim > 0 + else () + ) if cov.shape != expected_cov_shape: raise ValueError( @@ -106,184 +93,79 @@ def __init__( f"shape {cov.shape} was given." ) - # Method selection - univariate = mean.ndim == 0 - dense = isinstance(mean, np.ndarray) and isinstance(cov, np.ndarray) - cov_operator = isinstance(cov, linops.LinearOperator) - compute_cov_cholesky: Callable[[], ValueType] = None - - if univariate: - # Univariate Gaussian - sample = self._univariate_sample - in_support = Normal._univariate_in_support - pdf = self._univariate_pdf - logpdf = self._univariate_logpdf - cdf = self._univariate_cdf - logcdf = self._univariate_logcdf - quantile = self._univariate_quantile - - median = lambda: mean - var = lambda: cov - entropy = self._univariate_entropy - - compute_cov_cholesky = self._univariate_cov_cholesky - - elif dense or cov_operator: - # Multi- and matrixvariate Gaussians - sample = self._dense_sample - in_support = Normal._dense_in_support - pdf = self._dense_pdf - logpdf = self._dense_logpdf - cdf = self._dense_cdf - logcdf = self._dense_logcdf - quantile = None - - median = None - var = self._dense_var - entropy = self._dense_entropy - - compute_cov_cholesky = self.dense_cov_cholesky - - # Ensure that the Cholesky factor has the same type as the covariance, - # and, if necessary, promote data types. Check for (in this order): type, - # shape, dtype. - if cov_cholesky is not None: - - if not isinstance(cov_cholesky, type(cov)): - raise TypeError( - f"The covariance matrix is of type `{type(cov)}`, so its " - f"Cholesky decomposition must be of the same type, but an " - f"object of type `{type(cov_cholesky)}` was given." - ) - - if cov_cholesky.shape != cov.shape: - raise ValueError( - f"The cholesky decomposition of the covariance matrix must " - f"have the same shape as the covariance matrix, i.e. " - f"{cov.shape}, but shape {cov_cholesky.shape} was given" - ) - - if cov_cholesky.dtype != cov.dtype: - cov_cholesky = cov_cholesky.astype( - cov.dtype, casting="safe", copy=False - ) - - if cov_operator: - if isinstance(cov, linops.SymmetricKronecker): - m, n = mean.shape - - if m != n or n != cov.A.shape[0] or n != cov.B.shape[1]: - raise ValueError( - "Normal distributions with symmetric Kronecker structured " - "kernels must have square mean and square kernels factors " - "with matching dimensions." - ) - - if cov.identical_factors: - sample = self._symmetric_kronecker_identical_factors_sample - - compute_cov_cholesky = ( - self._symmetric_kronecker_identical_factors_cov_cholesky - ) - elif isinstance(cov, linops.Kronecker): - compute_cov_cholesky = self._kronecker_cov_cholesky - if mean.ndim == 2: - m, n = mean.shape - - if ( - m != cov.A.shape[0] - or m != cov.A.shape[1] - or n != cov.B.shape[0] - or n != cov.B.shape[1] - ): - raise ValueError( - "Kronecker structured kernels must have factors with " - "the same shape as the mean." - ) - - else: - # This case handles all linear operators, for which no Cholesky - # factorization is implemented, yet. - # Computes the dense Cholesky and converts it to a LinearOperator. - compute_cov_cholesky = self._dense_cov_cholesky_as_linop - + self._cache = cache if cache is not None else {} + + if mean.ndim == 0: + # Scalar Gaussian + super().__init__( + shape=(), + dtype=mean.dtype, + parameters={"mean": mean, "cov": cov}, + sample=self._scalar_sample, + in_support=Normal._scalar_in_support, + pdf=self._scalar_pdf, + logpdf=self._scalar_logpdf, + cdf=self._scalar_cdf, + logcdf=self._scalar_logcdf, + quantile=self._scalar_quantile, + mode=lambda: mean, + median=lambda: mean, + mean=lambda: mean, + cov=lambda: cov, + var=lambda: cov, + entropy=self._scalar_entropy, + ) else: - raise ValueError( - f"Cannot instantiate normal distribution with mean of type " - f"{mean.__class__.__name__} and kernels of type " - f"{cov.__class__.__name__}." + # Multi- and matrix- and tensorvariate Gaussians + super().__init__( + shape=mean.shape, + dtype=mean.dtype, + parameters={"mean": mean, "cov": cov}, + sample=self._sample, + in_support=self._in_support, + pdf=self._pdf, + logpdf=self._logpdf, + cdf=self._cdf, + logcdf=self._logcdf, + quantile=None, + mode=lambda: mean, + median=None, + mean=lambda: mean, + cov=lambda: cov, + var=self._var, + entropy=self._entropy, ) - super().__init__( - shape=mean.shape, - dtype=mean.dtype, - parameters={"mean": mean, "cov": cov}, - sample=sample, - in_support=in_support, - pdf=pdf, - logpdf=logpdf, - cdf=cdf, - logcdf=logcdf, - quantile=quantile, - mode=lambda: mean, - median=median, - mean=lambda: mean, - cov=lambda: cov, - var=var, - entropy=entropy, - ) - - self._compute_cov_cholesky = compute_cov_cholesky - self._cov_cholesky = cov_cholesky - - @property - def cov_cholesky(self) -> ValueType: - """Cholesky factor :math:`L` of the covariance - :math:`\\operatorname{Cov}(X) =LL^\\top`.""" - - if not self.cov_cholesky_is_precomputed: - self.precompute_cov_cholesky() - return self._cov_cholesky - - def precompute_cov_cholesky( - self, - damping_factor: Optional[FloatLike] = None, - ): - """(P)recompute Cholesky factors (careful: in-place operation!).""" - if damping_factor is None: - damping_factor = config.covariance_inversion_damping - if self.cov_cholesky_is_precomputed: - raise Exception("A Cholesky factor is already available.") - self._cov_cholesky = self._compute_cov_cholesky(damping_factor=damping_factor) - @property - def cov_cholesky_is_precomputed(self) -> bool: - """Return truth-value of whether the Cholesky factor of the covariance is - readily available. - - This happens if (i) the Cholesky factor is specified during initialization or if - (ii) the property `self.cov_cholesky` has been called before. - """ - if self._cov_cholesky is None: - return False - return True - - @cached_property - def dense_mean(self) -> Union[np.floating, np.ndarray]: + def dense_mean(self) -> backend.Array: """Dense representation of the mean.""" if isinstance(self.mean, linops.LinearOperator): return self.mean.todense() return self.mean - @cached_property - def dense_cov(self) -> Union[np.floating, np.ndarray]: + @property + def dense_cov(self) -> backend.Array: """Dense representation of the covariance.""" if isinstance(self.cov, linops.LinearOperator): return self.cov.todense() return self.cov + @functools.cached_property + def cov_matrix(self) -> backend.Array: + if isinstance(self.cov, linops.LinearOperator): + return self.cov.todense() + + return self.cov + + @functools.cached_property + def cov_op(self) -> linops.LinearOperator: + if isinstance(self.cov, linops.LinearOperator): + return self.cov + + return linops.aslinop(self.cov) + def __getitem__(self, key: ArrayIndicesLike) -> "Normal": """Marginalization in multi- and matrixvariate normal random variables, expressed as (advanced) indexing, masking and slicing. @@ -292,14 +174,15 @@ def __getitem__(self, key: ArrayIndicesLike) -> "Normal": https://numpy.org/doc/1.19/reference/arrays.indexing.html. - Note that, currently, this method only works for multi- and matrixvariate - normal distributions. - Parameters ---------- - key : int or slice or ndarray or tuple of None, int, slice, or ndarray + key : Indices, slice objects and/or boolean masks specifying which entries to keep while marginalizing over all other entries. + + Returns + ------- + Random variable after marginalization. """ if not isinstance(key, tuple): @@ -310,7 +193,8 @@ def __getitem__(self, key: ArrayIndicesLike) -> "Normal": # Select submatrix from covariance matrix cov = self.dense_cov.reshape(self.shape + self.shape) - cov = cov[key][(...,) + key] + cov = cov[key] + cov = cov[tuple(slice(cov.shape[i]) for i in range(cov.ndim - self.ndim)) + key] if mean.ndim > 0: cov = cov.reshape(mean.size, mean.size) @@ -374,9 +258,6 @@ def __pos__(self) -> "Normal": cov=self.cov, ) - # TODO: Overwrite __abs__ and add absolute moments of normal - # TODO: (https://arxiv.org/pdf/1209.4340.pdf) - # Binary arithmetic operations def _add_normal(self, other: "Normal") -> "Normal": @@ -404,200 +285,334 @@ def _sub_normal(self, other: "Normal") -> "Normal": ) # Univariate Gaussians - def _univariate_cov_cholesky( - self, - damping_factor: FloatLike, - ) -> np.floating: - return np.sqrt(self.cov + damping_factor) - - def _univariate_sample( + @functools.partial(backend.jit_method, static_argnums=(1,)) + def _scalar_sample( self, - rng: np.random.Generator, - size: ShapeType = (), - ) -> Union[np.floating, np.ndarray]: - sample = scipy.stats.norm.rvs( - loc=self.mean, scale=self.std, size=size, random_state=rng + rng_state: RNGState, + sample_shape: ShapeType = (), + ) -> backend.Array: + sample = backend.random.standard_normal( + rng_state, + shape=sample_shape, + dtype=self.dtype, ) - if np.isscalar(sample): - sample = _utils.as_numpy_scalar(sample, dtype=self.dtype) - else: - sample = sample.astype(self.dtype) - - assert sample.shape == size - - return sample + return self.std * sample + self.mean @staticmethod - def _univariate_in_support(x: ValueType) -> bool: - return np.isfinite(x) - - def _univariate_pdf(self, x: ValueType) -> np.float_: - return scipy.stats.norm.pdf(x, loc=self.mean, scale=self.std) + @backend.jit + def _scalar_in_support(x: backend.Array) -> backend.Array: + return backend.isfinite(x) + + @backend.jit_method + def _scalar_pdf(self, x: backend.Array) -> backend.Array: + return backend.exp(-((x - self.mean) ** 2) / (2.0 * self.var)) / backend.sqrt( + 2 * backend.pi * self.var + ) - def _univariate_logpdf(self, x: ValueType) -> np.float_: - return scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.std) + @backend.jit_method + def _scalar_logpdf(self, x: backend.Array) -> backend.Array: + return -((x - self.mean) ** 2) / (2.0 * self.var) - 0.5 * backend.log( + 2.0 * backend.pi * self.var + ) - def _univariate_cdf(self, x: ValueType) -> np.float_: - return scipy.stats.norm.cdf(x, loc=self.mean, scale=self.std) + @backend.jit_method + def _scalar_cdf(self, x: backend.Array) -> backend.Array: + return backend.special.ndtr((x - self.mean) / self.std) - def _univariate_logcdf(self, x: ValueType) -> np.float_: - return scipy.stats.norm.logcdf(x, loc=self.mean, scale=self.std) + @backend.jit_method + def _scalar_logcdf(self, x: backend.Array) -> backend.Array: + return backend.log(self._scalar_cdf(x)) - def _univariate_quantile(self, p: FloatLike) -> np.floating: - return scipy.stats.norm.ppf(p, loc=self.mean, scale=self.std) + @backend.jit_method + def _scalar_quantile(self, p: FloatLike) -> backend.Array: + return self.mean + self.std * backend.special.ndtri(p) - def _univariate_entropy(self: ValueType) -> np.float_: - return _utils.as_numpy_scalar( - scipy.stats.norm.entropy(loc=self.mean, scale=self.std), - dtype=np.float_, - ) + @backend.jit_method + def _scalar_entropy(self) -> backend.Scalar: + return 0.5 * backend.log(2.0 * backend.pi * self.var) + 0.5 # Multi- and matrixvariate Gaussians - def dense_cov_cholesky( - self, - damping_factor: Optional[FloatLike] = None, - ) -> np.ndarray: - """Compute the Cholesky factorization of the covariance from its dense - representation.""" - if damping_factor is None: - damping_factor = config.covariance_inversion_damping - dense_cov = self.dense_cov - - return scipy.linalg.cholesky( - dense_cov + damping_factor * np.eye(self.size, dtype=self.dtype), - lower=True, - ) - def _dense_cov_cholesky_as_linop( - self, damping_factor: FloatLike - ) -> linops.LinearOperator: - return linops.aslinop(self.dense_cov_cholesky(damping_factor=damping_factor)) - - def _dense_sample( - self, rng: np.random.Generator, size: ShapeType = () - ) -> np.ndarray: - sample = scipy.stats.multivariate_normal.rvs( - mean=self.dense_mean.ravel(), - cov=self.dense_cov, - size=size, - random_state=rng, + @functools.partial(backend.jit_method, static_argnums=(1,)) + def _sample( + self, rng_state: RNGState, sample_shape: ShapeType = () + ) -> backend.Array: + samples = backend.random.standard_normal( + rng_state, + shape=sample_shape + (self.size,), + dtype=self.dtype, ) - return sample.reshape(sample.shape[:-1] + self.shape) + samples = self._cov_sqrtm @ samples[..., None] + samples = samples.reshape(sample_shape + self.shape) + samples += self.dense_mean + + return samples @staticmethod - def _arg_todense(x: Union[np.ndarray, linops.LinearOperator]) -> np.ndarray: + def _arg_todense(x: Union[backend.Array, linops.LinearOperator]) -> backend.Array: if isinstance(x, linops.LinearOperator): return x.todense() - if isinstance(x, np.ndarray): + if backend.isarray(x): return x raise ValueError(f"Unsupported argument type {type(x)}") - @staticmethod - def _dense_in_support(x: ValueType) -> bool: - return np.all(np.isfinite(Normal._arg_todense(x))) - - def _dense_pdf(self, x: ValueType) -> np.float_: - return scipy.stats.multivariate_normal.pdf( - Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)), - mean=self.dense_mean.ravel(), - cov=self.dense_cov, + @backend.jit_method + def _in_support(self, x: backend.Array) -> backend.Array: + return backend.all( + backend.isfinite(Normal._arg_todense(x)), + axis=tuple(range(-self.ndim, 0)), + keepdims=False, ) - def _dense_logpdf(self, x: ValueType) -> np.float_: - return scipy.stats.multivariate_normal.logpdf( - Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)), - mean=self.dense_mean.ravel(), - cov=self.dense_cov, - ) + @backend.jit_method + def _pdf(self, x: backend.Array) -> backend.Array: + return backend.exp(self._logpdf(x)) - def _dense_cdf(self, x: ValueType) -> np.float_: - return scipy.stats.multivariate_normal.cdf( - Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)), - mean=self.dense_mean.ravel(), - cov=self.dense_cov, + @backend.jit_method + def _logpdf(self, x: backend.Array) -> backend.Array: + x_centered = Normal._arg_todense(x - self.dense_mean).reshape( + x.shape[: -self.ndim] + (-1,) ) - def _dense_logcdf(self, x: ValueType) -> np.float_: - return scipy.stats.multivariate_normal.logcdf( - Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)), - mean=self.dense_mean.ravel(), - cov=self.dense_cov, + return -0.5 * ( + # TODO (#569,#678): backend.sum( + # x_centered * self._cov_op.inv()(x_centered, axis=-1), + # axis=-1 + # ) + # Here, we use: + # ||L^{-1}(x - \mu)||_2^2 = (x - \mu)^T \Sigma^{-1} (x - \mu) + backend.sum(self._cov_sqrtm_solve(x_centered) ** 2, axis=-1) + + self.size * backend.log(backend.asarray(2.0 * backend.pi)) + + self._cov_logdet ) - def _dense_var(self) -> np.ndarray: - return np.diag(self.dense_cov).reshape(self.shape) + _cdf = backend.Dispatcher() - def _dense_entropy(self) -> np.float_: - return _utils.as_numpy_scalar( - scipy.stats.multivariate_normal.entropy( - mean=self.dense_mean.ravel(), - cov=self.dense_cov, - ), - dtype=np.float_, + @_cdf.numpy_impl + def _cdf_numpy(self, x: backend.Array) -> backend.Array: + import scipy.stats # pylint: disable=import-outside-toplevel + + x_batch_shape = x.shape[: x.ndim - self.ndim] + + scipy_cdf = scipy.stats.multivariate_normal.cdf( + Normal._arg_todense(x).reshape(x_batch_shape + (-1,)), + mean=self.dense_mean.reshape(-1), + cov=self.cov_matrix, ) - # Matrixvariate Gaussian with Kronecker covariance - def _kronecker_cov_cholesky( - self, - damping_factor: FloatLike, - ) -> linops.Kronecker: - assert isinstance(self.cov, linops.Kronecker) + # scipy's implementation happily squeezes `1` dimensions out of the batch + expected_shape = x.shape[: x.ndim - self.ndim] - A = self.cov.A.todense() - B = self.cov.B.todense() + if any(dim == 1 for dim in expected_shape): + assert all(dim != 1 for dim in scipy_cdf.shape) - return linops.Kronecker( - A=scipy.linalg.cholesky( - A + damping_factor * np.eye(A.shape[0], dtype=self.dtype), - lower=True, - ), - B=scipy.linalg.cholesky( - B + damping_factor * np.eye(B.shape[0], dtype=self.dtype), - lower=True, + scipy_cdf = scipy_cdf.reshape(expected_shape) + + return scipy_cdf + + def _logcdf(self, x: backend.Array) -> backend.Array: + return backend.log(self.cdf(x)) + + @backend.jit_method + def _var(self) -> backend.Array: + return backend.diag(self.dense_cov).reshape(self.shape) + + @backend.jit_method + def _entropy(self) -> backend.Scalar: + entropy = 0.5 * self.size * (backend.log(2.0 * backend.pi) + 1.0) + entropy += 0.5 * self._cov_logdet + + return entropy + + def compute_cov_sqrtm(self) -> Normal: + if "cov_cholesky" in self._cache and "cov_eigh" in self._cache: + return self + + cache = self._cache + + if "cov_cholesky" not in self._cache: + cache["cov_cholesky"] = self._cov_cholesky + + return Normal( + self.mean, + self.cov, + cache=backend.cond( + backend.any(backend.isnan(cache["cov_cholesky"])), + lambda: cache + {"cov_eigh": self._cov_eigh}, + lambda: cache, ), ) - # Matrixvariate Gaussian with symmetric Kronecker covariance from identical - # factors - def _symmetric_kronecker_identical_factors_cov_cholesky( - self, - damping_factor: FloatLike, - ) -> linops.SymmetricKronecker: - assert ( + # TODO (#678): Use `LinearOperator.cholesky` once the backend is supported + + @property + @backend.jit_method + def _cov_cholesky(self) -> MatrixType: + if "cov_cholesky" in self._cache: + return self._cache["cov_cholesky"] + + if self.ndim == 0: + return backend.sqrt(self.cov) + + if backend.isarray(self.cov): + return backend.linalg.cholesky(self.cov, upper=False) + + if isinstance(self.cov, linops.Kronecker): + return linops.Kronecker( + backend.linalg.cholesky(self.cov.A.todense(), upper=False), + backend.linalg.cholesky(self.cov.B.todense(), upper=False), + ) + + if ( isinstance(self.cov, linops.SymmetricKronecker) and self.cov.identical_factors - ) + ): + return linops.SymmetricKronecker( + backend.linalg.cholesky(self.cov.A.todense(), upper=False) + ) - A = self.cov.A.todense() + assert isinstance(self.cov, linops.LinearOperator) - return linops.SymmetricKronecker( - A=scipy.linalg.cholesky( - A + damping_factor * np.eye(A.shape[0], dtype=self.dtype), - lower=True, - ), - ) + return linops.aslinop(backend.linalg.cholesky(self.cov.todense(), upper=False)) + + @property + def _cov_matrix_cholesky(self) -> backend.Array: + if isinstance(self._cov_cholesky, linops.LinearOperator): + return self._cov_cholesky.todense() - def _symmetric_kronecker_identical_factors_sample( - self, rng: np.random.Generator, size: ShapeType = () - ) -> np.ndarray: - assert ( + return self._cov_cholesky + + # TODO (#569,#678): Use `LinearOperator.eig` it is implemented and once the backend + # is supported + + @property + @backend.jit_method + def _cov_eigh(self) -> MatrixType: + if "cov_eigh" in self._cache: + return self._cache["cov_eigh"] + + if self.ndim == 0: + eigvals = self.cov + Q = backend.ones_like(self.cov) + elif backend.isarray(self.cov): + eigvals, Q = backend.linalg.eigh(self.cov) + elif isinstance(self.cov, linops.Kronecker): + A_eigvals, A_eigvecs = backend.linalg.eigh(self.cov.A.todense()) + B_eigvals, B_eigvecs = backend.linalg.eigh(self.cov.B.todense()) + + eigvals = backend.linalg.kron(A_eigvals, B_eigvals) + Q = linops.Kronecker(A_eigvecs, B_eigvecs) + elif ( isinstance(self.cov, linops.SymmetricKronecker) and self.cov.identical_factors - ) + ): + A_eigvals, A_eigvecs = backend.linalg.eigh(self.cov.A.todense()) + + eigvals = backend.linalg.kron(A_eigvals, B_eigvals) + Q = linops.SymmetricKronecker(A_eigvecs) + else: + assert isinstance(self.cov, linops.LinearOperator) + + eigvals, Q = backend.linalg.eigh(self.cov_matrix) + + Q = linops.aslinop(Q) + + return (_clip_eigvals(eigvals), Q) + + # TODO (#569,#678): Replace `_cov_{sqrtm,sqrtm_solve,logdet}` with + # `self._cov_op.{sqrtm,inv,logdet}` once they are supported and once linops support + # the backend + + @property + @backend.jit_method + def _cov_sqrtm(self) -> MatrixType: + cov_cholesky = self._cov_cholesky + + def _fallback_eigh(): + eigvals, Q = self._cov_eigh - n = self.mean.shape[1] + if isinstance(Q, linops.LinearOperator): + return Q @ linops.Scaling(backend.sqrt(eigvals)) - # Draw standard normal samples - size_sample = (n * n,) + size + return Q * backend.sqrt(eigvals)[None, :] - stdnormal_samples = scipy.stats.norm.rvs(size=size_sample, random_state=rng) + if isinstance(cov_cholesky, (linops.Kronecker, linops.SymmetricKronecker)): + return backend.cond( + backend.any(backend.isnan(cov_cholesky.A.todense())) + & backend.any(backend.isnan(cov_cholesky.B.todense())), + _fallback_eigh, + lambda: cov_cholesky, + ) + + if isinstance(cov_cholesky, linops.Scaling): + return backend.cond( + backend.any(backend.isnan(cov_cholesky.factors)), + _fallback_eigh, + lambda: cov_cholesky, + ) + + if isinstance(cov_cholesky, linops.LinearOperator): + return backend.cond( + backend.any(backend.isnan(cov_cholesky.todense())), + _fallback_eigh, + lambda: cov_cholesky, + ) + + return backend.cond( + backend.any(backend.isnan(cov_cholesky)), + _fallback_eigh, + lambda: cov_cholesky, + ) + + @backend.jit_method + def _cov_sqrtm_solve(self, x: backend.Array) -> backend.Array: + def _eigh_fallback(x): + eigvals, Q = self._cov_eigh + + return (x @ Q) / backend.sqrt(eigvals) + + return backend.cond( + backend.any(backend.isnan(self._cov_matrix_cholesky)), + _eigh_fallback, + lambda x: backend.linalg.solve_triangular( + self._cov_matrix_cholesky, + x[..., None], + upper=False, + )[..., 0], + x, + ) + + @property + @backend.jit_method + def _cov_logdet(self) -> backend.Array: + return backend.cond( + backend.any(backend.isnan(self._cov_matrix_cholesky)), + lambda: backend.sum(backend.log(self._cov_eigh[0])), + lambda: ( + 2.0 * backend.sum(backend.log(backend.diag(self._cov_matrix_cholesky))) + ), + ) - # Appendix E: Bartels, S., Probabilistic Linear Algebra, PhD Thesis 2019 - samples_scaled = linops.Symmetrize(n) @ (self.cov_cholesky @ stdnormal_samples) - # TODO: can we avoid todense here and just return operator samples? - return self.dense_mean[None, :, :] + samples_scaled.T.reshape(-1, n, n) +def _clip_eigvals(eigvals: backend.Array) -> backend.Array: + # Clip eigenvalues as in + # https://github.com/scipy/scipy/blob/b5d8bab88af61d61de09641243848df63380a67f/scipy/stats/_multivariate.py#L60-L166 + if eigvals.dtype == backend.float64: + eigvals_clip = 1e6 + elif eigvals.dtype == backend.float32: + eigvals_clip = 1e3 + else: + raise TypeError("Unsupported dtype") + + eigvals_clip *= backend.finfo(eigvals.dtype).eps + eigvals_clip *= backend.max(backend.abs(eigvals)) + + return backend.cond( + backend.any(eigvals < -eigvals_clip), + lambda: backend.full_like(eigvals, backend.nan), + lambda: eigvals * (eigvals >= eigvals_clip), + ) diff --git a/src/probnum/randvars/_random_variable.py b/src/probnum/randvars/_random_variable.py index 6f6180e55..0433197a0 100644 --- a/src/probnum/randvars/_random_variable.py +++ b/src/probnum/randvars/_random_variable.py @@ -1,26 +1,27 @@ """Random Variables.""" from __future__ import annotations +import functools from functools import cached_property -from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union +import operator +from typing import Any, Callable, Dict, Optional import numpy as np -from probnum import utils as _utils -from probnum.typing import ArrayIndicesLike, DTypeLike, FloatLike, ShapeLike, ShapeType - -ValueType = TypeVar("ValueType") +from probnum import backend +from probnum.backend.random import RNGState +from probnum.backend.typing import ArrayIndicesLike, DTypeLike, ShapeLike, ShapeType # pylint: disable="too-many-lines" -class RandomVariable(Generic[ValueType]): +class RandomVariable: """Random variables represent uncertainty about a value. Random variables generalize multi-dimensional arrays by encoding uncertainty about the (numerical) quantity in question. Despite their name, they do not necessarily represent stochastic objects. Random variables are also the - primary in- and outputs of probabilistic numerical methods. + primary in- and outputs of probabilistic numerical methods. Instances of :class:`RandomVariable` can be added, multiplied, etc. with arrays and linear operators. This may change their distribution and therefore @@ -60,23 +61,6 @@ class RandomVariable(Generic[ValueType]): (Element-wise) standard deviation of the random variable. entropy : Information-theoretic entropy :math:`H(X)` of the random variable. - as_value_type : - Function which can be used to transform user-supplied arguments, interpreted as - realizations of this random variable, to an easy-to-process, normalized format. - Will be called internally to transform the argument of functions like - :meth:`~RandomVariable.in_support`, :meth:`~RandomVariable.cdf` - and :meth:`~RandomVariable.logcdf`, :meth:`~DiscreteRandomVariable.pmf` - and :meth:`~DiscreteRandomVariable.logpmf` (in :class:`DiscreteRandomVariable`), - :meth:`~ContinuousRandomVariable.pdf` and - :meth:`~ContinuousRandomVariable.logpdf` (in :class:`ContinuousRandomVariable`), - and potentially by similar functions in subclasses. - - For instance, this method is useful if (``log``) - :meth:`~ContinousRandomVariable.cdf` and (``log``) - :meth:`~ContinuousRandomVariable.pdf` both only work on :class:`numpy.float_` - arguments, but we still want the user to be able to pass Python - :class:`float`. Then :meth:`~RandomVariable.as_value_type` - should be set to something like ``lambda x: np.float64(x)``. See Also -------- @@ -108,28 +92,25 @@ def __init__( shape: ShapeLike, dtype: DTypeLike, parameters: Optional[Dict[str, Any]] = None, - sample: Optional[Callable[[np.random.Generator, ShapeType], ValueType]] = None, - in_support: Optional[Callable[[ValueType], bool]] = None, - cdf: Optional[Callable[[ValueType], np.float_]] = None, - logcdf: Optional[Callable[[ValueType], np.float_]] = None, - quantile: Optional[Callable[[FloatLike], ValueType]] = None, - mode: Optional[Callable[[], ValueType]] = None, - median: Optional[Callable[[], ValueType]] = None, - mean: Optional[Callable[[], ValueType]] = None, - cov: Optional[Callable[[], ValueType]] = None, - var: Optional[Callable[[], ValueType]] = None, - std: Optional[Callable[[], ValueType]] = None, - entropy: Optional[Callable[[], np.float_]] = None, - as_value_type: Optional[Callable[[Any], ValueType]] = None, + sample: Optional[Callable[[RNGState, ShapeType], backend.Array]] = None, + in_support: Optional[Callable[[backend.Array], bool]] = None, + cdf: Optional[Callable[[backend.Array], backend.Array]] = None, + logcdf: Optional[Callable[[backend.Array], backend.Array]] = None, + quantile: Optional[Callable[[backend.Array], backend.Array]] = None, + mode: Optional[Callable[[], backend.Array]] = None, + median: Optional[Callable[[], backend.Array]] = None, + mean: Optional[Callable[[], backend.Array]] = None, + cov: Optional[Callable[[], backend.Array]] = None, + var: Optional[Callable[[], backend.Array]] = None, + std: Optional[Callable[[], backend.Array]] = None, + entropy: Optional[Callable[[], backend.Scalar]] = None, ): # pylint: disable=too-many-arguments,too-many-locals """Create a new random variable.""" - self.__shape = _utils.as_shape(shape) + self.__shape = backend.asshape(shape) # Data Types - self.__dtype = np.dtype(dtype) - self.__median_dtype = RandomVariable.infer_median_dtype(self.__dtype) - self.__moment_dtype = RandomVariable.infer_moment_dtype(self.__dtype) + self.__dtype = backend.asdtype(dtype) # Probability distribution of the random variable self.__parameters = parameters.copy() if parameters is not None else {} @@ -150,9 +131,6 @@ def __init__( self.__std = std self.__entropy = entropy - # Utilities - self.__as_value_type = as_value_type - def __repr__(self) -> str: return ( f"<{self.__class__.__name__} with shape={self.shape}, dtype" @@ -172,40 +150,40 @@ def ndim(self) -> int: @cached_property def size(self) -> int: """Size of realizations of the random variable, defined as the product over all - components of :meth:`shape`.""" - return int(np.prod(self.__shape)) + components of :attr:`shape`.""" + return functools.reduce(operator.mul, self.__shape, 1) @property - def dtype(self) -> np.dtype: + def dtype(self) -> backend.DType: """Data type of (elements of) a realization of this random variable.""" return self.__dtype - @property - def median_dtype(self) -> np.dtype: - """The dtype of the :attr:`median`. - - It will be set to the dtype arising from the multiplication of - values with dtypes :attr:`dtype` and :class:`numpy.float_`. This - is motivated by the fact that, even for discrete random - variables, e.g. integer-valued random variables, the - :attr:`median` might lie in between two values in which case - these values are averaged. For example, a uniform random - variable on :math:`\\{ 1, 2, 3, 4 \\}` will have a median of - :math:`2.5`. + @cached_property + def median_dtype(self) -> backend.DType: + r"""The dtype of the :attr:`median`. + + It will be set to the dtype arising from the multiplication of values with + dtypes :attr:`dtype` and :class:`~probnum.backend.float64`. This is motivated by + the fact that, even for discrete random variables, e.g. integer-valued random + variables, the :attr:`median` might lie in between two values in which case + these values are averaged. For example, a uniform random variable on :math:`\{ + 1, 2, 3, 4 \}` will have a median of :math:`2.5`. """ - return self.__median_dtype + return backend.promote_types(self.dtype, backend.float64) - @property - def moment_dtype(self) -> np.dtype: - """The dtype of any (function of a) moment of the random variable, e.g. its - :attr:`mean`, :attr:`cov`, :attr:`var`, or :attr:`std`. It will be set to the - dtype arising from the multiplication of values with dtypes :attr:`dtype` - and :class:`numpy.float_`. This is motivated by the mathematical definition of a - moment as a sum or an integral over products of probabilities and values of the - random variable, which are represented as using the dtypes :class:`numpy.float_` - and :attr:`dtype`, respectively. + @cached_property + def expectation_dtype(self) -> backend.DType: + r"""The dtype of an expectation of (a function of) the random variable. + + For instance, the :attr:`mean`, :attr:`cov`, :attr:`var`, :attr:`std`, and + :attr:`entropy` of the random variable will have this dtype. It will be set + to the dtype arising from the multiplication of values with dtypes :attr:`dtype` + and :class:`~probnum.backend.float64`. This is motivated by the mathematical + definition of an expectation as a sum or an integral over products of + probabilities and values of the random variable, which are represented as using + the dtypes :class:`~probnum.backend.float64` and :attr:`dtype`, respectively. """ - return self.__moment_dtype + return backend.promote_types(self.dtype, backend.float64) @property def parameters(self) -> Dict[str, Any]: @@ -217,7 +195,7 @@ def parameters(self) -> Dict[str, Any]: return self.__parameters.copy() @cached_property - def mode(self) -> ValueType: + def mode(self) -> backend.Array: """Mode of the random variable.""" if self.__mode is None: raise NotImplementedError @@ -238,7 +216,7 @@ def mode(self) -> ValueType: return mode @cached_property - def median(self) -> ValueType: + def median(self) -> backend.Array: """Median of the random variable. To learn about the dtype of the median, see @@ -256,7 +234,7 @@ def median(self) -> ValueType: "median", median, shape=self.__shape, - dtype=self.__median_dtype, + dtype=self.median_dtype, ) # Make immutable @@ -266,10 +244,10 @@ def median(self) -> ValueType: return median @cached_property - def mean(self) -> ValueType: + def mean(self) -> backend.Array: """Mean :math:`\\mathbb{E}(X)` of the random variable. - To learn about the dtype of the mean, see :attr:`moment_dtype`. + To learn about the dtype of the mean, see :attr:`expectation_dtype`. """ if self.__mean is None: raise NotImplementedError @@ -280,7 +258,7 @@ def mean(self) -> ValueType: "mean", mean, shape=self.__shape, - dtype=self.__moment_dtype, + dtype=self.expectation_dtype, ) # Make immutable @@ -290,11 +268,11 @@ def mean(self) -> ValueType: return mean @cached_property - def cov(self) -> ValueType: + def cov(self) -> backend.Array: """Covariance :math:`\\operatorname{Cov}(X) = \\mathbb{E}((X-\\mathbb{E}(X))(X-\\mathbb{E}(X))^\\top)` of the random variable. - To learn about the dtype of the covariance, see :attr:`moment_dtype`. - """ # pylint: disable=line-too-long + To learn about the dtype of the covariance, see :attr:`expectation_dtype`. + """ if self.__cov is None: raise NotImplementedError @@ -304,7 +282,7 @@ def cov(self) -> ValueType: "covariance", cov, shape=(self.size, self.size) if self.ndim > 0 else (), - dtype=self.__moment_dtype, + dtype=self.expectation_dtype, ) # Make immutable @@ -314,15 +292,18 @@ def cov(self) -> ValueType: return cov @cached_property - def var(self) -> ValueType: + def var(self) -> backend.Array: """Variance :math:`\\operatorname{Var}(X) = \\mathbb{E}((X-\\mathbb{E}(X))^2)` of the random variable. - To learn about the dtype of the variance, see :attr:`moment_dtype`. + To learn about the dtype of the variance, see :attr:`expectation_dtype`. """ if self.__var is None: try: - var = np.diag(self.cov).reshape(self.__shape).copy() + var = backend.reshape( + backend.diag(self.cov), + self.__shape, + ).copy() except NotImplementedError as exc: raise NotImplementedError from exc else: @@ -332,7 +313,7 @@ def var(self) -> ValueType: "variance", var, shape=self.__shape, - dtype=self.__moment_dtype, + dtype=self.expectation_dtype, ) # Make immutable @@ -342,17 +323,14 @@ def var(self) -> ValueType: return var @cached_property - def std(self) -> ValueType: + def std(self) -> backend.Array: """Standard deviation of the random variable. To learn about the dtype of the standard deviation, see - :attr:`moment_dtype`. + :attr:`expectation_dtype`. """ if self.__std is None: - try: - std = np.sqrt(self.var) - except NotImplementedError as exc: - raise NotImplementedError from exc + std = backend.sqrt(self.var) else: std = self.__std() @@ -360,7 +338,7 @@ def std(self) -> ValueType: "standard deviation", std, shape=self.__shape, - dtype=self.__moment_dtype, + dtype=self.expectation_dtype, ) # Make immutable @@ -370,20 +348,23 @@ def std(self) -> ValueType: return std @cached_property - def entropy(self) -> np.float_: - """Information-theoretic entropy :math:`H(X)` of the random variable.""" + def entropy(self) -> backend.Scalar: + r"""Information-theoretic entropy :math:`H(X)` of the random variable.""" if self.__entropy is None: raise NotImplementedError entropy = self.__entropy() - entropy = RandomVariable._ensure_numpy_float( - "entropy", entropy, force_scalar=True + RandomVariable._check_property_value( + "entropy", + value=entropy, + shape=(), + dtype=self.expectation_dtype, ) return entropy - def in_support(self, x: ValueType) -> bool: + def in_support(self, x: backend.Array) -> backend.Array: """Check whether the random variable takes value ``x`` with non-zero probability, i.e. if ``x`` is in the support of its distribution. @@ -395,36 +376,40 @@ def in_support(self, x: ValueType) -> bool: if self.__in_support is None: raise NotImplementedError - in_support = self.__in_support(self._as_value_type(x)) + in_support = self.__in_support(backend.asarray(x)) - if not isinstance(in_support, bool): - raise ValueError( - f"The function `in_support` must return a `bool`, but its return value " - f"is of type `{type(x)}`." - ) + self._check_return_value( + "in_support", + input_value=x, + return_value=in_support, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.bool, + ) return in_support - def sample(self, rng: np.random.Generator, size: ShapeLike = ()) -> ValueType: + def sample( + self, rng_state: RNGState, sample_shape: ShapeLike = () + ) -> backend.Array: """Draw realizations from a random variable. Parameters ---------- - rng - Random number generator used for sampling. - size + rng_state + Random number generator state used for sampling. + sample_shape Size of the drawn sample of realizations. """ if self.__sample is None: raise NotImplementedError("No sampling method provided.") - if not isinstance(rng, np.random.Generator): - msg = "Random number generators must be of type np.random.Generator." - raise TypeError(msg) + samples = self.__sample(rng_state, backend.asshape(sample_shape)) + + # TODO: Check shape and dtype - return self.__sample(rng=rng, size=_utils.as_shape(size)) + return samples - def cdf(self, x: ValueType) -> np.float_: + def cdf(self, x: backend.Array) -> backend.Array: """Cumulative distribution function. Parameters @@ -436,21 +421,26 @@ def cdf(self, x: ValueType) -> np.float_: The cdf evaluation will be broadcast over all additional dimensions. """ if self.__cdf is not None: - return RandomVariable._ensure_numpy_float( - "cdf", self.__cdf(self._as_value_type(x)) + cdf = self.__cdf(backend.asarray(x)) + elif self.__logcdf is not None: + cdf = backend.exp(self.logcdf(x)) + else: + raise NotImplementedError( + f"Neither the `cdf` nor the `logcdf` of the random variable object " + f"with type `{type(self).__name__}` is implemented." ) - if self.__logcdf is not None: - cdf = np.exp(self.logcdf(self._as_value_type(x))) - assert isinstance(cdf, np.float_) - return cdf - - raise NotImplementedError( - f"Neither the `cdf` nor the `logcdf` of the random variable object " - f"with type `{type(self).__name__}` is implemented." + self._check_return_value( + "cdf", + input_value=x, + return_value=cdf, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.float64, ) - def logcdf(self, x: ValueType) -> np.float_: + return cdf + + def logcdf(self, x: backend.Array) -> backend.Array: """Log-cumulative distribution function. Parameters @@ -462,21 +452,26 @@ def logcdf(self, x: ValueType) -> np.float_: The logcdf evaluation will be broadcast over all additional dimensions. """ if self.__logcdf is not None: - return RandomVariable._ensure_numpy_float( - "logcdf", self.__logcdf(self._as_value_type(x)) + logcdf = self.__logcdf(backend.asarray(x)) + elif self.__cdf is not None: + logcdf = backend.log(self.cdf(x)) + else: + raise NotImplementedError( + f"Neither the `logcdf` nor the `cdf` of the random variable object " + f"with type `{type(self).__name__}` is implemented." ) - if self.__cdf is not None: - logcdf = np.log(self.__cdf(x)) - assert isinstance(logcdf, np.float_) - return logcdf - - raise NotImplementedError( - f"Neither the `logcdf` nor the `cdf` of the random variable object " - f"with type `{type(self).__name__}` is implemented." + self._check_return_value( + "logcdf", + input_value=x, + return_value=logcdf, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.float64, ) - def quantile(self, p: FloatLike) -> ValueType: + return logcdf + + def quantile(self, p: backend.Array) -> backend.Array: """Quantile function. The quantile function :math:`Q \\colon [0, 1] \\to \\mathbb{R}` of a random @@ -498,34 +493,25 @@ def quantile(self, p: FloatLike) -> ValueType: if self.__quantile is None: raise NotImplementedError - try: - p = _utils.as_numpy_scalar(p, dtype=np.floating) - except TypeError as exc: - raise TypeError( - "The given argument `p` can not be cast to a `np.floating` object." - ) from exc - quantile = self.__quantile(p) - if quantile.shape != self.__shape: - raise ValueError( - f"The quantile function should return values of the same shape as the " - f"random variable, i.e. {self.__shape}, but it returned a value with " - f"{quantile.shape}." - ) - - if quantile.dtype != self.__dtype: - raise ValueError( - f"The quantile function should return values of the same dtype as the " - f"random variable, i.e. `{self.__dtype.name}`, but it returned a value " - f"with dtype `{quantile.dtype.name}`." - ) + self._check_return_value( + "quantile", + input_value=p, + return_value=quantile, + expected_shape=p.shape + self.shape, + expected_dtype=self.dtype, + ) return quantile def __getitem__(self, key: ArrayIndicesLike) -> "RandomVariable": + # Shape inference + # For simplicity, this should not be computed using backend, but rather in numpy + shape = np.broadcast_to(np.empty(()), self.shape)[key].shape + return RandomVariable( - shape=np.empty(shape=self.shape)[key].shape, + shape=shape, dtype=self.dtype, sample=lambda rng, size: self.sample(rng, size)[key], mode=lambda: self.mode[key], @@ -533,7 +519,6 @@ def __getitem__(self, key: ArrayIndicesLike) -> "RandomVariable": var=lambda: self.var[key], std=lambda: self.std[key], entropy=lambda: self.entropy, - as_value_type=self.__as_value_type, ) def reshape(self, newshape: ShapeLike) -> "RandomVariable": @@ -545,7 +530,7 @@ def reshape(self, newshape: ShapeLike) -> "RandomVariable": New shape for the random variable. It must be compatible with the original shape. """ - newshape = _utils.as_shape(newshape) + newshape = backend.asshape(newshape) return RandomVariable( shape=newshape, @@ -558,7 +543,6 @@ def reshape(self, newshape: ShapeLike) -> "RandomVariable": var=lambda: self.var.reshape(newshape), std=lambda: self.std.reshape(newshape), entropy=lambda: self.entropy, - as_value_type=self.__as_value_type, ) def transpose(self, *axes: int) -> "RandomVariable": @@ -569,8 +553,13 @@ def transpose(self, *axes: int) -> "RandomVariable": axes : See documentation of :meth:`numpy.ndarray.transpose`. """ + + # Shape inference + # For simplicity, this should not be computed using backend, but rather in numpy + shape = np.broadcast_to(np.empty(()), self.shape).transpose(*axes).shape + return RandomVariable( - shape=np.empty(shape=self.shape).transpose(*axes).shape, + shape=shape, dtype=self.dtype, sample=lambda rng, size: self.sample(rng, size).transpose(*axes), mode=lambda: self.mode.transpose(*axes), @@ -580,7 +569,6 @@ def transpose(self, *axes: int) -> "RandomVariable": var=lambda: self.var.transpose(*axes), std=lambda: self.std.transpose(*axes), entropy=lambda: self.entropy, - as_value_type=self.__as_value_type, ) T = property(transpose) @@ -591,7 +579,9 @@ def __neg__(self) -> "RandomVariable": return RandomVariable( shape=self.shape, dtype=self.dtype, - sample=lambda rng, size: -self.sample(rng=rng, size=size), + sample=lambda seed, sample_shape: -self.sample( + rng_state=seed, sample_shape=sample_shape + ), in_support=lambda x: self.in_support(-x), mode=lambda: -self.mode, median=lambda: -self.median, @@ -599,14 +589,15 @@ def __neg__(self) -> "RandomVariable": cov=lambda: self.cov, var=lambda: self.var, std=lambda: self.std, - as_value_type=self.__as_value_type, ) def __pos__(self) -> "RandomVariable": return RandomVariable( shape=self.shape, dtype=self.dtype, - sample=lambda rng, size: +self.sample(rng=rng, size=size), + sample=lambda seed, sample_shape: +self.sample( + rng_state=seed, sample_shape=sample_shape + ), in_support=lambda x: self.in_support(+x), mode=lambda: +self.mode, median=lambda: +self.median, @@ -614,14 +605,15 @@ def __pos__(self) -> "RandomVariable": cov=lambda: self.cov, var=lambda: self.var, std=lambda: self.std, - as_value_type=self.__as_value_type, ) def __abs__(self) -> "RandomVariable": return RandomVariable( shape=self.shape, dtype=self.dtype, - sample=lambda rng, size: abs(self.sample(rng=rng, size=size)), + sample=lambda seed, sample_shape: abs( + self.sample(rng_state=seed, sample_shape=sample_shape) + ), ) # Binary arithmetic operations @@ -743,56 +735,12 @@ def __rpow__(self, other: Any) -> "RandomVariable": return pow_(other, self) - @staticmethod - def infer_median_dtype(value_dtype: DTypeLike) -> np.dtype: - """Infer the dtype of the median. - - Set the dtype to the dtype arising from - the multiplication of values with dtypes :attr:`dtype` and - :class:`numpy.float_`. This is motivated by the fact that, even for discrete - random variables, e.g. integer-valued random variables, the :attr:`median` - might lie in between two values in which case these values are averaged. For - example, a uniform random variable on :math:`\\{ 1, 2, 3, 4 \\}` will have a - median of :math:`2.5`. - - Parameters - ---------- - value_dtype : - Dtype of a value. - """ - return RandomVariable.infer_moment_dtype(value_dtype) - - @staticmethod - def infer_moment_dtype(value_dtype: DTypeLike) -> np.dtype: - """Infer the dtype of any moment. - - Infers the dtype of any (function of a) moment of the random variable, e.g. its - :attr:`mean`, :attr:`cov`, :attr:`var`, or :attr:`std`. Returns the - dtype arising from the multiplication of values with dtypes :attr:`dtype` - and :class:`numpy.float_`. This is motivated by the mathematical definition of a - moment as a sum or an integral over products of probabilities and values of the - random variable, which are represented as using the dtypes :class:`numpy.float_` - and :attr:`dtype`, respectively. - - Parameters - ---------- - value_dtype : - Dtype of a value. - """ - return np.promote_types(value_dtype, np.float_) - - def _as_value_type(self, x: Any) -> ValueType: - if self.__as_value_type is not None: - return self.__as_value_type(x) - - return x - @staticmethod def _check_property_value( name: str, - value: Any, - shape: Optional[Tuple[int, ...]] = None, - dtype: Optional[np.dtype] = None, + value: backend.Array, + shape: Optional[ShapeType] = None, + dtype: Optional[backend.DType] = None, ): if shape is not None: if value.shape != shape: @@ -802,50 +750,43 @@ def _check_property_value( ) if dtype is not None: - if not np.issubdtype(value.dtype, dtype): + if value.dtype != dtype: raise ValueError( f"The {name} of the random variable does not have the correct " - f"dtype. Expected {dtype.name} but got {value.dtype.name}." + f"dtype. Expected {str(dtype)} but got {str(value.dtype)}." ) - @classmethod - def _ensure_numpy_float( - cls, name: str, value: Any, force_scalar: bool = False - ) -> Union[np.float_, np.ndarray]: - if np.isscalar(value): - if not isinstance(value, np.float_): - try: - value = _utils.as_numpy_scalar(value, dtype=np.float_) - except TypeError as err: - raise TypeError( - f"The function `{name}` specified via the constructor of " - f"`{cls.__name__}` must return a scalar value that can be " - f"converted to a `np.float_`, which is not possible for " - f"{value} of type {type(value)}." - ) from err - elif not force_scalar: - try: - value = np.asarray(value, dtype=np.float_) - except TypeError as err: - raise TypeError( - f"The function `{name}` specified via the constructor of " - f"`{cls.__name__}` must return a value that can be converted " - f"to a `np.ndarray` of type `np.float_`, which is not possible " - f"for {value} of type {type(value)}." - ) from err - else: - raise TypeError( - f"The function `{name}` specified via the constructor of " - f"`{cls.__name__}` must return a scalar value, but {value} of type " - f"{type(value)} is not scalar." - ) + def _check_return_value( + self, + method_name: str, + input_value: backend.Array, + return_value: backend.Array, + expected_shape: Optional[ShapeType] = None, + expected_dtype: Optional[backend.DType] = None, + ): + # pylint: disable=too-many-arguments - assert isinstance(value, (np.float_, np.ndarray)) + if expected_shape is not None: + if return_value.shape != expected_shape: + raise ValueError( + f"The return value of the function `{method_name}` does not have " + f"the correct shape for an input with shape {input_value.shape} " + f"and a random variable with shape {self.shape}. Expected " + f"{expected_shape} but got {return_value.shape}." + ) - return value + if expected_dtype is not None: + if return_value.dtype != expected_dtype: + raise ValueError( + f"The return value of the function `{method_name}` does not have " + f"the correct dtype for an input with dtype " + f"{str(input_value.dtype)} and a random variable with dtype " + f"{str(self.dtype)}. Expected {str(expected_dtype)} but got " + f"{str(return_value.dtype)}." + ) -class DiscreteRandomVariable(RandomVariable[ValueType]): +class DiscreteRandomVariable(RandomVariable): """Random variable with countable range. Discrete random variables map to a countable set. Typical examples are the natural @@ -888,21 +829,6 @@ class DiscreteRandomVariable(RandomVariable[ValueType]): (Element-wise) standard deviation of the random variable. entropy : Shannon entropy :math:`H(X)` of the random variable. - as_value_type : - Function which can be used to transform user-supplied arguments, interpreted as - realizations of this random variable, to an easy-to-process, normalized format. - Will be called internally to transform the argument of functions like - :meth:`~DiscreteRandomVariable.in_support`, :meth:`~DiscreteRandomVariable.cdf` - and :meth:`~DiscreteRandomVariable.logcdf`, :meth:`~DiscreteRandomVariable.pmf` - and :meth:`~DiscreteRandomVariable.logpmf`, and potentially by similar - functions in subclasses. - - For instance, this method is useful if (``log``) - :meth:`~DiscreteRandomVariable.cdf` and (``log``) - :meth:`~DiscreteRandomVariable.pmf` both only work on :class:`numpy.float_` - arguments, but we still want the user to be able to pass Python - :class:`float`. Then :meth:`~DiscreteRandomVariable.as_value_type` - should be set to something like ``lambda x: np.float64(x)``. See Also -------- @@ -912,42 +838,40 @@ class DiscreteRandomVariable(RandomVariable[ValueType]): Examples -------- >>> # Create a custom categorical random variable - >>> import numpy as np + >>> from probnum import backend >>> from probnum.randvars import DiscreteRandomVariable >>> >>> # Distribution parameters - >>> support = np.array([-1, 0, 1]) - >>> p = np.array([0.2, 0.5, 0.3]) + >>> support = backend.asarray([-1, 0, 1]) + >>> p = backend.asarray([0.2, 0.5, 0.3]) >>> parameters_categorical = { ... "support" : support, ... "p" : p} >>> >>> # Sampling function - >>> def sample_categorical(rng, size=()): - ... return rng.choice(a=support, size=size, p=p) + >>> def sample_categorical(rng_state, sample_shape=()): + ... return backend.random.choice( + ... rng_state=rng_state, x=support, shape=sample_shape, p=p + ... ) >>> >>> # Probability mass function >>> def pmf_categorical(x): - ... idx = np.where(x == support)[0] - ... if len(idx) > 0: - ... return p[idx] - ... else: - ... return 0.0 + ... idx = backend.where(x == support, p, backend.zeros_like(p)) >>> >>> # Create custom random variable >>> x = DiscreteRandomVariable( ... shape=(), - ... dtype=np.dtype(np.int64), + ... dtype=backend.int64, ... parameters=parameters_categorical, ... sample=sample_categorical, ... pmf=pmf_categorical, - ... mean=lambda : np.float64(0), - ... median=lambda : np.float64(0), + ... mean=lambda : backend.float64(0), + ... median=lambda : backend.float64(0), ... ) >>> >>> # Sample from new random variable - >>> rng = np.random.default_rng(42) - >>> x.sample(rng=rng, size=3) + >>> rng_state = backend.random.rng_state(42) + >>> x.sample(rng_state=rng_state, sample_shape=3) array([1, 0, 1]) >>> x.pmf(2) array(0.) @@ -960,22 +884,23 @@ def __init__( shape: ShapeLike, dtype: DTypeLike, parameters: Optional[Dict[str, Any]] = None, - sample: Optional[Callable[[np.random.Generator, ShapeLike], ValueType]] = None, - in_support: Optional[Callable[[ValueType], bool]] = None, - pmf: Optional[Callable[[ValueType], np.float_]] = None, - logpmf: Optional[Callable[[ValueType], np.float_]] = None, - cdf: Optional[Callable[[ValueType], np.float_]] = None, - logcdf: Optional[Callable[[ValueType], np.float_]] = None, - quantile: Optional[Callable[[FloatLike], ValueType]] = None, - mode: Optional[Callable[[], ValueType]] = None, - median: Optional[Callable[[], ValueType]] = None, - mean: Optional[Callable[[], ValueType]] = None, - cov: Optional[Callable[[], ValueType]] = None, - var: Optional[Callable[[], ValueType]] = None, - std: Optional[Callable[[], ValueType]] = None, - entropy: Optional[Callable[[], np.float_]] = None, - as_value_type: Optional[Callable[[Any], ValueType]] = None, + sample: Optional[Callable[[RNGState, ShapeType], backend.Array]] = None, + in_support: Optional[Callable[[backend.Array], backend.Array]] = None, + pmf: Optional[Callable[[backend.Array], backend.Array]] = None, + logpmf: Optional[Callable[[backend.Array], backend.Array]] = None, + cdf: Optional[Callable[[backend.Array], backend.Array]] = None, + logcdf: Optional[Callable[[backend.Array], backend.Array]] = None, + quantile: Optional[Callable[[backend.Array], backend.Array]] = None, + mode: Optional[Callable[[], backend.Array]] = None, + median: Optional[Callable[[], backend.Array]] = None, + mean: Optional[Callable[[], backend.Array]] = None, + cov: Optional[Callable[[], backend.Array]] = None, + var: Optional[Callable[[], backend.Array]] = None, + std: Optional[Callable[[], backend.Array]] = None, + entropy: Optional[Callable[[], backend.Scalar]] = None, ): + # pylint: disable=too-many-arguments,too-many-locals + # Probability mass function self.__pmf = pmf self.__logpmf = logpmf @@ -996,10 +921,9 @@ def __init__( var=var, std=std, entropy=entropy, - as_value_type=as_value_type, ) - def pmf(self, x: ValueType) -> np.float_: + def pmf(self, x: backend.Array) -> backend.Array: """Probability mass function. Computes the probability of the random variable being equal to the given @@ -1020,19 +944,26 @@ def pmf(self, x: ValueType) -> np.float_: The pmf evaluation will be broadcast over all additional dimensions. """ if self.__pmf is not None: - return DiscreteRandomVariable._ensure_numpy_float("pmf", self.__pmf(x)) - - if self.__logpmf is not None: - pmf = np.exp(self.__logpmf(x)) - assert isinstance(pmf, np.float_) - return pmf + pmf = self.__pmf(backend.asarray(x)) + elif self.__logpmf is not None: + pmf = backend.exp(self.logpmf(x)) + else: + raise NotImplementedError( + f"Neither the `pmf` nor the `logpmf` of the discrete random variable " + f"object with type `{type(self).__name__}` is implemented." + ) - raise NotImplementedError( - f"Neither the `pmf` nor the `logpmf` of the discrete random variable " - f"object with type `{type(self).__name__}` is implemented." + self._check_return_value( + "pmf", + input_value=x, + return_value=pmf, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.float64, ) - def logpmf(self, x: ValueType) -> np.float_: + return pmf + + def logpmf(self, x: backend.Array) -> backend.Array: """Natural logarithm of the probability mass function. Parameters @@ -1044,22 +975,27 @@ def logpmf(self, x: ValueType) -> np.float_: The logpmf evaluation will be broadcast over all additional dimensions. """ if self.__logpmf is not None: - return DiscreteRandomVariable._ensure_numpy_float( - "logpmf", self.__logpmf(self._as_value_type(x)) + logpmf = self.__logpmf(backend.asarray(x)) + elif self.__pmf is not None: + logpmf = backend.log(self.pmf(x)) + else: + raise NotImplementedError( + f"Neither the `logpmf` nor the `pmf` of the discrete random variable " + f"object with type `{type(self).__name__}` is implemented." ) - if self.__pmf is not None: - logpmf = np.log(self.__pmf(self._as_value_type(x))) - assert isinstance(logpmf, np.float_) - return logpmf - - raise NotImplementedError( - f"Neither the `logpmf` nor the `pmf` of the discrete random variable " - f"object with type `{type(self).__name__}` is implemented." + self._check_return_value( + "logpmf", + input_value=x, + return_value=logpmf, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.float64, ) + return logpmf + -class ContinuousRandomVariable(RandomVariable[ValueType]): +class ContinuousRandomVariable(RandomVariable): """Random variable with uncountably infinite range. Continuous random variables map to a uncountably infinite set. Typically, this is a @@ -1102,23 +1038,6 @@ class ContinuousRandomVariable(RandomVariable[ValueType]): (Element-wise) standard deviation of the random variable. entropy : Differential entropy :math:`H(X)` of the random variable. - as_value_type : - Function which can be used to transform user-supplied arguments, interpreted as - realizations of this random variable, to an easy-to-process, normalized format. - Will be called internally to transform the argument of functions like - :meth:`~ContinuousRandomVariable.in_support`, - :meth:`~ContinuousRandomVariable.cdf` - and :meth:`~ContinuousRandomVariable.logcdf`, - :meth:`~ContinuousRandomVariable.pdf` and - :meth:`~ContinuousRandomVariable.logpdf`, and potentially by similar - functions in subclasses. - - For instance, this method is useful if (``log``) - :meth:`~ContinuousRandomVariable.cdf` and (``log``) - :meth:`~ContinuousRandomVariable.pdf` both only work on :class:`numpy.float_` - arguments, but we still want the user to be able to pass Python - :class:`float`. Then :meth:`~ContinuousRandomVariable.as_value_type` - should be set to something like ``lambda x: np.float64(x)``. See Also -------- @@ -1128,7 +1047,7 @@ class ContinuousRandomVariable(RandomVariable[ValueType]): Examples -------- >>> # Create a custom uniformly distributed random variable - >>> import numpy as np + >>> from probnum import backend >>> >>> # Distribution parameters >>> a = 0.0 @@ -1136,8 +1055,8 @@ class ContinuousRandomVariable(RandomVariable[ValueType]): >>> parameters_uniform = {"bounds" : [a, b]} >>> >>> # Sampling function - >>> def sample_uniform(rng, size=()): - ... return rng.uniform(size=size) + >>> def sample_uniform(rng_state, sample_shape=()): + ... return backend.random.uniform(rng_state=rng_state, shape=sample_shape) >>> >>> # Probability density function >>> def pdf_uniform(x): @@ -1160,8 +1079,8 @@ class ContinuousRandomVariable(RandomVariable[ValueType]): ... ) >>> >>> # Sample from new random variable - >>> rng = np.random.default_rng(42) - >>> u.sample(rng=rng, size=3) + >>> rng_state = backend.random.rng_state(42) + >>> u.sample(rng_state, 3) array([0.77395605, 0.43887844, 0.85859792]) >>> u.pdf(0.5) array(1.) @@ -1174,22 +1093,23 @@ def __init__( shape: ShapeLike, dtype: DTypeLike, parameters: Optional[Dict[str, Any]] = None, - sample: Optional[Callable[[np.random.Generator, ShapeLike], ValueType]] = None, - in_support: Optional[Callable[[ValueType], bool]] = None, - pdf: Optional[Callable[[ValueType], np.float_]] = None, - logpdf: Optional[Callable[[ValueType], np.float_]] = None, - cdf: Optional[Callable[[ValueType], np.float_]] = None, - logcdf: Optional[Callable[[ValueType], np.float_]] = None, - quantile: Optional[Callable[[FloatLike], ValueType]] = None, - mode: Optional[Callable[[], ValueType]] = None, - median: Optional[Callable[[], ValueType]] = None, - mean: Optional[Callable[[], ValueType]] = None, - cov: Optional[Callable[[], ValueType]] = None, - var: Optional[Callable[[], ValueType]] = None, - std: Optional[Callable[[], ValueType]] = None, - entropy: Optional[Callable[[], np.float_]] = None, - as_value_type: Optional[Callable[[Any], ValueType]] = None, + sample: Optional[Callable[[RNGState, ShapeType], backend.Array]] = None, + in_support: Optional[Callable[[backend.Array], backend.Array]] = None, + pdf: Optional[Callable[[backend.Array], backend.Array]] = None, + logpdf: Optional[Callable[[backend.Array], backend.Array]] = None, + cdf: Optional[Callable[[backend.Array], backend.Array]] = None, + logcdf: Optional[Callable[[backend.Array], backend.Array]] = None, + quantile: Optional[Callable[[backend.Array], backend.Array]] = None, + mode: Optional[Callable[[], backend.Array]] = None, + median: Optional[Callable[[], backend.Array]] = None, + mean: Optional[Callable[[], backend.Array]] = None, + cov: Optional[Callable[[], backend.Array]] = None, + var: Optional[Callable[[], backend.Array]] = None, + std: Optional[Callable[[], backend.Array]] = None, + entropy: Optional[Callable[[], backend.Array]] = None, ): + # pylint: disable=too-many-arguments,too-many-locals + # Probability density function self.__pdf = pdf self.__logpdf = logpdf @@ -1210,10 +1130,9 @@ def __init__( var=var, std=std, entropy=entropy, - as_value_type=as_value_type, ) - def pdf(self, x: ValueType) -> np.float_: + def pdf(self, x: backend.Array) -> backend.Array: """Probability density function. The area under the curve defined by the probability density function @@ -1234,21 +1153,26 @@ def pdf(self, x: ValueType) -> np.float_: The pdf evaluation will be broadcast over all additional dimensions. """ if self.__pdf is not None: - return ContinuousRandomVariable._ensure_numpy_float( - "pdf", self.__pdf(self._as_value_type(x)) + pdf = self.__pdf(backend.asarray(x)) + elif self.__logpdf is not None: + pdf = backend.exp(self.logpdf(x)) + else: + raise NotImplementedError( + f"Neither the `pdf` nor the `logpdf` of the continuous random variable " + f"object with type `{type(self).__name__}` is implemented." ) - if self.__logpdf is not None: - pdf = np.exp(self.__logpdf(self._as_value_type(x))) - assert isinstance(pdf, np.float_) - - return pdf - raise NotImplementedError( - f"Neither the `pdf` nor the `logpdf` of the continuous random variable " - f"object with type `{type(self).__name__}` is implemented." + self._check_return_value( + "pdf", + input_value=x, + return_value=pdf, + expected_shape=x.shape[: x.ndim - self.ndim], + expected_dtype=backend.float64, ) - def logpdf(self, x: ValueType) -> np.float_: + return pdf + + def logpdf(self, x: backend.Array) -> backend.Array: """Natural logarithm of the probability density function. Parameters @@ -1260,16 +1184,21 @@ def logpdf(self, x: ValueType) -> np.float_: The logpdf evaluation will be broadcast over all additional dimensions. """ if self.__logpdf is not None: - return ContinuousRandomVariable._ensure_numpy_float( - "logpdf", self.__logpdf(self._as_value_type(x)) + logpdf = self.__logpdf(backend.asarray(x)) + elif self.__pdf is not None: + logpdf = backend.log(self.pdf(x)) + else: + raise NotImplementedError( + f"Neither the `logpdf` nor the `pdf` of the continuous random variable " + f"object with type `{type(self).__name__}` is implemented." ) - if self.__pdf is not None: - logpdf = np.log(self.__pdf(self._as_value_type(x))) - assert isinstance(logpdf, np.float_) - return logpdf - - raise NotImplementedError( - f"Neither the `logpdf` nor the `pdf` of the continuous random variable " - f"object with type `{type(self).__name__}` is implemented." + self._check_return_value( + "logpdf", + input_value=x, + return_value=logpdf, + expected_shape=x.shape[: -self.ndim], + expected_dtype=backend.float64, ) + + return logpdf diff --git a/src/probnum/randvars/_randomvariablelist.py b/src/probnum/randvars/_randomvariablelist.py index 21cbf65a8..d80a209e3 100644 --- a/src/probnum/randvars/_randomvariablelist.py +++ b/src/probnum/randvars/_randomvariablelist.py @@ -5,7 +5,7 @@ import numpy as np -from probnum import randvars +from probnum.randvars import _random_variable class _RandomVariableList(list): @@ -25,14 +25,16 @@ def __init__(self, rv_list: list): if len(rv_list) > 0: # First element as a proxy for checking all elements - if not isinstance(rv_list[0], randvars.RandomVariable): + if not isinstance(rv_list[0], _random_variable.RandomVariable): raise TypeError( "RandomVariableList expects RandomVariable elements, but " + f"first element has type {type(rv_list[0])}." ) super().__init__(rv_list) - def __getitem__(self, idx) -> Union[randvars.RandomVariable, "_RandomVariableList"]: + def __getitem__( + self, idx + ) -> Union[_random_variable.RandomVariable, "_RandomVariableList"]: result = super().__getitem__(idx) # Make sure to wrap the result into a _RandomVariableList if necessary diff --git a/src/probnum/randvars/_scipy_stats.py b/src/probnum/randvars/_scipy_stats.py deleted file mode 100644 index 9367f5494..000000000 --- a/src/probnum/randvars/_scipy_stats.py +++ /dev/null @@ -1,259 +0,0 @@ -"""Wrapper classes for SciPy random variables.""" - -from typing import Any, Dict, Union - -import numpy as np -import scipy.stats - -from probnum import utils as _utils - -from . import _normal, _random_variable - -ValueType = Union[np.generic, np.ndarray] - -# pylint: disable=protected-access - - -class _SciPyRandomVariableMixin: - """Mix-in class for SciPy random variable wrappers.""" - - @property - def scipy_rv(self): - """SciPy random variable.""" - return self._scipy_rv - - -class WrappedSciPyRandomVariable( - _SciPyRandomVariableMixin, _random_variable.RandomVariable[ValueType] -): - """Wrapper for SciPy random variable objects. - - Parameters - ---------- - scipy_rv - SciPy random variable. - """ - - def __init__( - self, - scipy_rv: Union[ - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ], - ): - self._scipy_rv = scipy_rv - - super().__init__(**_rv_init_kwargs_from_scipy_rv(scipy_rv)) - - -class WrappedSciPyDiscreteRandomVariable( - _SciPyRandomVariableMixin, _random_variable.DiscreteRandomVariable[ValueType] -): - """Wrapper for discrete SciPy random variable objects. - - Parameters - ---------- - scipy_rv - Discrete SciPy random variable. - """ - - def __init__( - self, - scipy_rv: Union[ - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ], - ): - if isinstance(scipy_rv, scipy.stats._distn_infrastructure.rv_frozen): - if not isinstance(scipy_rv.dist, scipy.stats.rv_discrete): - raise ValueError("The given SciPy random variable is not discrete.") - - self._scipy_rv = scipy_rv - - rv_kwargs = _rv_init_kwargs_from_scipy_rv(scipy_rv) - - rv_kwargs["pmf"] = _return_numpy( - getattr(scipy_rv, "pmf", None), - dtype=np.float_, - ) - - rv_kwargs["logpmf"] = _return_numpy( - getattr(scipy_rv, "logpmf", None), - dtype=np.float_, - ) - - super().__init__(**rv_kwargs) - - -class WrappedSciPyContinuousRandomVariable( - _SciPyRandomVariableMixin, _random_variable.ContinuousRandomVariable[ValueType] -): - """Wrapper for continuous SciPy random variable objects. - - Parameters - ---------- - scipy_rv - Continuous SciPy random variable. - """ - - def __init__( - self, - scipy_rv: Union[ - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ], - ): - if isinstance(scipy_rv, scipy.stats._distn_infrastructure.rv_frozen): - if not isinstance(scipy_rv.dist, scipy.stats.rv_continuous): - raise ValueError("The given SciPy random variable is not continuous.") - - self._scipy_rv = scipy_rv - - rv_kwargs = _rv_init_kwargs_from_scipy_rv(scipy_rv) - - rv_kwargs["pdf"] = _return_numpy( - getattr(scipy_rv, "pdf", None), - dtype=np.float_, - ) - - rv_kwargs["logpdf"] = _return_numpy( - getattr(scipy_rv, "logpdf", None), - dtype=np.float_, - ) - - super().__init__(**rv_kwargs) - - -def wrap_scipy_rv( - scipy_rv: Union[ - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ] -) -> _random_variable.RandomVariable: - """Transform SciPy distributions to ProbNum :class:`RandomVariable`s. - - Parameters - ---------- - scipy_rv : - SciPy random variable. - """ - - # pylint: disable=too-many-return-statements - - # Random variables with concrete implementations in ProbNum - if isinstance(scipy_rv, scipy.stats._distn_infrastructure.rv_frozen): - # Univariate distributions - if scipy_rv.dist.name == "norm": - # Normal distribution - return _normal.Normal( - mean=scipy_rv.mean(), - cov=scipy_rv.var(), - ) - elif isinstance(scipy_rv, scipy.stats._multivariate.multi_rv_frozen): - # Multivariate distributions - if scipy_rv.__class__.__name__ == "multivariate_normal_frozen": - # Multivariate normal distribution - return _normal.Normal( - mean=scipy_rv.mean, - cov=scipy_rv.cov, - ) - - # Generic random variables - if isinstance(scipy_rv, scipy.stats._distn_infrastructure.rv_frozen): - if isinstance(scipy_rv.dist, scipy.stats.rv_discrete): - return WrappedSciPyDiscreteRandomVariable(scipy_rv) - if isinstance(scipy_rv.dist, scipy.stats.rv_continuous): - return WrappedSciPyContinuousRandomVariable(scipy_rv) - - assert isinstance(scipy_rv.dist, scipy.stats.rv_generic) - return WrappedSciPyRandomVariable(scipy_rv) - - if isinstance(scipy_rv, scipy.stats._multivariate.multi_rv_frozen): - has_pmf = hasattr(scipy_rv, "pmf") or hasattr(scipy_rv, "logpmf") - has_pdf = hasattr(scipy_rv, "pdf") or hasattr(scipy_rv, "logpdf") - - if has_pdf and has_pmf: - return WrappedSciPyRandomVariable(scipy_rv) - if has_pmf: - return WrappedSciPyDiscreteRandomVariable(scipy_rv) - if has_pdf: - return WrappedSciPyContinuousRandomVariable(scipy_rv) - - assert not has_pmf and not has_pdf - return WrappedSciPyRandomVariable(scipy_rv) - - raise ValueError(f"Unsupported argument type {type(scipy_rv)}") - - -def _rv_init_kwargs_from_scipy_rv( - scipy_rv: Union[ - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ], -) -> Dict[str, Any]: - """Create dictionary of random variable properties from a Scipy random variable. - - Parameters - ---------- - scipy_rv - SciPy random variable. - """ - # Infer shape and dtype - sample = _return_numpy(scipy_rv.rvs)() - - shape = sample.shape - dtype = sample.dtype - - median_dtype = np.promote_types(dtype, np.float_) - moments_dtype = np.promote_types(dtype, np.float_) - - # Support of univariate random variables - if isinstance(scipy_rv, scipy.stats._distn_infrastructure.rv_frozen): - - def in_support(x): - low, high = scipy_rv.support() - - return bool(low <= x <= high) - - else: - in_support = None - - def sample_from_scipy_rv(rng, size): - return scipy_rv.rvs(size=size, random_state=rng) - - if hasattr(scipy_rv, "rvs"): - sample_wrapper = sample_from_scipy_rv - else: - sample_wrapper = None - - return { - "shape": shape, - "dtype": dtype, - "sample": _return_numpy(sample_wrapper, dtype), - "in_support": in_support, - "cdf": _return_numpy(getattr(scipy_rv, "cdf", None), np.float_), - "logcdf": _return_numpy(getattr(scipy_rv, "logcdf", None), np.float_), - "quantile": _return_numpy(getattr(scipy_rv, "ppf", None), dtype), - "mode": None, # not offered by scipy.stats - "median": _return_numpy(getattr(scipy_rv, "median", None), median_dtype), - "mean": _return_numpy(getattr(scipy_rv, "mean", None), moments_dtype), - "cov": _return_numpy(getattr(scipy_rv, "cov", None), moments_dtype), - "var": _return_numpy(getattr(scipy_rv, "var", None), moments_dtype), - "std": _return_numpy(getattr(scipy_rv, "std", None), moments_dtype), - "entropy": _return_numpy(getattr(scipy_rv, "entropy", None), np.float_), - } - - -def _return_numpy(fun, dtype=None): - if fun is None: - return None - - def _wrapper(*args, **kwargs): - res = fun(*args, **kwargs) - - if np.isscalar(res): - return _utils.as_numpy_scalar(res, dtype=dtype) - - return np.asarray(res, dtype=dtype) - - return _wrapper diff --git a/src/probnum/randvars/_sym_mat_normal.py b/src/probnum/randvars/_sym_mat_normal.py new file mode 100644 index 000000000..ce6867c9f --- /dev/null +++ b/src/probnum/randvars/_sym_mat_normal.py @@ -0,0 +1,54 @@ +import numpy as np + +from probnum import backend, linops +from probnum.backend.random import RNGState +from probnum.backend.typing import ShapeType +from probnum.typing import LinearOperatorLike + +from . import _normal + + +class SymmetricMatrixNormal(_normal.Normal): + def __init__( + self, + mean: LinearOperatorLike, + cov: linops.SymmetricKronecker, + ) -> None: + if not isinstance(cov, linops.SymmetricKronecker): + raise ValueError( + "The covariance operator must have type `SymmetricKronecker`." + ) + if not cov.identical_factors: + raise ValueError("The covariance operator must have identical factors.") + + m, n = mean.shape + + if m != n or n != cov.A.shape[0] or n != cov.B.shape[1]: + raise ValueError( + "Normal distributions with symmetric Kronecker structured " + "kernels must have square mean and square kernels factors with " + "matching dimensions." + ) + + super().__init__(mean=linops.aslinop(mean), cov=cov) + + def _sample(self, rng_state: RNGState, sample_shape: ShapeType = ()) -> np.ndarray: + assert ( + isinstance(self.cov, linops.SymmetricKronecker) + and self.cov.identical_factors + ) + + n = self.mean.shape[1] + + # Draw standard normal samples + stdnormal_samples = backend.random.standard_normal( + rng_state, + shape=sample_shape + (n * n, 1), + dtype=self.dtype, + ) + + # Appendix E: Bartels, S., Probabilistic Linear Algebra, PhD Thesis 2019 + samples_scaled = linops.Symmetrize(n) @ (self._cov_cholesky @ stdnormal_samples) + + # TODO: can we avoid todense here and just return operator samples? + return self.dense_mean + samples_scaled.reshape(*sample_shape, n, n) diff --git a/src/probnum/randvars/_utils.py b/src/probnum/randvars/_utils.py index ff9821cc6..5253c2ba9 100644 --- a/src/probnum/randvars/_utils.py +++ b/src/probnum/randvars/_utils.py @@ -1,12 +1,11 @@ """Utility functions for random variables.""" from typing import Any -import numpy as np import scipy.sparse -import probnum.linops +from probnum import backend, linops -from . import _constant, _random_variable, _scipy_stats +from . import _constant, _random_variable def asrandvar(obj: Any) -> _random_variable.RandomVariable: @@ -17,53 +16,39 @@ def asrandvar(obj: Any) -> _random_variable.RandomVariable: Parameters ---------- - obj : + obj Object to be represented as a :class:`RandomVariable`. + Returns + ------- + randvar + Object as a :class:`RandomVariable`. + + Raises + ------ + ValueError + If the object cannot be represented as a :class:`RandomVariable`. + See Also -------- RandomVariable : Class representing random variables. - - Examples - -------- - >>> from scipy.stats import bernoulli - >>> import probnum as pn - >>> import numpy as np - >>> bern = bernoulli(p=0.5) - >>> bern_pn = pn.asrandvar(bern) - >>> rng = np.random.default_rng(42) - >>> bern_pn.sample(rng=rng, size=5) - array([1, 0, 1, 1, 0]) """ - # pylint: disable=protected-access # RandomVariable if isinstance(obj, _random_variable.RandomVariable): return obj # Scalar - if np.isscalar(obj): + if backend.ndim(obj) == 0: return _constant.Constant(support=obj) - # Numpy array or sparse matrix - if isinstance(obj, (np.ndarray, scipy.sparse.spmatrix)): + # NumPy array or sparse matrix + if backend.isarray(obj) or isinstance(obj, scipy.sparse.spmatrix): return _constant.Constant(support=obj) # Linear Operators - if isinstance( - obj, (probnum.linops.LinearOperator, scipy.sparse.linalg.LinearOperator) - ): - return _constant.Constant(support=probnum.linops.aslinop(obj)) - - # Scipy random variable - if isinstance( - obj, - ( - scipy.stats._distn_infrastructure.rv_frozen, - scipy.stats._multivariate.multi_rv_frozen, - ), - ): - return _scipy_stats.wrap_scipy_rv(obj) + if isinstance(obj, (linops.LinearOperator, scipy.sparse.linalg.LinearOperator)): + return _constant.Constant(support=linops.aslinop(obj)) raise ValueError( f"Argument of type {type(obj)} cannot be converted to a random variable." diff --git a/src/probnum/typing.py b/src/probnum/typing.py index 1d110b363..4a75ac163 100644 --- a/src/probnum/typing.py +++ b/src/probnum/typing.py @@ -3,125 +3,55 @@ This module defines commonly used types in the library. These are separated into two different kinds, API types and argument types. -**API types** (``*Type``) are aliases which define custom types -used throughout the library. Objects ofthis type may be supplied as arguments -or returned by a method. +**API types** (``*Type``) are aliases which define custom types used throughout the +library. Objects of this type may be supplied as arguments or returned by a method. **Argument types** (``*Like``) are aliases which define commonly used method -arguments that are internally converted to a standardized representation. -These should only ever be used in the signature of a method and then -be converted internally, e.g. in a class instantiation or an interface. -They enable the user to conveniently supply a variety of objects of different -types for the same argument, while ensuring a unified internal representation of -those same objects. As an example, take the different ways a user might specify -a shape: ``2``, ``(2,)``, ``[2, 2]``. These may all be acceptable arguments -to a function taking a shape, but internally should always be converted -to a :attr:`ShapeType`, i.e. a tuple of ``int``\\ s. +arguments that are internally converted to a standardized representation. These should +only ever be used in the signature of a method and then be converted internally, e.g. in +a class instantiation or an interface. They enable the user to conveniently supply a +variety of objects of different types for the same argument, while ensuring a unified +internal representation of those same objects. As an example, a user might pass an +object which can be converted to a finite dimensional linear operator. This argument +could be an :class:`~probnum.backend.Array`, a sparse matrix +:class:`~scipy.sparse.spmatrix` or a :class:`~probnum.linops.LinearOperator`. The type +alias :attr:`LinearOperatorLike` combines all these in a single type. Internally, the +passed argument is then converted to a :class:`~probnum.linops.LinearOperator`. """ from __future__ import annotations -import numbers -from typing import Iterable, Tuple, Union +from typing import Union -import numpy as np -from numpy.typing import ArrayLike as _NumPyArrayLike, DTypeLike as _NumPyDTypeLike import scipy.sparse +from probnum import backend +from probnum.backend.typing import ArrayLike + __all__ = [ # API Types - "ShapeType", - "ScalarType", "MatrixType", # Argument Types - "IntLike", - "FloatLike", - "ShapeLike", - "DTypeLike", - "ArrayIndicesLike", - "ScalarLike", - "ArrayLike", "LinearOperatorLike", - "NotImplementedType", ] ######################################################################################## # API Types ######################################################################################## -# Array Utilities -ShapeType = Tuple[int, ...] -"""Type defining a shape of an object.""" - # Scalars, Arrays and Matrices -ScalarType = np.ndarray -"""Type defining a scalar.""" +MatrixType = Union[backend.Array, "probnum.linops.LinearOperator"] +"""Type defining a matrix, i.e. a linear map between finite-dimensional vector spaces. -MatrixType = Union[np.ndarray, "probnum.linops.LinearOperator"] -"""Type defining a matrix, i.e. a linear map between \ -finite-dimensional vector spaces.""" +An object :code:`matrix`, which behaves like an :class:`~probnum.backend.Array` and +satisfies :code:`matrix.ndim == 2`. +""" ######################################################################################## # Argument Types ######################################################################################## -# Python Numbers -IntLike = Union[int, numbers.Integral, np.integer] -"""Object that can be converted to an integer. - -Arguments of type :attr:`IntLike` should always be converted -into :class:`int`\\ s before further internal processing.""" - -FloatLike = Union[float, numbers.Real, np.floating] -"""Object that can be converted to a float. - -Arguments of type :attr:`FloatLike` should always be converted -into :class:`float`\\ s before further internal processing.""" - -# Array Utilities -ShapeLike = Union[IntLike, Iterable[IntLike]] -"""Object that can be converted to a shape. - -Arguments of type :attr:`ShapeLike` should always be converted -into :class:`ShapeType` using the function :func:`probnum.utils.as_shape` -before further internal processing.""" - -DTypeLike = _NumPyDTypeLike -"""Object that can be converted to an array dtype. - -Arguments of type :attr:`DTypeLike` should always be converted -into :class:`numpy.dtype`\\ s before further internal processing.""" - -_ArrayIndexLike = Union[ - int, - slice, - type(Ellipsis), - None, - np.newaxis, - np.ndarray, -] -ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]] -"""Object that can be converted to indices of an array. - -Type of the argument to the :meth:`__getitem__` method of a NumPy-like array type -such as :class:`numpy.ndarray`, :class:`probnum.linops.LinearOperator` or -:class:`probnum.randvars.RandomVariable`.""" - # Scalars, Arrays and Matrices -ScalarLike = Union[int, float, complex, numbers.Number, np.number] -"""Object that can be converted to a scalar value. - -Arguments of type :attr:`ScalarLike` should always be converted -into :class:`numpy.number`\\ s using the function :func:`probnum.utils.as_scalar` -before further internal processing.""" - -ArrayLike = _NumPyArrayLike -"""Object that can be converted to an array. - -Arguments of type :attr:`ArrayLike` should always be converted -into :class:`numpy.ndarray`\\ s using the function :func:`np.asarray` -before further internal processing.""" - LinearOperatorLike = Union[ ArrayLike, scipy.sparse.spmatrix, @@ -129,14 +59,6 @@ ] """Object that can be converted to a :class:`~probnum.linops.LinearOperator`. -Arguments of type :attr:`LinearOperatorLike` should always be converted -into :class:`~probnum.linops.\\ -LinearOperator`\\ s using the function :func:`probnum.linops.aslinop` before further -internal processing.""" - -######################################################################################## -# Other Types -######################################################################################## - -NotImplementedType = type(NotImplemented) -"""Type of the `NotImplemented` constant.""" +Arguments of type :attr:`LinearOperatorLike` should always be converted into +:class:`~probnum.linops.LinearOperator`\\ s using the function +:func:`probnum.linops.aslinop` before further internal processing.""" diff --git a/src/probnum/utils/__init__.py b/src/probnum/utils/__init__.py deleted file mode 100644 index c89157c2e..000000000 --- a/src/probnum/utils/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Utility Functions.""" - -from .argutils import * -from .arrayutils import * - -# Public classes and functions. Order is reflected in documentation. -__all__ = [ - "as_colvec", - "atleast_1d", - "as_numpy_scalar", - "as_shape", -] diff --git a/src/probnum/utils/argutils.py b/src/probnum/utils/argutils.py deleted file mode 100644 index 55754dba0..000000000 --- a/src/probnum/utils/argutils.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Utility functions for argument types.""" - -import numbers -from typing import Optional - -import numpy as np - -from probnum.typing import DTypeLike, ScalarLike, ShapeLike, ShapeType - -__all__ = ["as_shape", "as_numpy_scalar"] - - -def as_shape(x: ShapeLike, ndim: Optional[numbers.Integral] = None) -> ShapeType: - """Convert a shape representation into a shape defined as a tuple of ints. - - Parameters - ---------- - x - Shape representation. - """ - if isinstance(x, (int, numbers.Integral, np.integer)): - shape = (int(x),) - elif isinstance(x, tuple) and all(isinstance(item, int) for item in x): - shape = x - else: - try: - _ = iter(x) - except TypeError as e: - raise TypeError( - f"The given shape {x} must be an integer or an iterable of integers." - ) from e - - if not all(isinstance(item, (int, numbers.Integral, np.integer)) for item in x): - raise TypeError(f"The given shape {x} must only contain integer values.") - - shape = tuple(int(item) for item in x) - - if isinstance(ndim, numbers.Integral): - if len(shape) != ndim: - raise TypeError(f"The given shape {shape} must have {ndim} dimensions.") - - return shape - - -def as_numpy_scalar(x: ScalarLike, dtype: DTypeLike = None) -> np.ndarray: - """Convert a scalar into a scalar NumPy array. - - Parameters - ---------- - x - Scalar value. - dtype - Data type of the scalar. - """ - - if np.ndim(x) != 0: - raise ValueError("The given input is not a scalar.") - - return np.asarray(x, dtype=dtype) diff --git a/src/probnum/utils/arrayutils.py b/src/probnum/utils/arrayutils.py deleted file mode 100644 index c18a9675a..000000000 --- a/src/probnum/utils/arrayutils.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Utility functions for arrays and the like.""" - -from typing import Union - -import numpy as np -import scipy - -import probnum.randvars - - -def atleast_1d(*rvs): - """Convert arrays or random variables to arrays or random variables with at least - one dimension. - - Scalar inputs are converted to 1-dimensional arrays, whilst - higher-dimensional inputs are preserved. Sparse arrays are not - transformed. - Parameters - ---------- - rvs: array-like or RandomVariable - One or more input random variables or arrays. - Returns - ------- - res : array-like or list - An array / random variable or list of arrays / random variables, - each with ``a.ndim >= 1``. - """ - res = [] - for rv in rvs: - if isinstance(rv, scipy.sparse.spmatrix): - result = rv - elif isinstance(rv, np.ndarray): - result = np.atleast_1d(rv) - elif isinstance(rv, probnum.randvars.RandomVariable): - raise NotImplementedError - else: - result = rv - res.append(result) - if len(res) == 1: - return res[0] - - return res - - -def as_colvec( - vec: Union[np.ndarray, "probnum.randvars.RandomVariable"] -) -> Union[np.ndarray, "probnum.randvars.RandomVariable"]: - """Transform the given vector or random variable to column format. Given a vector - (or random variable) of dimension (n,) return an array with dimensions (n, 1) - instead. Higher-dimensional arrays are not changed. - - Parameters - ---------- - vec - Vector, array or random variable to be transformed into a column vector. - """ - if isinstance(vec, probnum.randvars.RandomVariable): - if vec.shape != (vec.shape[0], 1): - vec.reshape(newshape=(vec.shape[0], 1)) - elif vec.ndim == 1: - return vec[:, None] - return vec diff --git a/src/probnum/utils/linalg/__init__.py b/src/probnum/utils/linalg/__init__.py deleted file mode 100644 index a817cdd0f..000000000 --- a/src/probnum/utils/linalg/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Utility functions that involve numerical linear algebra.""" - -from ._cholesky_updates import cholesky_update, tril_to_positive_tril -from ._inner_product import induced_norm, inner_product -from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt - -__all__ = [ - "inner_product", - "induced_norm", - "cholesky_update", - "tril_to_positive_tril", - "gram_schmidt", - "modified_gram_schmidt", - "double_gram_schmidt", -] diff --git a/src/probnum/utils/linalg/_inner_product.py b/src/probnum/utils/linalg/_inner_product.py deleted file mode 100644 index d58441b4a..000000000 --- a/src/probnum/utils/linalg/_inner_product.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Functions defining useful inner products.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, Union - -import numpy as np - -if TYPE_CHECKING: - from probnum import linops - - -def inner_product( - v: np.ndarray, - w: np.ndarray, - A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, -) -> np.ndarray: - r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`. - - For n-d arrays the function computes the inner product over the last axis of the - two arrays ``v`` and ``w``. - - Parameters - ---------- - v - First array. - w - Second array. - A - Symmetric positive (semi-)definite matrix defining the geometry. - - Returns - ------- - inprod : - Inner product(s) of ``v`` and ``w``. - - Notes - ----- - Note that the broadcasting behavior of :func:`inner_product` differs from - :func:`numpy.inner`. Rather it follows the broadcasting rules of - :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors. - """ - v_T = v[..., None, :] - w = w[..., :, None] - - if A is None: - vw_inprod = v_T @ w - else: - vw_inprod = v_T @ (A @ w) - - return np.squeeze(vw_inprod, axis=(-2, -1)) - - -def induced_norm( - v: np.ndarray, - A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, - axis: int = -1, -) -> np.ndarray: - r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`. - - Computes the induced norm over the given axis of the array. - - Parameters - ---------- - v - Array. - A - Symmetric positive (semi-)definite linear operator defining the geometry. - axis - Specifies the axis along which to compute the vector norms. - - Returns - ------- - norm : - Vector norm of ``v`` along the given ``axis``. - """ - - if A is None: - return np.linalg.norm(v, ord=2, axis=axis, keepdims=False) - - v = np.moveaxis(v, axis, -1) - w = np.squeeze(A @ v[..., :, None], axis=-1) - - return np.sqrt(np.sum(v * w, axis=-1)) diff --git a/tests/conftest.py b/tests/conftest.py index 4c8ef4478..ab47294eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,20 @@ -# -*- coding: utf-8 -*- -"""Dummy conftest.py for probnum. +from probnum import BACKEND -If you don't know what this is for, just leave it empty. Read more about -conftest.py under: https://pytest.org/latest/plugins.html -""" +import pytest -# import pytest + +def pytest_configure(config: "_pytest.config.Config"): + config.addinivalue_line( + "markers", "skipif_backend(backend): Skip test for the given backend." + ) + + +def pytest_runtest_setup(item: pytest.Item): + # Setup conditional backend skip + skipped_backends = [ + mark.args[0] for mark in item.iter_markers(name="skipif_backend") + ] + + if skipped_backends: + if BACKEND in skipped_backends: + pytest.skip(f"Test skipped for backend {BACKEND}.") diff --git a/tests/test_functions/__init__.py b/tests/probnum/__init__.py similarity index 100% rename from tests/test_functions/__init__.py rename to tests/probnum/__init__.py diff --git a/tests/test_linops/__init__.py b/tests/probnum/_pnmethod/__init__.py similarity index 100% rename from tests/test_linops/__init__.py rename to tests/probnum/_pnmethod/__init__.py diff --git a/tests/test_pnmethod/test_stopping_citerion/test_stopping_criterion.py b/tests/probnum/_pnmethod/test_stopping_criterion.py similarity index 100% rename from tests/test_pnmethod/test_stopping_citerion/test_stopping_criterion.py rename to tests/probnum/_pnmethod/test_stopping_criterion.py index 33661bbcd..ac9ce0c39 100644 --- a/tests/test_pnmethod/test_stopping_citerion/test_stopping_criterion.py +++ b/tests/probnum/_pnmethod/test_stopping_criterion.py @@ -3,10 +3,10 @@ import operator from typing import Callable -import pytest - from probnum import LambdaStoppingCriterion, StoppingCriterion +import pytest + @pytest.fixture def stopcrit(): diff --git a/tests/test_linops/test_linops_cases/__init__.py b/tests/probnum/backend/__init__.py similarity index 100% rename from tests/test_linops/test_linops_cases/__init__.py rename to tests/probnum/backend/__init__.py diff --git a/tests/probnum/backend/autodiff/test_autodiff.py b/tests/probnum/backend/autodiff/test_autodiff.py new file mode 100644 index 000000000..07ddf4440 --- /dev/null +++ b/tests/probnum/backend/autodiff/test_autodiff.py @@ -0,0 +1,29 @@ +"""Basic tests for automatic differentiation functionality.""" +from probnum import Backend, backend, compat +from probnum.backend.autodiff import grad, hessian, jacfwd, jacrev + +import pytest + + +@pytest.mark.skipif_backend(Backend.NUMPY) +@pytest.mark.parametrize("x", backend.linspace(0, 2 * backend.pi, 10)) +def test_grad_basic_function(x: backend.Array): + compat.testing.assert_allclose(grad(backend.sin)(x), backend.cos(x)) + + +@pytest.mark.skipif_backend(Backend.NUMPY) +@pytest.mark.parametrize("x", backend.linspace(0, 2 * backend.pi, 10)) +def test_jacfwd_basic_function(x: backend.Array): + compat.testing.assert_allclose(jacfwd(backend.sin)(x), backend.cos(x)) + + +@pytest.mark.skipif_backend(Backend.NUMPY) +@pytest.mark.parametrize("x", backend.linspace(0, 2 * backend.pi, 10)) +def test_jacrev_basic_function(x: backend.Array): + compat.testing.assert_allclose(jacrev(backend.sin)(x), backend.cos(x)) + + +@pytest.mark.skipif_backend(Backend.NUMPY) +@pytest.mark.parametrize("x", backend.linspace(0, 2 * backend.pi, 10)) +def test_hessian_basic_function(x: backend.Array): + compat.testing.assert_allclose(hessian(backend.sin)(x), -backend.sin(x)) diff --git a/tests/test_pnmethod/__init__.py b/tests/probnum/backend/linalg/__init__.py similarity index 100% rename from tests/test_pnmethod/__init__.py rename to tests/probnum/backend/linalg/__init__.py diff --git a/tests/probnum/backend/linalg/test_cholesky_updates.py b/tests/probnum/backend/linalg/test_cholesky_updates.py new file mode 100644 index 000000000..46dcb77c6 --- /dev/null +++ b/tests/probnum/backend/linalg/test_cholesky_updates.py @@ -0,0 +1,82 @@ +from probnum import backend, compat +from probnum.problems.zoo.linalg import random_spd_matrix + +import pytest +import tests.utils + + +@pytest.fixture +def even_ndim(): + """Even dimension for the tests, because it is halfed in test_cholesky_optional + below.""" + return 10 + + +@pytest.fixture +def spdmats(even_ndim): + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=3897, shape=even_ndim + ) + rng_state1, rng_state2 = backend.random.split(rng_state, num=2) + + spdmat1 = random_spd_matrix(rng_state1, shape=(even_ndim, even_ndim)) + spdmat2 = random_spd_matrix(rng_state2, shape=(even_ndim, even_ndim)) + + return spdmat1, spdmat2 + + +@pytest.fixture +def spdmat1(spdmats): + return spdmats[0] + + +@pytest.fixture +def spdmat2(spdmats): + return spdmats[1] + + +def test_cholesky_update(spdmat1, spdmat2): + expected = backend.linalg.cholesky(spdmat1 + spdmat2, upper=False) + + S1 = backend.linalg.cholesky(spdmat1, upper=False) + S2 = backend.linalg.cholesky(spdmat2, upper=False) + received = backend.linalg.cholesky_update(S1, S2) + compat.testing.assert_allclose(expected, received) + + +def test_cholesky_optional(spdmat1, even_ndim): + """Assert that cholesky_update() transforms a non-square matrix square-root into a + correct Cholesky factor.""" + H_shape = (even_ndim // 2, even_ndim) + H = backend.random.uniform( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=2908, + shape=H_shape, + ), + shape=H_shape, + ) + expected = backend.linalg.cholesky(H @ spdmat1 @ H.T, upper=False) + S1 = backend.linalg.cholesky(spdmat1, upper=False) + received = backend.linalg.cholesky_update(H @ S1) + compat.testing.assert_allclose(expected, received) + + +def test_tril_to_positive_tril(): + + # Make a random tril matrix + mat = backend.tril( + backend.random.uniform(rng_state=backend.random.rng_state(4897), shape=(4, 4)) + ) + scale = backend.asarray([1.0, 1.0, 1e-5, 1e-5]) + signs = backend.asarray([1.0, -1.0, -1.0, -1.0]) + tril = mat @ backend.diag(scale) + tril_wrong_signs = tril @ backend.diag(signs) + + # Call triu_to_positive_til + tril_received = backend.linalg.tril_to_positive_tril(tril_wrong_signs) + + # Sanity check + compat.testing.assert_allclose(tril @ tril.T, tril_received @ tril_received.T) + + # Assert that the initial tril matrix comes out + compat.testing.assert_allclose(tril_received, tril) diff --git a/tests/probnum/backend/linalg/test_diagonal.py b/tests/probnum/backend/linalg/test_diagonal.py new file mode 100644 index 000000000..2b2890dc1 --- /dev/null +++ b/tests/probnum/backend/linalg/test_diagonal.py @@ -0,0 +1,10 @@ +from probnum import backend + +import pytest + + +@pytest.mark.parametrize( + "x", [backend.random.uniform(backend.random.rng_state(42), shape=(5, 2, 6))] +) +def test_diagonal_acts_on_last_axes(x: backend.Array): + assert x.shape[:-2] == backend.linalg.diagonal(x).shape[:-1] diff --git a/tests/probnum/backend/linalg/test_inner_product.py b/tests/probnum/backend/linalg/test_inner_product.py new file mode 100644 index 000000000..0d115a488 --- /dev/null +++ b/tests/probnum/backend/linalg/test_inner_product.py @@ -0,0 +1,113 @@ +"""Tests for general inner products.""" + +from probnum import backend +from probnum.backend.linalg import induced_vector_norm, inner_product +from probnum.problems.zoo.linalg import random_spd_matrix + +import pytest +import tests.utils + + +@pytest.fixture(scope="module", params=[1, 10, 50]) +def n(request) -> int: + """Vector size.""" + return request.param + + +@pytest.fixture(scope="module", params=[1, 3, 5]) +def m(request) -> int: + """Number of simultaneous vectors.""" + return request.param + + +@pytest.fixture(scope="module", params=[1, 3]) +def p(request) -> int: + """Number of matrices.""" + return request.param + + +@pytest.fixture(scope="module") +def vector0(n: int) -> backend.Array: + shape = (n,) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=86, + shape=shape, + ), + shape=shape, + ) + + +@pytest.fixture(scope="module") +def vector1(n: int) -> backend.Array: + shape = (n,) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=567, + shape=shape, + ), + shape=shape, + ) + + +@pytest.fixture(scope="module") +def array0(p: int, m: int, n: int) -> backend.Array: + shape = (p, m, n) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=86, + shape=shape, + ), + shape=shape, + ) + + +@pytest.fixture(scope="module") +def array1(m: int, n: int) -> backend.Array: + shape = (m, n) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=567, + shape=shape, + ), + shape=shape, + ) + + +def test_inner_product_vectors(vector0: backend.Array, vector1: backend.Array): + assert inner_product(vector0, vector1) == pytest.approx( + backend.sum(vector0 * vector1) + ) + + +def test_inner_product_arrays(array0: backend.Array, array1: backend.Array): + assert inner_product(array0, array1) == pytest.approx( + backend.einsum("...i,...i", array0, array1) + ) + + +def test_euclidean_norm_vector(vector0: backend.Array): + assert backend.sqrt(backend.sum(vector0**2)) == pytest.approx( + induced_vector_norm(vector0) + ) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_euclidean_norm_array(array0: backend.Array, axis: int): + assert backend.sqrt(backend.sum(array0**2, axis=axis)) == pytest.approx( + induced_vector_norm(array0, axis=axis) + ) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_induced_vector_norm_array(array0: backend.Array, axis: int): + inprod_mat = random_spd_matrix( + rng_state=backend.random.rng_state(254), + shape=(array0.shape[axis], array0.shape[axis]), + ) + array0_moved_axis = backend.move_axes(array0, axis, -1) + A_array_0_moved_axis = (inprod_mat @ array0_moved_axis[..., :, None])[..., 0] + + assert backend.sqrt( + backend.sum(array0_moved_axis * A_array_0_moved_axis, axis=-1) + ) == pytest.approx(induced_vector_norm(array0, A=inprod_mat, axis=axis)) diff --git a/tests/probnum/backend/linalg/test_orthogonalize.py b/tests/probnum/backend/linalg/test_orthogonalize.py new file mode 100644 index 000000000..41faf530c --- /dev/null +++ b/tests/probnum/backend/linalg/test_orthogonalize.py @@ -0,0 +1,193 @@ +"""Tests for orthogonalization functions.""" + +from functools import partial +from typing import Callable, Union + +from probnum import backend, compat, linops +from probnum.backend.linalg import ( + gram_schmidt, + gram_schmidt_double, + gram_schmidt_modified, +) +from probnum.problems.zoo.linalg import random_spd_matrix + +import pytest +import tests.utils + +n = 100 + + +@pytest.fixture(scope="module", params=[1, 10, 50]) +def basis_size(request) -> int: + """Number of basis vectors.""" + return request.param + + +@pytest.fixture(scope="module") +def vector() -> backend.Array: + shape = (n,) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=526367, + shape=shape, + ), + shape=shape, + ) + + +@pytest.fixture(scope="module") +def vectors() -> backend.Array: + shape = (2, 10, n) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=234, + shape=shape, + ), + shape=shape, + ) + + +@pytest.fixture( + scope="module", + params=[ + backend.eye(n), + linops.Identity(n), + linops.Scaling(factors=1.0, shape=(n, n)), + # backend.inner, + ], +) +def inprod(request) -> int: + return request.param + + +@pytest.fixture( + scope="module", + params=[ + partial(gram_schmidt_double, gram_schmidt_fn=gram_schmidt), + partial(gram_schmidt_double, gram_schmidt_fn=gram_schmidt_modified), + ], +) +def orthogonalization_fn(request) -> int: + return request.param + + +def test_is_orthogonal( + vector: backend.Array, + basis_size: int, + inprod: Union[ + backend.Array, + linops.LinearOperator, + Callable[[backend.Array, backend.Array], backend.Array], + ], + orthogonalization_fn: Callable[[backend.Array, backend.Array], backend.Array], +): + # Compute orthogonal basis + basis_shape = (vector.shape[0], basis_size) + basis = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=32, + shape=basis_shape, + ), + shape=basis_shape, + ) + orthogonal_basis, _ = backend.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, orthogonal_basis=orthogonal_basis, inner_product=inprod + ) + compat.testing.assert_allclose( + orthogonal_basis @ ortho_vector, + backend.zeros((basis_size,)), + atol=1e-12, + rtol=1e-12, + ) + + +def test_is_normalized( + vector: backend.Array, + basis_size: int, + orthogonalization_fn: Callable[[backend.Array, backend.Array], backend.Array], +): + # Compute orthogonal basis + basis_shape = (vector.shape[0], basis_size) + basis = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=9467, + shape=basis_shape, + ), + shape=basis_shape, + ) + orthogonal_basis, _ = backend.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, orthogonal_basis=orthogonal_basis, normalize=True + ) + + assert backend.sum(ortho_vector**2) == pytest.approx(1.0) + + +@pytest.mark.parametrize( + "inner_product_matrix", + [ + backend.diag( + backend.random.gamma(backend.random.rng_state(123), 1.0, shape=(n,)) + ), + 5 * backend.eye(n), + random_spd_matrix(rng_state=backend.random.rng_state(46), shape=(n, n)), + ], +) +def test_noneuclidean_innerprod( + vector: backend.Array, + basis_size: int, + inner_product_matrix: backend.Array, + orthogonalization_fn: Callable[[backend.Array, backend.Array], backend.Array], +): + evals, evecs = backend.linalg.eigh(inner_product_matrix) + orthogonal_basis = evecs * 1 / backend.sqrt(evals) + orthogonal_basis = orthogonal_basis[:, 0:basis_size].T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, + orthogonal_basis=orthogonal_basis, + inner_product=inner_product_matrix, + normalize=False, + ) + + compat.testing.assert_allclose( + orthogonal_basis @ inner_product_matrix @ ortho_vector, + backend.zeros((basis_size,)), + atol=1e-12, + rtol=1e-12, + ) + + +def test_broadcasting( + vectors: backend.Array, + basis_size: int, + orthogonalization_fn: Callable[[backend.Array, backend.Array], backend.Array], +): + # Compute orthogonal basis + basis_shape = (vectors.shape[-1], basis_size) + basis = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=32, + shape=basis_shape, + ), + shape=basis_shape, + ) + orthogonal_basis, _ = backend.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vectors = orthogonalization_fn(v=vectors, orthogonal_basis=orthogonal_basis) + compat.testing.assert_allclose( + (orthogonal_basis @ ortho_vectors[..., None])[..., 0], + backend.zeros(vectors.shape[:-1] + (basis_size,)), + atol=1e-12, + rtol=1e-12, + ) diff --git a/tests/test_pnmethod/test_stopping_citerion/__init__.py b/tests/probnum/backend/random/__init__.py similarity index 100% rename from tests/test_pnmethod/test_stopping_citerion/__init__.py rename to tests/probnum/backend/random/__init__.py diff --git a/tests/probnum/backend/random/test_uniform_so_group.py b/tests/probnum/backend/random/test_uniform_so_group.py new file mode 100644 index 000000000..43cda6588 --- /dev/null +++ b/tests/probnum/backend/random/test_uniform_so_group.py @@ -0,0 +1,43 @@ +import numpy as np + +from probnum import backend, compat +from probnum.backend.typing import SeedType, ShapeType + +import pytest_cases +import tests.utils + + +@pytest_cases.fixture(scope="module") +@pytest_cases.parametrize("seed", (234789, 7890)) +@pytest_cases.parametrize("n", (1, 2, 5, 9)) +@pytest_cases.parametrize("shape", ((), (1,), (2,), (3, 2))) +@pytest_cases.parametrize("dtype", (backend.float32, backend.float64)) +def so_group_sample( + seed: SeedType, n: int, shape: ShapeType, dtype: backend.DType +) -> backend.Array: + return backend.random.uniform_so_group( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=seed, shape=shape, dtype=dtype, n=n + ), + n=n, + shape=shape, + dtype=dtype, + ) + + +def test_orthogonal(so_group_sample: backend.Array): + n = so_group_sample.shape[-2] + + compat.testing.assert_allclose( + so_group_sample @ backend.swap_axes(so_group_sample, -2, -1), + backend.broadcast_arrays(backend.eye(n), so_group_sample)[0], + atol=1e-6 if so_group_sample.dtype == backend.float32 else 1e-12, + ) + + +def test_determinant_1(so_group_sample: backend.Array): + compat.testing.assert_allclose( + np.linalg.det(compat.to_numpy(so_group_sample)), + 1.0, + rtol=2e-6 if so_group_sample.dtype == backend.float32 else 1e-7, + ) diff --git a/tests/probnum/backend/test_array_object.py b/tests/probnum/backend/test_array_object.py new file mode 100644 index 000000000..3d70e40b4 --- /dev/null +++ b/tests/probnum/backend/test_array_object.py @@ -0,0 +1,34 @@ +"""Tests for the basic array object and associated functions.""" +import numpy as np + +from probnum import Backend + +import pytest + +try: + import jax.numpy as jnp +except ImportError as e: + pass + +try: + import torch +except ImportError as e: + pass + + +@pytest.mark.skipif_backend(Backend.NUMPY) +@pytest.mark.skipif_backend(Backend.TORCH) +def test_jax_ndarray_module_is_not_updated(): + assert jnp.ndarray.__module__ != "probnum.backend" + + +@pytest.mark.skipif_backend(Backend.JAX) +@pytest.mark.skipif_backend(Backend.TORCH) +def test_numpy_ndarray_module_is_not_updated(): + assert np.ndarray.__module__ != "probnum.backend" + + +@pytest.mark.skipif_backend(Backend.JAX) +@pytest.mark.skipif_backend(Backend.NUMPY) +def test_torch_tensor_module_is_not_updated(): + assert torch.Tensor.__module__ != "probnum.backend" diff --git a/tests/test_utils/test_argutils.py b/tests/probnum/backend/test_core.py similarity index 63% rename from tests/test_utils/test_argutils.py rename to tests/probnum/backend/test_core.py index dfc98576d..61efcb3fd 100644 --- a/tests/test_utils/test_argutils.py +++ b/tests/probnum/backend/test_core.py @@ -1,33 +1,17 @@ -"""Basic tests for argutils.""" - import numpy as np -import pytest -import probnum.utils as pnut +from probnum import backend, compat - -@pytest.mark.parametrize("scalar", [1, 1.0, 1.0 + 2.0j, np.array(1.0)]) -def test_as_numpy_scalar_returns_scalar_array(scalar): - """All sorts of scalars are transformed into a np.generic.""" - as_scalar = pnut.as_numpy_scalar(scalar) - assert isinstance(as_scalar, np.ndarray) and as_scalar.shape == () - np.testing.assert_allclose(as_scalar, scalar, atol=0.0, rtol=1e-12) - - -@pytest.mark.parametrize("sequence", [[1.0], (1,), np.array([1.0])]) -def test_as_numpy_scalar_bad_sequence_is_bad(sequence): - """Sequence types give rise to ValueErrors in `as_numpy_scalar`.""" - with pytest.raises(ValueError): - pnut.as_numpy_scalar(sequence) +import pytest @pytest.mark.parametrize("shape_arg", list(range(5)) + [np.int32(8)]) @pytest.mark.parametrize("ndim", [False, True]) def test_as_shape_int(shape_arg, ndim): if ndim: - shape = pnut.as_shape(shape_arg, ndim=1) + shape = backend.asshape(shape_arg, ndim=1) else: - shape = pnut.as_shape(shape_arg) + shape = backend.asshape(shape_arg) assert isinstance(shape, tuple) assert len(shape) == 1 @@ -50,9 +34,9 @@ def test_as_shape_int(shape_arg, ndim): @pytest.mark.parametrize("ndim", [False, True]) def test_as_shape_iterable(shape_arg, ndim): if ndim: - shape = pnut.as_shape(shape_arg, ndim=len(shape_arg)) + shape = backend.asshape(shape_arg, ndim=len(shape_arg)) else: - shape = pnut.as_shape(shape_arg) + shape = backend.asshape(shape_arg) assert isinstance(shape, tuple) assert len(shape) == len(shape_arg) @@ -73,7 +57,7 @@ def test_as_shape_iterable(shape_arg, ndim): ) def test_as_shape_wrong_type(shape_arg): with pytest.raises(TypeError): - pnut.as_shape(shape_arg) + backend.asshape(shape_arg) @pytest.mark.parametrize( @@ -91,4 +75,19 @@ def test_as_shape_wrong_type(shape_arg): ) def test_as_shape_wrong_ndim(shape_arg, ndim): with pytest.raises(TypeError): - pnut.as_shape(shape_arg, ndim=ndim) + backend.asshape(shape_arg, ndim=ndim) + + +@pytest.mark.parametrize("scalar", [1, 1.0, 1.0 + 2.0j, backend.asarray(1.0)]) +def test_asscalar_returns_scalar_array(scalar): + """All sorts of scalars are transformed into a np.generic.""" + asscalar = backend.asscalar(scalar) + assert backend.isarray(asscalar) and asscalar.shape == () + compat.testing.assert_allclose(asscalar, scalar, atol=0.0, rtol=1e-12) + + +@pytest.mark.parametrize("sequence", [[1.0], (1,), backend.asarray([1.0])]) +def test_asscalar_sequence_error(sequence): + """Sequence types give rise to ValueErrors in `asscalar`.""" + with pytest.raises(ValueError): + backend.asscalar(sequence) diff --git a/tests/probnum/backend/test_hypergrad.py b/tests/probnum/backend/test_hypergrad.py new file mode 100644 index 000000000..244fc093e --- /dev/null +++ b/tests/probnum/backend/test_hypergrad.py @@ -0,0 +1,66 @@ +from scipy.optimize._numdiff import approx_derivative + +from probnum import Backend, backend, compat, functions, randprocs, randvars + +import pytest + + +def assert_gradient_approx_finite_differences( + func, + grad, + x0, + *, + epsilon=None, + method="3-point", + rtol=1e-7, + atol=0.0, +): + if epsilon is None: + out = func(x0) + + epsilon = backend.sqrt(backend.finfo(out.dtype).eps) + + compat.testing.assert_allclose( + grad(x0), + approx_derivative( + lambda x: backend.asarray(func(x), copy=False), + x0, + method=method, + ), + rtol=rtol, + atol=atol, + ) + + +def g(l): + l = l[0] + + gp = randprocs.GaussianProcess( + mean=functions.Zero(input_shape=()), + cov=randprocs.kernels.ExpQuad(input_shape=(), lengthscale=l), + ) + + xs = backend.linspace(-1.0, 1.0, 10) + ys = backend.linspace(-1.0, 1.0, 10) + + fX = gp(xs) + + e = randvars.Normal(mean=backend.zeros(10), cov=backend.eye(10)) + + return -(fX + e).logpdf(ys) + + +@pytest.mark.skipif_backend(Backend.NUMPY) +def test_compare_grad(): + l = backend.asarray([3.0]) + dg = backend.autodiff.grad(g) + + assert_gradient_approx_finite_differences( + g, + dg, + x0=l, + ) + + +if __name__ == "__main__": + test_compare_grad() diff --git a/tests/test_problems/__init__.py b/tests/probnum/functions/__init__.py similarity index 100% rename from tests/test_problems/__init__.py rename to tests/probnum/functions/__init__.py diff --git a/tests/test_problems/test_zoo/__init__.py b/tests/probnum/functions/conftest.py similarity index 100% rename from tests/test_problems/test_zoo/__init__.py rename to tests/probnum/functions/conftest.py diff --git a/tests/test_functions/test_algebra.py b/tests/probnum/functions/test_algebra.py similarity index 53% rename from tests/test_functions/test_algebra.py rename to tests/probnum/functions/test_algebra.py index 873d8c646..295e709de 100644 --- a/tests/test_functions/test_algebra.py +++ b/tests/probnum/functions/test_algebra.py @@ -1,15 +1,15 @@ -import numpy as np +from probnum import backend, compat, functions +from probnum.backend.typing import ShapeType + import pytest from pytest_cases import param_fixture, param_fixtures - -from probnum import functions -from probnum.typing import ScalarLike, ShapeType +from tests.utils.random import rng_state_from_sampling_args lambda_fn_0 = functions.LambdaFunction( lambda xs: ( - np.sin( - np.linspace(0.5, 2.0, 6).reshape((3, 2)) - * np.sum(xs**2, axis=-1)[..., None, None] + backend.sin( + backend.linspace(0.5, 2.0, 6).reshape((3, 2)) + * backend.sum(xs**2, axis=-1)[..., None, None] ) ), input_shape=(2,), @@ -18,8 +18,8 @@ lambda_fn_1 = functions.LambdaFunction( lambda xs: ( - np.linspace(0.5, 2.0, 6).reshape((3, 2)) - * np.exp(-0.5 * np.sum(xs**2, axis=-1))[..., None, None] + backend.linspace(0.5, 2.0, 6).reshape((3, 2)) + * backend.exp(-0.5 * backend.sum(xs**2, axis=-1))[..., None, None] ), input_shape=(2,), output_shape=(3, 2), @@ -55,28 +55,38 @@ def test_add_evaluation( - op0: functions.Function, op1: functions.Function, batch_shape: ShapeType, seed: int + op0: functions.Function, op1: functions.Function, batch_shape: ShapeType ): fn_add = op0 + op1 - rng = np.random.default_rng(seed) - xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + rng_state = rng_state_from_sampling_args(base_seed=2457, shape=batch_shape) + xs = backend.random.uniform( + rng_state=rng_state, + minval=-1.0, + maxval=1.0, + shape=batch_shape + op0.input_shape, + ) - np.testing.assert_array_equal( + compat.testing.assert_array_equal( fn_add(xs), op0(xs) + op1(xs), ) def test_sub_evaluation( - op0: functions.Function, op1: functions.Function, batch_shape: ShapeType, seed: int + op0: functions.Function, op1: functions.Function, batch_shape: ShapeType ): fn_sub = op0 - op1 - rng = np.random.default_rng(seed) - xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + rng_state = rng_state_from_sampling_args(base_seed=27545, shape=batch_shape) + xs = backend.random.uniform( + rng_state=rng_state, + minval=-1.0, + maxval=1.0, + shape=batch_shape + op0.input_shape, + ) - np.testing.assert_array_equal( + compat.testing.assert_array_equal( fn_sub(xs), op0(xs) - op1(xs), ) @@ -85,16 +95,20 @@ def test_sub_evaluation( @pytest.mark.parametrize("scalar", [1.0, 3, 1000.0]) def test_mul_scalar_evaluation( op0: functions.Function, - scalar: ScalarLike, + scalar: backend.Scalar, batch_shape: ShapeType, - seed: int, ): fn_scaled = op0 * scalar - rng = np.random.default_rng(seed) - xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + rng_state = rng_state_from_sampling_args(base_seed=2527, shape=batch_shape) + xs = backend.random.uniform( + rng_state=rng_state, + minval=-1.0, + maxval=1.0, + shape=batch_shape + op0.input_shape, + ) - np.testing.assert_array_equal( + compat.testing.assert_array_equal( fn_scaled(xs), op0(xs) * scalar, ) @@ -103,16 +117,20 @@ def test_mul_scalar_evaluation( @pytest.mark.parametrize("scalar", [1.0, 3, 1000.0]) def test_rmul_scalar_evaluation( op0: functions.Function, - scalar: ScalarLike, + scalar: backend.Scalar, batch_shape: ShapeType, - seed: int, ): fn_scaled = scalar * op0 - rng = np.random.default_rng(seed) - xs = rng.uniform(-1.0, 1.0, batch_shape + op0.input_shape) + rng_state = rng_state_from_sampling_args(base_seed=83664, shape=batch_shape) + xs = backend.random.uniform( + rng_state=rng_state, + minval=-1.0, + maxval=1.0, + shape=batch_shape + op0.input_shape, + ) - np.testing.assert_array_equal( + compat.testing.assert_array_equal( fn_scaled(xs), scalar * op0(xs), ) diff --git a/tests/test_functions/test_algebra_fallbacks.py b/tests/probnum/functions/test_algebra_fallbacks.py similarity index 99% rename from tests/test_functions/test_algebra_fallbacks.py rename to tests/probnum/functions/test_algebra_fallbacks.py index b3bc3f51b..1bb034e71 100644 --- a/tests/test_functions/test_algebra_fallbacks.py +++ b/tests/probnum/functions/test_algebra_fallbacks.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import functions +import pytest + @pytest.fixture(scope="module") def fn0() -> functions.LambdaFunction: diff --git a/tests/test_functions/test_function.py b/tests/probnum/functions/test_function.py similarity index 99% rename from tests/test_functions/test_function.py rename to tests/probnum/functions/test_function.py index e0242ace1..5c184bf6a 100644 --- a/tests/test_functions/test_function.py +++ b/tests/probnum/functions/test_function.py @@ -1,10 +1,11 @@ """Tests for functions with fixed in- and output shape.""" import numpy as np -import pytest from probnum.functions import LambdaFunction +import pytest + def test_input_shape_mismatch_raises_error(): fn = LambdaFunction(fn=lambda x: 2 * x, input_shape=(1,), output_shape=(1,)) diff --git a/tests/test_problems/test_zoo/test_diffeq/__init__.py b/tests/probnum/linops/__init__.py similarity index 100% rename from tests/test_problems/test_zoo/test_diffeq/__init__.py rename to tests/probnum/linops/__init__.py diff --git a/tests/test_problems/test_zoo/test_filtsmooth/__init__.py b/tests/probnum/linops/cases/__init__.py similarity index 100% rename from tests/test_problems/test_zoo/test_filtsmooth/__init__.py rename to tests/probnum/linops/cases/__init__.py diff --git a/tests/test_linops/test_linops_cases/arithmetic_cases.py b/tests/probnum/linops/cases/arithmetic_cases.py similarity index 83% rename from tests/test_linops/test_linops_cases/arithmetic_cases.py rename to tests/probnum/linops/cases/arithmetic_cases.py index a3088afce..d65785961 100644 --- a/tests/test_linops/test_linops_cases/arithmetic_cases.py +++ b/tests/probnum/linops/cases/arithmetic_cases.py @@ -1,9 +1,9 @@ from typing import Tuple import numpy as np -import pytest_cases import probnum as pn +from probnum import backend from probnum.linops._arithmetic_fallbacks import ( NegatedLinearOperator, ScaledLinearOperator, @@ -11,18 +11,24 @@ ) from probnum.problems.zoo.linalg import random_spd_matrix +import pytest_cases + square_matrix_pairs = [ ( - np.random.default_rng(n + 478).standard_normal((n, n)), - np.random.default_rng(n + 267).standard_normal((n, n)), + backend.random.standard_normal( + rng_state=backend.random.rng_state(n + 478), shape=(n, n) + ), + backend.random.standard_normal( + rng_state=backend.random.rng_state(n + 267), shape=(n, n) + ), ) for n in [1, 2, 3, 5, 8] ] spd_matrix_pairs = [ ( - random_spd_matrix(np.random.default_rng(n + 9872), dim=n), - random_spd_matrix(np.random.default_rng(n + 1231), dim=n), + random_spd_matrix(backend.random.rng_state(n + 9872), shape=(n, n)), + random_spd_matrix(backend.random.rng_state(n + 1231), shape=(n, n)), ) for n in [1, 2, 3, 5, 8] ] diff --git a/tests/test_linops/test_linops_cases/kronecker_cases.py b/tests/probnum/linops/cases/kronecker_cases.py similarity index 94% rename from tests/test_linops/test_linops_cases/kronecker_cases.py rename to tests/probnum/linops/cases/kronecker_cases.py index a5886f9c1..cb1b78564 100644 --- a/tests/test_linops/test_linops_cases/kronecker_cases.py +++ b/tests/probnum/linops/cases/kronecker_cases.py @@ -2,16 +2,18 @@ from typing import Tuple, Union import numpy as np -import pytest -import pytest_cases import probnum as pn +from probnum import backend from probnum.problems.zoo.linalg import random_spd_matrix +import pytest +import pytest_cases + spd_matrices = ( pn.linops.Identity(shape=(1, 1)), np.array([[1.0, -2.0], [-2.0, 5.0]]), - random_spd_matrix(np.random.default_rng(597), dim=9), + random_spd_matrix(rng_state=backend.random.rng_state(597), shape=(9, 9)), ) @@ -108,8 +110,12 @@ def case_symmetric_kronecker( "A,B", [ ( - random_spd_matrix(np.random.default_rng(234789 + n), dim=n), - random_spd_matrix(np.random.default_rng(347892 + n), dim=n), + random_spd_matrix( + rng_state=backend.random.rng_state(234789 + n), shape=(n, n) + ), + random_spd_matrix( + rng_state=backend.random.rng_state(347892 + n), shape=(n, n) + ), ) for n in [1, 2, 3, 6] ], diff --git a/tests/test_linops/test_linops_cases/linear_operator_cases.py b/tests/probnum/linops/cases/linear_operator_cases.py similarity index 96% rename from tests/test_linops/test_linops_cases/linear_operator_cases.py rename to tests/probnum/linops/cases/linear_operator_cases.py index ef5815d9c..072ff2201 100644 --- a/tests/test_linops/test_linops_cases/linear_operator_cases.py +++ b/tests/probnum/linops/cases/linear_operator_cases.py @@ -1,13 +1,15 @@ from typing import Tuple import numpy as np -import pytest -import pytest_cases import scipy.sparse import probnum as pn +from probnum import backend from probnum.problems.zoo.linalg import random_spd_matrix +import pytest +import pytest_cases + matrices = [ np.array([[-1.5, 3], [0, -230]]), np.array([[2, 0], [1, 3]]), @@ -16,7 +18,7 @@ spd_matrices = [ np.array([[1.0]]), np.array([[1.0, -2.0], [-2.0, 5.0]]), - random_spd_matrix(np.random.default_rng(597), dim=10), + random_spd_matrix(rng_state=backend.random.rng_state(597), shape=(10, 10)), ] diff --git a/tests/test_linops/test_linops_cases/scaling_cases.py b/tests/probnum/linops/cases/scaling_cases.py similarity index 99% rename from tests/test_linops/test_linops_cases/scaling_cases.py rename to tests/probnum/linops/cases/scaling_cases.py index fb12b21fe..b024da728 100644 --- a/tests/test_linops/test_linops_cases/scaling_cases.py +++ b/tests/probnum/linops/cases/scaling_cases.py @@ -1,10 +1,11 @@ from typing import Tuple import numpy as np -import pytest_cases import probnum as pn +import pytest_cases + @pytest_cases.case(tags=["square", "symmetric", "indefinite"]) @pytest_cases.parametrize( diff --git a/tests/test_linops/test_linops_cases/selectionembedding_cases.py b/tests/probnum/linops/cases/selectionembedding_cases.py similarity index 100% rename from tests/test_linops/test_linops_cases/selectionembedding_cases.py rename to tests/probnum/linops/cases/selectionembedding_cases.py diff --git a/tests/test_linops/test_arithmetics.py b/tests/probnum/linops/test_arithmetics.py similarity index 97% rename from tests/test_linops/test_arithmetics.py rename to tests/probnum/linops/test_arithmetics.py index 8b3b1313d..318ecaa5d 100644 --- a/tests/test_linops/test_arithmetics.py +++ b/tests/probnum/linops/test_arithmetics.py @@ -4,9 +4,8 @@ import itertools import numpy as np -import pytest -from probnum import config +from probnum import backend, config from probnum.linops._arithmetic import _add_fns, _matmul_fns, _mul_fns, _sub_fns from probnum.linops._arithmetic_fallbacks import ( NegatedLinearOperator, @@ -32,9 +31,14 @@ from probnum.linops._scaling import Scaling, Zero from probnum.problems.zoo.linalg import random_spd_matrix +import pytest + def _aslist(arg): - """Converts anything to a list. Non-iterables become single-element lists.""" + """Converts anything to a list. + + Non-iterables become single-element lists. + """ try: return list(arg) except TypeError: # excepts TypeError: '' object is not iterable @@ -69,7 +73,9 @@ def get_linop(linop_type): elif linop_type is Matrix: return (Matrix(np.random.rand(4, 4)), Matrix(np.random.rand(6, 3))) elif linop_type is _InverseLinearOperator: - _posdef_randmat = random_spd_matrix(rng=np.random.default_rng(123), dim=4) + _posdef_randmat = random_spd_matrix( + rng_state=backend.random.rng_state(123), shape=(4, 4) + ) return Matrix(_posdef_randmat).inv() elif linop_type is TransposedLinearOperator: return TransposedLinearOperator(linop=Matrix(np.random.rand(4, 4))) diff --git a/tests/test_linops/test_arithmetics_fallbacks.py b/tests/probnum/linops/test_arithmetics_fallbacks.py similarity index 85% rename from tests/test_linops/test_arithmetics_fallbacks.py rename to tests/probnum/linops/test_arithmetics_fallbacks.py index d74fba490..559ab4411 100644 --- a/tests/test_linops/test_arithmetics_fallbacks.py +++ b/tests/probnum/linops/test_arithmetics_fallbacks.py @@ -1,17 +1,14 @@ """Tests for linear operator arithmetics fallbacks.""" import numpy as np -import pytest # NegatedLinearOperator,; ProductLinearOperator,; SumLinearOperator,; +from probnum import backend from probnum.linops._arithmetic_fallbacks import ScaledLinearOperator from probnum.linops._linear_operator import Matrix from probnum.problems.zoo.linalg import random_spd_matrix - -@pytest.fixture -def rng(): - return np.random.default_rng(123) +import pytest @pytest.fixture @@ -20,8 +17,9 @@ def scalar(): @pytest.fixture -def rand_spd_mat(rng): - return Matrix(random_spd_matrix(rng, dim=4)) +def rand_spd_mat(): + rng_state = backend.random.rng_state(1237) + return Matrix(random_spd_matrix(rng_state, shape=(4, 4))) def test_scaled_linop(rand_spd_mat, scalar): diff --git a/tests/test_linops/test_kronecker.py b/tests/probnum/linops/test_kronecker.py similarity index 97% rename from tests/test_linops/test_kronecker.py rename to tests/probnum/linops/test_kronecker.py index 427556bb6..d5fe2c88c 100644 --- a/tests/test_linops/test_kronecker.py +++ b/tests/probnum/linops/test_kronecker.py @@ -1,15 +1,16 @@ """Tests for Kronecker-type linear operators.""" import numpy as np -import pytest -import pytest_cases import probnum as pn +import pytest +import pytest_cases + @pytest_cases.parametrize_with_cases( "linop,matrix", - cases=".test_linops_cases.kronecker_cases", + cases=".cases.kronecker_cases", has_tag="symmetric_kronecker", ) def test_symmetric_kronecker_commutative( diff --git a/tests/test_linops/test_linop_decompositions.py b/tests/probnum/linops/test_linop_decompositions.py similarity index 90% rename from tests/test_linops/test_linop_decompositions.py rename to tests/probnum/linops/test_linop_decompositions.py index 19b3644fd..f5c5aa9c8 100644 --- a/tests/test_linops/test_linop_decompositions.py +++ b/tests/probnum/linops/test_linop_decompositions.py @@ -1,16 +1,17 @@ import pathlib import numpy as np -import pytest -import pytest_cases -from pytest_cases import filters import scipy.linalg import probnum as pn +import pytest +import pytest_cases +from pytest_cases import filters + case_modules = [ - ".test_linops_cases." + path.stem - for path in (pathlib.Path(__file__).parent / "test_linops_cases").glob("*_cases.py") + ".cases." + path.stem + for path in (pathlib.Path(__file__).parent / "cases").glob("*_cases.py") ] @@ -70,8 +71,8 @@ def test_cholesky(linop: pn.linops.LinearOperator, matrix: np.ndarray, lower: bo def test_cholesky_is_symmetric_not_true( linop: pn.linops.LinearOperator, matrix: np.ndarray, lower: bool ): # pylint: disable=unused-argument - """Tests whether computing the Cholesky decomposition of a ``LinearOperator`` - whose ``is_symmetric`` property is not set to ``True`` results in an error.""" + """Tests whether computing the Cholesky decomposition of a ``LinearOperator`` whose + ``is_symmetric`` property is not set to ``True`` results in an error.""" if linop.is_symmetric is not True: with pytest.raises(np.linalg.LinAlgError): @@ -86,8 +87,8 @@ def test_cholesky_is_symmetric_not_true( def test_cholesky_is_positive_definite_false( linop: pn.linops.LinearOperator, matrix: np.ndarray, lower: bool ): # pylint: disable=unused-argument - """Tests whether computing the Cholesky decomposition of a ``LinearOperator`` - whose ``is_symmetric`` property is not set to ``True`` results in an error.""" + """Tests whether computing the Cholesky decomposition of a ``LinearOperator`` whose + ``is_symmetric`` property is not set to ``True`` results in an error.""" if linop.is_positive_definite is False: with pytest.raises(np.linalg.LinAlgError): @@ -111,7 +112,7 @@ def test_cholesky_not_positive_definite( linop: pn.linops.LinearOperator, matrix: np.ndarray, lower: bool ): """Tests whether computing the Cholesky decomposition of a symmetric, but not - positive definite matrix results in an error""" + positive definite matrix results in an error.""" expected_exception = None diff --git a/tests/test_linops/test_linop_properties.py b/tests/probnum/linops/test_linop_properties.py similarity index 84% rename from tests/test_linops/test_linop_properties.py rename to tests/probnum/linops/test_linop_properties.py index 77089df58..748c9b4d9 100644 --- a/tests/test_linops/test_linop_properties.py +++ b/tests/probnum/linops/test_linop_properties.py @@ -1,14 +1,15 @@ import pathlib import numpy as np -import pytest -import pytest_cases import probnum as pn +import pytest +import pytest_cases + case_modules = [ - ".test_linops_cases." + path.stem - for path in (pathlib.Path(__file__).parent / "test_linops_cases").glob("*_cases.py") + ".cases." + path.stem + for path in (pathlib.Path(__file__).parent / "cases").glob("*_cases.py") ] diff --git a/tests/test_linops/test_linops.py b/tests/probnum/linops/test_linops.py similarity index 98% rename from tests/test_linops/test_linops.py rename to tests/probnum/linops/test_linops.py index 4b3fe104a..20b404aa4 100644 --- a/tests/test_linops/test_linops.py +++ b/tests/probnum/linops/test_linops.py @@ -2,14 +2,15 @@ from typing import Optional, Tuple, Union import numpy as np -import pytest -import pytest_cases import probnum as pn +import pytest +import pytest_cases + case_modules = [ - ".test_linops_cases." + path.stem - for path in (pathlib.Path(__file__).parent / "test_linops_cases").glob("*_cases.py") + ".cases." + path.stem + for path in (pathlib.Path(__file__).parent / "cases").glob("*_cases.py") ] diff --git a/tests/test_linops/test_matrix.py b/tests/probnum/linops/test_matrix.py similarity index 76% rename from tests/test_linops/test_matrix.py rename to tests/probnum/linops/test_matrix.py index b45f3ee99..37c89d7e1 100644 --- a/tests/test_linops/test_matrix.py +++ b/tests/probnum/linops/test_matrix.py @@ -2,7 +2,10 @@ import probnum as pn +import pytest + +@pytest.mark.filterwarnings("ignore:the matrix subclass is not the recommended way") def test_matrix_linop_converts_numpy_matrix(): matrix = np.asmatrix(np.eye(10)) linop = pn.linops.Matrix(matrix) diff --git a/tests/test_problems/test_zoo/test_linalg/__init__.py b/tests/probnum/problems/__init__.py similarity index 100% rename from tests/test_problems/test_zoo/test_linalg/__init__.py rename to tests/probnum/problems/__init__.py diff --git a/tests/test_randprocs/__init__.py b/tests/probnum/problems/zoo/__init__.py similarity index 100% rename from tests/test_randprocs/__init__.py rename to tests/probnum/problems/zoo/__init__.py diff --git a/tests/test_randprocs/test_kernels/__init__.py b/tests/probnum/problems/zoo/diffeq/__init__.py similarity index 100% rename from tests/test_randprocs/test_kernels/__init__.py rename to tests/probnum/problems/zoo/diffeq/__init__.py diff --git a/tests/test_problems/test_zoo/test_diffeq/test_ivp_examples.py b/tests/probnum/problems/zoo/diffeq/test_ivp_examples.py similarity index 99% rename from tests/test_problems/test_zoo/test_diffeq/test_ivp_examples.py rename to tests/probnum/problems/zoo/diffeq/test_ivp_examples.py index 851b16b28..1fd57770e 100644 --- a/tests/test_problems/test_zoo/test_diffeq/test_ivp_examples.py +++ b/tests/probnum/problems/zoo/diffeq/test_ivp_examples.py @@ -1,9 +1,10 @@ import numpy as np -import pytest import probnum.problems as pnpr import probnum.problems.zoo.diffeq as diffeqzoo +import pytest + ODE_LIST = [ diffeqzoo.vanderpol(), diffeqzoo.threebody(), diff --git a/tests/test_problems/test_zoo/test_diffeq/test_ivp_examples_jax.py b/tests/probnum/problems/zoo/diffeq/test_ivp_examples_jax.py similarity index 100% rename from tests/test_problems/test_zoo/test_diffeq/test_ivp_examples_jax.py rename to tests/probnum/problems/zoo/diffeq/test_ivp_examples_jax.py index 6c932e28e..1a93fe991 100644 --- a/tests/test_problems/test_zoo/test_diffeq/test_ivp_examples_jax.py +++ b/tests/probnum/problems/zoo/diffeq/test_ivp_examples_jax.py @@ -1,7 +1,7 @@ -import pytest - import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + # Jax dependency handling # pylint: disable=unused-import try: diff --git a/tests/test_randprocs/test_markov/__init__.py b/tests/probnum/problems/zoo/filtsmooth/__init__.py similarity index 100% rename from tests/test_randprocs/test_markov/__init__.py rename to tests/probnum/problems/zoo/filtsmooth/__init__.py diff --git a/tests/test_problems/test_zoo/test_filtsmooth/test_filtsmooth_problems.py b/tests/probnum/problems/zoo/filtsmooth/test_filtsmooth_problems.py similarity index 99% rename from tests/test_problems/test_zoo/test_filtsmooth/test_filtsmooth_problems.py rename to tests/probnum/problems/zoo/filtsmooth/test_filtsmooth_problems.py index b856df85e..4dd123f2a 100644 --- a/tests/test_problems/test_zoo/test_filtsmooth/test_filtsmooth_problems.py +++ b/tests/probnum/problems/zoo/filtsmooth/test_filtsmooth_problems.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import problems import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + def rng(): return np.random.default_rng(seed=123) diff --git a/tests/test_randprocs/test_markov/test_continuous/__init__.py b/tests/probnum/problems/zoo/linalg/__init__.py similarity index 100% rename from tests/test_randprocs/test_markov/test_continuous/__init__.py rename to tests/probnum/problems/zoo/linalg/__init__.py diff --git a/tests/test_problems/test_zoo/test_linalg/conftest.py b/tests/probnum/problems/zoo/linalg/conftest.py similarity index 68% rename from tests/test_problems/test_zoo/test_linalg/conftest.py rename to tests/probnum/problems/zoo/linalg/conftest.py index 4b01ebf1c..faffbbad3 100644 --- a/tests/test_problems/test_zoo/test_linalg/conftest.py +++ b/tests/probnum/problems/zoo/linalg/conftest.py @@ -1,10 +1,8 @@ """Test fixtures for the linear algebra test problem zoo.""" -import numpy as np -import pytest -import pytest_cases import scipy.sparse +from probnum import backend from probnum.problems.zoo.linalg import ( SuiteSparseMatrix, random_sparse_spd_matrix, @@ -12,10 +10,9 @@ suitesparse_matrix, ) - -@pytest_cases.fixture() -def rng() -> np.random.Generator: - return np.random.default_rng(42) +import pytest +import pytest_cases +from tests.utils.random import rng_state_from_sampling_args @pytest_cases.fixture() @@ -38,21 +35,23 @@ def density(density: float) -> float: @pytest_cases.fixture() -def rnd_dense_spd_mat(n_cols: int, rng: np.random.Generator) -> np.ndarray: +def rnd_dense_spd_mat(n_cols: int) -> backend.Array: """Random spd matrix generated from :meth:`random_spd_matrix`.""" - return random_spd_matrix(rng=rng, dim=n_cols) + rng_state = rng_state_from_sampling_args(base_seed=2984357, shape=n_cols) + return random_spd_matrix(rng_state=rng_state, shape=(n_cols, n_cols)) @pytest_cases.fixture() -def rnd_sparse_spd_mat( - n_cols: int, density: float, rng: np.random.Generator -) -> scipy.sparse.spmatrix: +def rnd_sparse_spd_mat(n_cols: int, density: float) -> scipy.sparse.spmatrix: """Random sparse spd matrix generated from :meth:`random_sparse_spd_matrix`.""" - return random_sparse_spd_matrix(rng=rng, dim=n_cols, density=density) + rng_state = rng_state_from_sampling_args(base_seed=2984357, shape=n_cols) + return random_sparse_spd_matrix( + rng_state=rng_state, shape=(n_cols, n_cols), density=density + ) rnd_spd_mat = pytest_cases.fixture_union( - "spd_mat", [rnd_dense_spd_mat, rnd_sparse_spd_mat] + "spd_mat", [rnd_dense_spd_mat, rnd_sparse_spd_mat], idstyle="explicit" ) diff --git a/tests/probnum/problems/zoo/linalg/test_random_linear_system.py b/tests/probnum/problems/zoo/linalg/test_random_linear_system.py new file mode 100644 index 000000000..1ad808210 --- /dev/null +++ b/tests/probnum/problems/zoo/linalg/test_random_linear_system.py @@ -0,0 +1,35 @@ +"""Tests for functions generating random linear systems.""" + +from probnum import backend, randvars +from probnum.problems.zoo.linalg import random_linear_system, random_spd_matrix + +import pytest + + +def test_custom_random_matrix(): + rng_state = backend.random.rng_state(305985) + random_unitary_matrix = lambda rng_state, n: backend.random.uniform_so_group( + n=n, rng_state=rng_state + ) + _ = random_linear_system(rng_state, random_unitary_matrix, n=5) + + +def test_custom_solution_randvar(): + n = 5 + rng_state = backend.random.rng_state(3453) + x = randvars.Normal(mean=backend.ones(n), cov=backend.eye(n)) + _ = random_linear_system( + rng_state=rng_state, matrix=random_spd_matrix, solution_rv=x, shape=(n, n) + ) + + +def test_incompatible_matrix_and_solution(): + rng_state = backend.random.rng_state(3453) + + with pytest.raises(ValueError): + _ = random_linear_system( + rng_state=rng_state, + matrix=random_spd_matrix, + solution_rv=randvars.Normal(backend.ones(2), backend.eye(2)), + shape=(5, 5), + ) diff --git a/tests/test_problems/test_zoo/test_linalg/test_random_spd_matrix.py b/tests/probnum/problems/zoo/linalg/test_random_spd_matrix.py similarity index 55% rename from tests/test_problems/test_zoo/test_linalg/test_random_spd_matrix.py rename to tests/probnum/problems/zoo/linalg/test_random_spd_matrix.py index 8c9a292aa..c7fbaa8fe 100644 --- a/tests/test_problems/test_zoo/test_linalg/test_random_spd_matrix.py +++ b/tests/probnum/problems/zoo/linalg/test_random_spd_matrix.py @@ -2,58 +2,68 @@ from typing import Union -import numpy as np -import pytest -import pytest_cases import scipy.sparse +from probnum import backend, compat from probnum.problems.zoo.linalg import random_sparse_spd_matrix, random_spd_matrix +import pytest +import pytest_cases + def test_dimension( - rnd_spd_mat: Union[np.ndarray, scipy.sparse.csr_matrix], n_cols: int + rnd_spd_mat: Union[backend.Array, scipy.sparse.csr_matrix], n_cols: int ): """Test whether matrix dimension matches specified dimension.""" assert rnd_spd_mat.shape == (n_cols, n_cols) -def test_symmetric(rnd_spd_mat: Union[np.ndarray, scipy.sparse.csr_matrix]): +def test_symmetric(rnd_spd_mat: Union[backend.Array, scipy.sparse.csr_matrix]): """Test whether the matrix is symmetric.""" if isinstance(rnd_spd_mat, scipy.sparse.spmatrix): rnd_spd_mat = rnd_spd_mat.todense() - np.testing.assert_equal(rnd_spd_mat, rnd_spd_mat.T) + compat.testing.assert_equal(rnd_spd_mat, rnd_spd_mat.T) -def test_positive_definite(rnd_spd_mat: Union[np.ndarray, scipy.sparse.csr_matrix]): +def test_positive_definite(rnd_spd_mat: Union[backend.Array, scipy.sparse.csr_matrix]): """Test whether the matrix is positive definite.""" if isinstance(rnd_spd_mat, scipy.sparse.spmatrix): rnd_spd_mat = rnd_spd_mat.todense() - eigvals = np.linalg.eigvals(rnd_spd_mat) - assert np.all(eigvals > 0.0), "Eigenvalues are not all positive." + eigvals = backend.linalg.eigvalsh(rnd_spd_mat) + assert backend.all(eigvals > 0.0), "Eigenvalues are not all positive." -def test_spectrum_matches_given(rng: np.random.Generator): +def test_spectrum_matches_given(): """Test whether the spectrum of the test problem matches the provided spectrum.""" - dim = 10 - spectrum = np.sort(rng.uniform(0.1, 1, size=dim)) - spdmat = random_spd_matrix(rng=rng, dim=dim, spectrum=spectrum) - eigvals = np.sort(np.linalg.eigvals(spdmat)) - np.testing.assert_allclose( + n = 10 + rng_state_spectrum, rng_state_mat = backend.random.split( + backend.random.rng_state(234985) + ) + spectrum = backend.sort( + backend.random.uniform( + rng_state=rng_state_spectrum, minval=0.1, maxval=1.0, shape=n + ) + ) + spdmat = random_spd_matrix(rng_state=rng_state_mat, shape=(n, n), spectrum=spectrum) + eigvals = backend.sort(backend.linalg.eigvalsh(spdmat)) + compat.testing.assert_allclose( spectrum, eigvals, err_msg="Provided spectrum doesn't match actual.", ) -def test_negative_eigenvalues_throws_error(rng: np.random.Generator): +def test_negative_eigenvalues_throws_error(): """Test whether a non-positive spectrum throws an error.""" with pytest.raises(ValueError): - random_spd_matrix(rng=rng, dim=3, spectrum=[-1, 1, 2]) + random_spd_matrix( + rng_state=backend.random.rng_state(1), shape=(3, 3), spectrum=[-1, 1, 2] + ) -def test_is_ndarray(rnd_dense_spd_mat: np.ndarray): - """Test whether the random dense spd matrix is a `np.ndarray`.""" - assert isinstance(rnd_dense_spd_mat, np.ndarray) +def test_is_ndarray(rnd_dense_spd_mat: backend.Array): + """Test whether the random dense spd matrix is a `backend.Array`.""" + assert isinstance(rnd_dense_spd_mat, backend.Array) def test_is_spmatrix(rnd_sparse_spd_mat: scipy.sparse.spmatrix): @@ -75,27 +85,36 @@ def test_is_spmatrix(rnd_sparse_spd_mat: scipy.sparse.spmatrix): ], ) def test_sparse_formats( - spformat: str, sparse_matrix_class: scipy.sparse.spmatrix, rng: np.random.Generator + spformat: str, + sparse_matrix_class: scipy.sparse.spmatrix, ): """Test whether sparse matrices in different formats can be created.""" # Scipy warns that creating DIA matrices with many diagonals is inefficient. # This should not dilute the test output, as the tests # only checks the *ability* to create large random sparse matrices. + + rng_state = backend.random.rng_state(4378354) + n = 1000 if spformat == "dia": with pytest.warns(scipy.sparse.SparseEfficiencyWarning): sparse_mat = random_sparse_spd_matrix( - rng=rng, dim=1000, density=10**-3, format=spformat + rng_state=rng_state, + shape=(n, n), + density=10**-3, + format=spformat, ) else: sparse_mat = random_sparse_spd_matrix( - rng=rng, dim=1000, density=10**-3, format=spformat + rng_state=rng_state, shape=(n, n), density=10**-3, format=spformat ) assert isinstance(sparse_mat, sparse_matrix_class) -def test_large_sparse_matrix(rng: np.random.Generator): +def test_large_sparse_matrix(): """Test whether a large random spd matrix can be created.""" n = 10**5 - sparse_mat = random_sparse_spd_matrix(rng=rng, dim=n, density=10**-8) + sparse_mat = random_sparse_spd_matrix( + rng_state=backend.random.rng_state(345), shape=(n, n), density=10**-8 + ) assert sparse_mat.shape == (n, n) diff --git a/tests/test_problems/test_zoo/test_linalg/test_suitesparse_matrix.py b/tests/probnum/problems/zoo/linalg/test_suitesparse_matrix.py similarity index 100% rename from tests/test_problems/test_zoo/test_linalg/test_suitesparse_matrix.py rename to tests/probnum/problems/zoo/linalg/test_suitesparse_matrix.py diff --git a/tests/probnum/randprocs/conftest.py b/tests/probnum/randprocs/conftest.py new file mode 100644 index 000000000..e63a21c47 --- /dev/null +++ b/tests/probnum/randprocs/conftest.py @@ -0,0 +1,137 @@ +"""Fixtures for random process tests.""" + +from typing import Any, Callable, Dict, Tuple, Type + +from probnum import backend, functions, randprocs +from probnum.backend.typing import ShapeType +from probnum.randprocs import kernels + +import pytest +import pytest_cases +import tests.utils + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize( + "shape", [(), (1,), (10,), (100,)], idgen="input_shape{shape}" +) +def input_shape(shape: ShapeType) -> ShapeType: + """Input dimension of the random process.""" + return shape + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize("shape", [()], idgen="output_shape{shape}") +def output_shape(shape: ShapeType) -> ShapeType: + """Output dimension of the random process.""" + return shape + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize( + "meanfndef", + [ + ("Zero", functions.Zero), + ( + "Lambda", + lambda input_shape, output_shape: functions.LambdaFunction( + lambda x: ( + backend.full_like(x, 2.0, shape=output_shape) + * backend.sum(x, axis=tuple(range(-len(input_shape), 0))) + + 1.0 + ), + input_shape=input_shape, + output_shape=output_shape, + ), + ), + ], + idgen="{meanfndef[0]}", +) +def mean( + meanfndef: Tuple[str, Callable[[ShapeType, ShapeType], functions.Function]], + input_shape: ShapeType, + output_shape: ShapeType, +) -> functions.Function: + """Mean function of a random process.""" + return meanfndef[1](input_shape=input_shape, output_shape=output_shape) + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize( + "kerndef", + [ + (kernels.Polynomial, {"constant": 1.0, "exponent": 3}), + (kernels.ExpQuad, {"lengthscale": 1.5}), + (kernels.RatQuad, {"lengthscale": 0.5, "alpha": 2.0}), + (kernels.Matern, {"lengthscale": 0.5, "nu": 1.5}), + ], + idgen="{kerndef[0].__name__}", +) +def cov( + kerndef: Tuple[Type[kernels.Kernel], Dict[str, Any]], + input_shape: ShapeType, + output_shape: ShapeType, +) -> kernels.Kernel: + """Covariance function.""" + + if output_shape != (): + pytest.skip() + + kernel_type, kwargs = kerndef + + return kernel_type(input_shape=input_shape, **kwargs) + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize( + "randprocdef", + [ + ( + "GP-Zero-Matern", + lambda input_shape, output_shape: randprocs.GaussianProcess( + mean=functions.Zero(input_shape=input_shape), + cov=kernels.Matern(input_shape=input_shape), + ), + ), + ], + idgen="{randprocdef[0]}", +) +def random_process( + randprocdef: Tuple[str, Callable[[ShapeType, ShapeType], randprocs.RandomProcess]], + input_shape: ShapeType, + output_shape: ShapeType, +) -> randprocs.RandomProcess: + """Random process.""" + return randprocdef[1](input_shape, output_shape) + + +@pytest_cases.fixture(scope="package") +def gaussian_process( + mean: functions.Function, cov: kernels.Kernel +) -> randprocs.GaussianProcess: + """Gaussian process.""" + return randprocs.GaussianProcess(mean=mean, cov=cov) + + +@pytest_cases.fixture(scope="session") +@pytest_cases.parametrize("shape", [(), (1,), (10,)], idgen="batch_shape{shape}") +def args0_batch_shape(shape: ShapeType) -> ShapeType: + return shape + + +@pytest_cases.fixture(scope="package") +@pytest_cases.parametrize("seed", [0, 1, 2], idgen="seed{seed}") +def args0( + random_process: randprocs.RandomProcess, + seed: int, + args0_batch_shape: ShapeType, +) -> backend.Array: + """Input(s) to a random process.""" + args0_shape = args0_batch_shape + random_process.input_shape + + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=seed, shape=args0_shape + ), + shape=args0_shape, + ) diff --git a/tests/test_randprocs/test_markov/test_discrete/__init__.py b/tests/probnum/randprocs/kernels/__init__.py similarity index 100% rename from tests/test_randprocs/test_markov/test_discrete/__init__.py rename to tests/probnum/randprocs/kernels/__init__.py diff --git a/tests/probnum/randprocs/kernels/conftest.py b/tests/probnum/randprocs/kernels/conftest.py new file mode 100644 index 000000000..1d0488f25 --- /dev/null +++ b/tests/probnum/randprocs/kernels/conftest.py @@ -0,0 +1,135 @@ +"""Test fixtures for kernels.""" + +from typing import Callable, Optional + +from probnum import Backend, backend +from probnum.backend.typing import ShapeType +from probnum.randprocs import kernels + +import pytest +import tests.utils + + +# Kernel objects +@pytest.fixture( + params=[ + pytest.param(input_shape, id=f"inshape{input_shape}") + for input_shape in [(), (1,), (10,), (100,)] + ], + scope="package", +) +def input_shape(request) -> ShapeType: + """Input shape of the covariance function.""" + return request.param + + +@pytest.fixture( + params=[ + pytest.param(kerndef, id=kerndef[0].__name__) + for kerndef in [ + (kernels.Linear, {"constant": 1.0}), + (kernels.WhiteNoise, {"sigma_sq": 1.0}), + (kernels.Polynomial, {"constant": 1.0, "exponent": 3}), + (kernels.ExpQuad, {"lengthscale": 1.5}), + (kernels.RatQuad, {"lengthscale": 0.5, "alpha": 2.0}), + (kernels.Matern, {"lengthscale": 0.5, "nu": 0.5}), + (kernels.Matern, {"lengthscale": 0.5, "nu": 1.5}), + (kernels.Matern, {"lengthscale": 1.5, "nu": 2.5}), + (kernels.Matern, {"lengthscale": 2.5, "nu": 7.0}), + (kernels.Matern, {"lengthscale": 3.0, "nu": backend.inf}), + (kernels.ProductMatern, {"lengthscales": 0.5, "nus": 0.5}), + ] + ], + scope="package", +) +def kernel(request, input_shape: ShapeType) -> kernels.Kernel: + """Kernel / covariance function.""" + return request.param[0](input_shape=input_shape, **request.param[1]) + + +@pytest.mark.skipif_backend(Backend.TORCH) +@pytest.fixture(scope="package") +def kernel_call_naive( + kernel: kernels.Kernel, +) -> Callable[[backend.Array, Optional[backend.Array]], backend.Array]: + """Naive implementation of kernel broadcasting which applies the kernel function to + scalar arguments while looping over the first dimensions of the inputs explicitly. + + Can be used as a reference implementation of `Kernel.__call__` vectorization. + """ + + if kernel.input_ndim == 0: + kernel_vectorized = backend.vectorize(kernel, signature="(),()->()") + else: + assert kernel.input_ndim == 1 + + kernel_vectorized = backend.vectorize(kernel, signature="(d),(d)->()") + + return lambda x0, x1: ( + kernel_vectorized(x0, x0) if x1 is None else kernel_vectorized(x0, x1) + ) + + +# Test data for `Kernel.matrix` +@pytest.fixture( + params=[ + pytest.param(shape, id=f"x0{shape}") + for shape in [ + (), + (1,), + (2,), + (10,), + (100,), + ] + ], + scope="package", +) +def x0_batch_shape(request) -> ShapeType: + """Batch shape of the first argument of ``Kernel.matrix``.""" + return request.param + + +@pytest.fixture( + params=[ + pytest.param(shape, id=f"x1{shape}") + for shape in [ + None, + (), + (1,), + (3,), + (10,), + ] + ], + scope="package", +) +def x1_batch_shape(request) -> Optional[ShapeType]: + """Batch shape of the second argument of ``Kernel.matrix`` or ``None`` if the second + argument is ``None``.""" + return request.param + + +@pytest.fixture(scope="package") +def x0(input_shape: ShapeType, x0_batch_shape: ShapeType) -> backend.Array: + """Random data from a standard normal distribution.""" + shape = x0_batch_shape + input_shape + + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=34897, shape=shape + ) + + return backend.random.standard_normal(rng_state, shape=shape) + + +@pytest.fixture(scope="package") +def x1(input_shape: ShapeType, x1_batch_shape: ShapeType) -> Optional[backend.Array]: + """Random data from a standard normal distribution.""" + if x1_batch_shape is None: + return None + + shape = x1_batch_shape + input_shape + + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=533, shape=shape + ) + + return backend.random.standard_normal(rng_state, shape=shape) diff --git a/tests/test_randprocs/test_kernels/test_arithmetic.py b/tests/probnum/randprocs/kernels/test_arithmetic.py similarity index 100% rename from tests/test_randprocs/test_kernels/test_arithmetic.py rename to tests/probnum/randprocs/kernels/test_arithmetic.py diff --git a/tests/test_randprocs/test_kernels/test_arithmetic_fallbacks.py b/tests/probnum/randprocs/kernels/test_arithmetic_fallbacks.py similarity index 76% rename from tests/test_randprocs/test_kernels/test_arithmetic_fallbacks.py rename to tests/probnum/randprocs/kernels/test_arithmetic_fallbacks.py index 28f387ecb..7b87c7e6e 100644 --- a/tests/test_randprocs/test_kernels/test_arithmetic_fallbacks.py +++ b/tests/probnum/randprocs/kernels/test_arithmetic_fallbacks.py @@ -1,40 +1,39 @@ """Tests for fall-back implementations of kernel arithmetic.""" -import numpy as np -import pytest -from pytest_cases import parametrize - +from probnum import backend, compat from probnum.randprocs import kernels from probnum.randprocs.kernels._arithmetic_fallbacks import ( ProductKernel, ScaledKernel, SumKernel, ) -from probnum.typing import ScalarType + +import pytest +from pytest_cases import parametrize @parametrize("scalar", [1.0, 3, 1000.0]) def test_scaled_kernel_evaluation( - kernel: kernels.Kernel, scalar: ScalarType, x0: np.ndarray + kernel: kernels.Kernel, scalar: backend.Scalar, x0: backend.Array ): k_scaled = ScaledKernel(kernel=kernel, scalar=scalar) - np.testing.assert_allclose(k_scaled.matrix(x0), scalar * kernel.matrix(x0)) + compat.testing.assert_allclose(k_scaled.matrix(x0), scalar * kernel.matrix(x0)) def test_non_scalar_raises_error(): with pytest.raises(TypeError): - ScaledKernel(kernel=kernels.WhiteNoise(input_shape=()), scalar=np.array([0, 1])) + ScaledKernel(kernel=kernels.WhiteNoise(input_shape=()), scalar=[0, 1]) def test_non_kernel_raises_error(): with pytest.raises(TypeError): - ScaledKernel(kernel=np.eye(5), scalar=1.0) + ScaledKernel(kernel=backend.eye(5), scalar=1.0) -def test_sum_kernel_evaluation(kernel: kernels.Kernel, x0: np.ndarray): +def test_sum_kernel_evaluation(kernel: kernels.Kernel, x0: backend.Array): k_whitenoise = kernels.WhiteNoise(input_shape=kernel.input_shape) k_sum = SumKernel(kernel, k_whitenoise) - np.testing.assert_allclose( + compat.testing.assert_allclose( k_sum.matrix(x0), kernel.matrix(x0) + k_whitenoise.matrix(x0) ) @@ -53,10 +52,12 @@ def test_sum_kernel_contracts(): assert all(not isinstance(summand, SumKernel) for summand in k_sum._summands) -def test_product_kernel_evaluation(kernel: kernels.Kernel, x0: np.ndarray): +def test_product_kernel_evaluation(kernel: kernels.Kernel, x0: backend.Array): k_poly = kernels.Polynomial(input_shape=kernel.input_shape) k_sum = ProductKernel(kernel, k_poly) - np.testing.assert_allclose(k_sum.matrix(x0), kernel.matrix(x0) * k_poly.matrix(x0)) + compat.testing.assert_allclose( + k_sum.matrix(x0), kernel.matrix(x0) * k_poly.matrix(x0) + ) def test_product_kernel_shape_mismatch_raises_error(): diff --git a/tests/test_randprocs/test_kernels/test_call.py b/tests/probnum/randprocs/kernels/test_call.py similarity index 58% rename from tests/test_randprocs/test_kernels/test_call.py rename to tests/probnum/randprocs/kernels/test_call.py index c0159af48..3ecb84589 100644 --- a/tests/test_randprocs/test_kernels/test_call.py +++ b/tests/probnum/randprocs/kernels/test_call.py @@ -1,12 +1,13 @@ """Test cases for `Kernel.__call__`.""" -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple -import numpy as np -import pytest +from probnum import backend, compat +from probnum.backend.typing import ShapeType +from probnum.randprocs import kernels -import probnum as pn -from probnum.typing import ShapeType +import pytest +import tests.utils @pytest.fixture( @@ -37,12 +38,16 @@ ] ], name="input_shapes", + scope="module", ) def fixture_input_shapes( request, input_shape: ShapeType ) -> Tuple[ShapeType, Optional[ShapeType]]: - """Shapes for the first and second argument of the covariance function. The second - shape is ``None`` if the second argument to the covariance function is ``None``.""" + """Shapes for the first and second argument of the covariance function. + + The second shape is ``None`` if the second argument to the covariance function is + ``None``. + """ x0_shape, x1_shape = request.param @@ -52,23 +57,25 @@ def fixture_input_shapes( ) -@pytest.fixture(name="x0") -def fixture_x0( - rng: np.random.Generator, input_shapes: Tuple[ShapeType, Optional[ShapeType]] -) -> np.ndarray: +@pytest.fixture(name="x0", scope="module") +def fixture_x0(input_shapes: Tuple[ShapeType, Optional[ShapeType]]) -> backend.Array: """The first argument to the covariance function drawn from a standard normal distribution.""" x0_shape, _ = input_shapes - return rng.normal(0, 1, size=x0_shape) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=899803, shape=x0_shape + ), + shape=x0_shape, + ) -@pytest.fixture(name="x1") +@pytest.fixture(name="x1", scope="module") def fixture_x1( - rng: np.random.Generator, - input_shapes: Tuple[ShapeType, Optional[ShapeType]], -) -> Optional[np.ndarray]: + input_shapes: Tuple[ShapeType, Optional[ShapeType]] +) -> Optional[backend.Array]: """The second argument to the covariance function drawn from a standard normal distribution.""" @@ -77,40 +84,47 @@ def fixture_x1( if x1_shape is None: return None - return rng.normal(0, 1, size=x1_shape) + return backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=4569, shape=x1_shape + ), + shape=x1_shape, + ) -@pytest.fixture(name="call_result") +@pytest.fixture(name="call_result", scope="module") def fixture_call_result( - kernel: pn.randprocs.kernels.Kernel, x0: np.ndarray, x1: Optional[np.ndarray] -) -> Union[np.ndarray, np.floating]: + kernel: kernels.Kernel, x0: backend.Array, x1: Optional[backend.Array] +) -> backend.Array: """Result of ``Kernel.__call__`` when given ``x0`` and ``x1``.""" return kernel(x0, x1) -@pytest.fixture(name="call_result_naive") +@pytest.fixture(name="call_result_naive", scope="module") def fixture_call_result_naive( - kernel_call_naive: Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray], - x0: np.ndarray, - x1: Optional[np.ndarray], -) -> Union[np.ndarray, np.floating]: + kernel_call_naive: Callable[ + [backend.Array, Optional[backend.Array]], backend.Array + ], + x0: backend.Array, + x1: Optional[backend.Array], +) -> backend.Array: """Result of ``Kernel.__call__`` when applied to the entries of ``x0`` and ``x1`` in a loop.""" return kernel_call_naive(x0, x1) -def test_type(call_result: Union[np.ndarray, np.floating]): - """Test whether the type of the output of ``Kernel.__call__`` is a NumPy type, i.e. - an ``np.ndarray`` or a ``np.floating``.""" +def test_type(call_result: backend.Array): + """Test whether the type of the output of ``Kernel.__call__`` is an object of + ``backend.Array``.""" - assert isinstance(call_result, (np.ndarray, np.floating)) + assert backend.isarray(call_result) def test_shape( - call_result: Union[np.ndarray, np.floating], - call_result_naive: Union[np.ndarray, np.floating], + call_result: backend.Array, + call_result_naive: backend.Array, ): """Test whether the shape of the output of ``Kernel.__call__`` matches the shape of the naive reference implementation.""" @@ -119,13 +133,13 @@ def test_shape( def test_values( - call_result: Union[np.ndarray, np.floating], - call_result_naive: Union[np.ndarray, np.floating], + call_result: backend.Array, + call_result_naive: backend.Array, ): """Test whether the entries of the output of ``Kernel.__call__`` match the entries generated by the naive reference implementation.""" - np.testing.assert_allclose( + compat.testing.assert_allclose( call_result, call_result_naive, rtol=10**-12, @@ -143,20 +157,20 @@ def test_values( (4, 25), ], ) -def test_wrong_input_dimension(kernel: pn.randprocs.kernels.Kernel, shape: ShapeType): +def test_wrong_input_dimension(kernel: kernels.Kernel, shape: ShapeType): """Test whether passing an input with the wrong input dimension raises an error.""" if kernel.input_ndim > 0: input_shape = shape + tuple(dim + 1 for dim in kernel.input_shape) with pytest.raises(ValueError): - kernel(np.zeros(input_shape), None) + kernel(backend.zeros(input_shape), None) with pytest.raises(ValueError): - kernel(np.ones(input_shape), np.zeros(shape + kernel.input_shape)) + kernel(backend.ones(input_shape), backend.zeros(shape + kernel.input_shape)) with pytest.raises(ValueError): - kernel(np.ones(shape + kernel.input_shape), np.zeros(input_shape)) + kernel(backend.ones(shape + kernel.input_shape), backend.zeros(input_shape)) @pytest.mark.parametrize( @@ -168,15 +182,15 @@ def test_wrong_input_dimension(kernel: pn.randprocs.kernels.Kernel, shape: Shape ], ) def test_broadcasting_error( - kernel: pn.randprocs.kernels.Kernel, - x0_shape: np.ndarray, - x1_shape: np.ndarray, + kernel: kernels.Kernel, + x0_shape: backend.Array, + x1_shape: backend.Array, ): """Test whether an error is raised if the inputs can not be broadcast to a common shape.""" with pytest.raises(ValueError): kernel( - np.zeros(x0_shape + kernel.input_shape), - np.ones(x1_shape + kernel.input_shape), + backend.zeros(x0_shape + kernel.input_shape), + backend.ones(x1_shape + kernel.input_shape), ) diff --git a/tests/test_randprocs/test_kernels/test_matern.py b/tests/probnum/randprocs/kernels/test_matern.py similarity index 82% rename from tests/test_randprocs/test_kernels/test_matern.py rename to tests/probnum/randprocs/kernels/test_matern.py index b58dc75b5..e5555804b 100644 --- a/tests/test_randprocs/test_kernels/test_matern.py +++ b/tests/probnum/randprocs/kernels/test_matern.py @@ -1,10 +1,10 @@ """Test cases for the Matern kernel.""" -import numpy as np -import pytest - +from probnum import backend, compat +from probnum.backend.typing import ShapeType from probnum.randprocs import kernels -from probnum.typing import ShapeType + +import pytest @pytest.mark.parametrize("nu", [-1, -1.0, 0.0, 0]) @@ -15,14 +15,14 @@ def test_nonpositive_nu_raises_exception(nu): def test_nu_large_recovers_rbf_kernel( - x0: np.ndarray, x1: np.ndarray, input_shape: ShapeType + x0: backend.Array, x1: backend.Array, input_shape: ShapeType ): """Test whether a Matern kernel with nu large is close to an RBF kernel.""" lengthscale = 1.25 rbf = kernels.ExpQuad(input_shape=input_shape, lengthscale=lengthscale) matern = kernels.Matern(input_shape=input_shape, lengthscale=lengthscale, nu=15) - np.testing.assert_allclose( + compat.testing.assert_allclose( rbf.matrix(x0, x1), matern.matrix(x0, x1), err_msg="RBF and Matern kernel are not sufficiently close for nu->infty.", diff --git a/tests/test_randprocs/test_kernels/test_matrix.py b/tests/probnum/randprocs/kernels/test_matrix.py similarity index 55% rename from tests/test_randprocs/test_kernels/test_matrix.py rename to tests/probnum/randprocs/kernels/test_matrix.py index 330102f70..f5cc9e67b 100644 --- a/tests/test_randprocs/test_kernels/test_matrix.py +++ b/tests/probnum/randprocs/kernels/test_matrix.py @@ -2,35 +2,37 @@ from typing import Callable, Optional -import numpy as np -import pytest +from probnum import backend, compat +from probnum.backend.typing import ShapeType +from probnum.randprocs import kernels -import probnum as pn -from probnum.typing import ShapeType +import pytest -@pytest.fixture(name="kernmat") +@pytest.fixture(name="kernmat", scope="module") def fixture_kernmat( - kernel: pn.randprocs.kernels.Kernel, x0: np.ndarray, x1: Optional[np.ndarray] -) -> np.ndarray: + kernel: kernels.Kernel, x0: backend.Array, x1: Optional[backend.Array] +) -> backend.Array: """Kernel evaluated at the data.""" - if x1 is None and np.prod(x0.shape[:-1]) >= 100: + if x1 is None and x0.size // kernel.input_size >= 100: pytest.skip("Runs too long") return kernel.matrix(x0, x1) -@pytest.fixture(name="kernmat_naive") +@pytest.fixture(name="kernmat_naive", scope="module") def fixture_kernmat_naive( - kernel: pn.randprocs.kernels.Kernel, - kernel_call_naive: Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray], - x0: np.ndarray, - x1: Optional[np.ndarray], -) -> np.ndarray: + kernel: kernels.Kernel, + kernel_call_naive: Callable[ + [backend.Array, Optional[backend.Array]], backend.Array + ], + x0: backend.Array, + x1: Optional[backend.Array], +) -> backend.Array: """Kernel evaluated at the data.""" if x1 is None: - if np.prod(x0.shape[:-1]) >= 100: + if x0.size // kernel.input_size >= 100: pytest.skip("Runs too long") x1 = x0 @@ -41,18 +43,18 @@ def fixture_kernmat_naive( return kernel_call_naive(x0, x1) -def test_type(kernmat: np.ndarray): +def test_type(kernmat: backend.Array): """Check whether a kernel evaluates to a numpy scalar or array.""" - assert isinstance(kernmat, (np.ndarray, np.number)) + assert backend.isarray(kernmat) def test_shape( - kernel: pn.randprocs.kernels.Kernel, - x0: np.ndarray, - x1: Optional[np.ndarray], - kernmat: np.ndarray, - kernmat_naive: np.ndarray, + kernel: kernels.Kernel, + x0: backend.Array, + x1: Optional[backend.Array], + kernmat: backend.Array, + kernmat_naive: backend.Array, ): """Test the shape of a kernel evaluated at sets of inputs.""" @@ -64,12 +66,12 @@ def test_shape( def test_kernel_matrix_against_naive( - kernmat: np.ndarray, - kernmat_naive: np.ndarray, + kernmat: backend.Array, + kernmat_naive: backend.Array, ): """Test the computation of the kernel matrix against a naive computation.""" - np.testing.assert_allclose( + compat.testing.assert_allclose( kernmat, kernmat_naive, rtol=10**-12, @@ -85,20 +87,20 @@ def test_kernel_matrix_against_naive( ], ) def test_invalid_shape( - kernel: pn.randprocs.kernels.Kernel, - x0_shape: np.ndarray, - x1_shape: np.ndarray, + kernel: kernels.Kernel, + x0_shape: backend.Array, + x1_shape: backend.Array, ): """Test whether an error is raised if the inputs can not be broadcast to a common shape.""" with pytest.raises(ValueError): - kernel.matrix(np.zeros(x0_shape + kernel.input_shape)) + kernel.matrix(backend.zeros(x0_shape + kernel.input_shape)) with pytest.raises(ValueError): kernel.matrix( - np.zeros(x0_shape + kernel.input_shape), - np.ones(x1_shape + kernel.input_shape), + backend.zeros(x0_shape + kernel.input_shape), + backend.ones(x1_shape + kernel.input_shape), ) @@ -110,7 +112,7 @@ def test_invalid_shape( (10,), ], ) -def test_wrong_input_dimension(kernel: pn.randprocs.kernels.Kernel, shape: ShapeType): +def test_wrong_input_dimension(kernel: kernels.Kernel, shape: ShapeType): """Test whether passing an input with the wrong input dimension raises an error.""" if kernel.input_ndim == 0: @@ -119,10 +121,14 @@ def test_wrong_input_dimension(kernel: pn.randprocs.kernels.Kernel, shape: Shape input_shape = shape + tuple(dim + 1 for dim in kernel.input_shape) with pytest.raises(ValueError): - kernel.matrix(np.zeros(input_shape)) + kernel.matrix(backend.zeros(input_shape)) with pytest.raises(ValueError): - kernel.matrix(np.ones(input_shape), np.zeros(shape + kernel.input_shape)) + kernel.matrix( + backend.ones(input_shape), backend.zeros(shape + kernel.input_shape) + ) with pytest.raises(ValueError): - kernel.matrix(np.ones(shape + kernel.input_shape), np.zeros(input_shape)) + kernel.matrix( + backend.ones(shape + kernel.input_shape), backend.zeros(input_shape) + ) diff --git a/tests/probnum/randprocs/kernels/test_product_matern.py b/tests/probnum/randprocs/kernels/test_product_matern.py new file mode 100644 index 000000000..a3218340f --- /dev/null +++ b/tests/probnum/randprocs/kernels/test_product_matern.py @@ -0,0 +1,60 @@ +"""Test cases for the product Matern kernel.""" + +import functools +import operator + +from probnum import backend, compat +from probnum.backend.typing import ArrayLike, ShapeType +from probnum.randprocs import kernels + +import pytest +import tests.utils + + +@pytest.mark.parametrize("lengthscale", [1.25]) +@pytest.mark.parametrize("nu", [0.5, 1.5, 2.5, 3.0]) +def test_kernel_matrix(input_shape: ShapeType, lengthscale: float, nu: float): + """Check that the product Matérn kernel matrix is an elementwise product of 1D + Matérn kernel matrices.""" + if len(input_shape) > 1: + pytest.skip() + + matern = kernels.Matern(input_shape=(), lengthscale=lengthscale, nu=nu) + product_matern = kernels.ProductMatern( + input_shape=input_shape, lengthscales=lengthscale, nus=nu + ) + + xs_shape = (15,) + input_shape + xs = backend.random.uniform( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=42, shape=xs_shape + ), + shape=xs_shape, + ) + + kernel_matrix1 = product_matern.matrix(xs) + + if len(input_shape) > 0: + assert len(input_shape) == 1 + + kernel_matrix2 = functools.reduce( + operator.mul, (matern.matrix(xs[:, dim]) for dim in range(input_shape[0])) + ) + else: + kernel_matrix2 = matern.matrix(xs) + + compat.testing.assert_allclose(kernel_matrix1, kernel_matrix2) + + +@pytest.mark.parametrize( + "ell,nu", + [ + ([3.0], 0.5), + (3.0, [0.5]), + ([3.0], [0.5]), + ], +) +def test_wrong_initialization_raises_exception(ell: ArrayLike, nu: ArrayLike): + """Parameters must be scalars if kernel input is scalar.""" + with pytest.raises(ValueError): + kernels.ProductMatern(input_shape=(), lengthscales=ell, nus=nu) diff --git a/tests/test_randprocs/test_kernels/test_rational_quadratic.py b/tests/probnum/randprocs/kernels/test_rational_quadratic.py similarity index 84% rename from tests/test_randprocs/test_kernels/test_rational_quadratic.py rename to tests/probnum/randprocs/kernels/test_rational_quadratic.py index 6c2263f55..f25971c61 100644 --- a/tests/test_randprocs/test_kernels/test_rational_quadratic.py +++ b/tests/probnum/randprocs/kernels/test_rational_quadratic.py @@ -1,12 +1,12 @@ """Test cases for the rational quadratic kernel.""" -import pytest - from probnum.randprocs import kernels +import pytest + @pytest.mark.parametrize("alpha", [-1, -1.0, 0.0, 0]) -def test_nonpositive_alpha_raises_exception(alpha): +def test_nonpositive_alpha_raises_exception(alpha: float): """Check whether a non-positive alpha parameter raises a ValueError.""" with pytest.raises(ValueError): kernels.RatQuad(input_shape=(), alpha=alpha) diff --git a/tests/test_randprocs/test_markov/test_integrator/__init__.py b/tests/probnum/randprocs/markov/__init__.py similarity index 100% rename from tests/test_randprocs/test_markov/test_integrator/__init__.py rename to tests/probnum/randprocs/markov/__init__.py diff --git a/tests/probnum/randprocs/markov/conftest.py b/tests/probnum/randprocs/markov/conftest.py new file mode 100644 index 000000000..4095a9654 --- /dev/null +++ b/tests/probnum/randprocs/markov/conftest.py @@ -0,0 +1,104 @@ +"""Fixtures for Markov processes.""" + +import numpy as np + +from probnum import backend, randvars +from probnum.problems.zoo.linalg import random_spd_matrix + +import pytest +from tests.utils.random import rng_state_from_sampling_args + + +@pytest.fixture(params=[2]) +def state_dim(request) -> int: + """State dimension.""" + return request.param + + +# Covariance matrices + + +@pytest.fixture +def spdmat1(state_dim: int): + rng_state = rng_state_from_sampling_args(base_seed=3245956, shape=state_dim) + return random_spd_matrix(rng_state, shape=(state_dim, state_dim)) + + +@pytest.fixture +def spdmat2(state_dim: int): + rng_state = rng_state_from_sampling_args(base_seed=1, shape=state_dim) + return random_spd_matrix(rng_state, shape=(state_dim, state_dim)) + + +@pytest.fixture +def spdmat3(state_dim: int): + rng_state = rng_state_from_sampling_args(base_seed=2498, shape=state_dim) + return random_spd_matrix(rng_state, shape=(state_dim, state_dim)) + + +@pytest.fixture +def spdmat4(state_dim: int): + rng_state = rng_state_from_sampling_args(base_seed=4056, shape=state_dim) + return random_spd_matrix(rng_state, shape=(state_dim, state_dim)) + + +# 'Normal' random variables + + +@pytest.fixture +def some_normal_rv1(state_dim, spdmat1): + rng_state = rng_state_from_sampling_args(base_seed=6879, shape=spdmat1.shape) + return randvars.Normal( + mean=backend.random.uniform(rng_state=rng_state, shape=state_dim), + cov=spdmat1, + cache={"cov_cholesky": np.linalg.cholesky(spdmat1)}, + ) + + +@pytest.fixture +def some_normal_rv2(state_dim, spdmat2): + rng_state = rng_state_from_sampling_args(base_seed=2344, shape=spdmat2.shape) + return randvars.Normal( + mean=backend.random.uniform(rng_state=rng_state, shape=state_dim), + cov=spdmat2, + cache={"cov_cholesky": np.linalg.cholesky(spdmat2)}, + ) + + +@pytest.fixture +def some_normal_rv3(state_dim, spdmat3): + rng_state = rng_state_from_sampling_args(base_seed=76, shape=spdmat3.shape) + return randvars.Normal( + mean=backend.random.uniform(rng_state=rng_state, shape=state_dim), + cov=spdmat3, + cache={"cov_cholesky": np.linalg.cholesky(spdmat3)}, + ) + + +@pytest.fixture +def some_normal_rv4(state_dim, spdmat4): + rng_state = rng_state_from_sampling_args(base_seed=22, shape=spdmat4.shape) + return randvars.Normal( + mean=backend.random.uniform(rng_state=rng_state, shape=state_dim), + cov=spdmat4, + cache={"cov_cholesky": np.linalg.cholesky(spdmat4)}, + ) + + +@pytest.fixture +def diffusion(): + """A diffusion != 1 makes it easier to see if _diffusion is actually used in forward + and backward.""" + return 5.1412512431 + + +@pytest.fixture(params=["classic", "sqrt"]) +def forw_impl_string_linear_gauss(request): + """Forward implementation choices passed via strings.""" + return request.param + + +@pytest.fixture(params=["classic", "joseph", "sqrt"]) +def backw_impl_string_linear_gauss(request): + """Backward implementation choices passed via strings.""" + return request.param diff --git a/tests/test_utils/__init__.py b/tests/probnum/randprocs/markov/continuous/__init__.py similarity index 100% rename from tests/test_utils/__init__.py rename to tests/probnum/randprocs/markov/continuous/__init__.py diff --git a/tests/test_randprocs/test_markov/test_continuous/test_diffusions.py b/tests/probnum/randprocs/markov/continuous/test_diffusions.py similarity index 96% rename from tests/test_randprocs/test_markov/test_continuous/test_diffusions.py rename to tests/probnum/randprocs/markov/continuous/test_diffusions.py index a5096a86f..716d722a1 100644 --- a/tests/test_randprocs/test_markov/test_continuous/test_diffusions.py +++ b/tests/probnum/randprocs/markov/continuous/test_diffusions.py @@ -3,17 +3,17 @@ import abc import numpy as np -import pytest from probnum import randprocs, randvars +import pytest + @pytest.fixture def some_meas_rv1(): """Generic measurement RV used to test calibration. - This config should return 9.776307498421126 for - Diffusion.calibrate_locally. + This config should return 9.776307498421126 for Diffusion.calibrate_locally. """ some_mean = np.arange(10, 13) some_cov = np.arange(9).reshape((3, 3)) @ np.arange(9).reshape((3, 3)).T + np.eye(3) @@ -25,8 +25,7 @@ def some_meas_rv1(): def some_meas_rv2(): """Another generic measurement RV used to test calibration. - This config should return 9.776307498421126 for - Diffusion.calibrate_locally. + This config should return 9.776307498421126 for Diffusion.calibrate_locally. """ some_mean = np.arange(10, 13) some_cov = np.arange(3, 12).reshape((3, 3)) @ np.arange(3, 12).reshape( diff --git a/tests/test_randprocs/test_markov/test_continuous/test_linear_sde.py b/tests/probnum/randprocs/markov/continuous/test_linear_sde.py similarity index 96% rename from tests/test_randprocs/test_markov/test_continuous/test_linear_sde.py rename to tests/probnum/randprocs/markov/continuous/test_linear_sde.py index 5e4970e81..854c12f0f 100644 --- a/tests/test_randprocs/test_markov/test_continuous/test_linear_sde.py +++ b/tests/probnum/randprocs/markov/continuous/test_linear_sde.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov.test_continuous import test_sde + +import pytest +from tests.probnum.randprocs.markov.continuous import test_sde class TestLinearSDE(test_sde.TestSDE): @@ -10,14 +11,14 @@ class TestLinearSDE(test_sde.TestSDE): # Replacement for an __init__ in the pytest language. See: # https://stackoverflow.com/questions/21430900/py-test-skips-test-class-if-constructor-is-defined @pytest.fixture(autouse=True) - def _setup(self, test_ndim, spdmat1, spdmat2): + def _setup(self, state_dim, spdmat1, spdmat2): self.G = lambda t: spdmat1 - self.v = lambda t: np.arange(test_ndim) + self.v = lambda t: np.arange(state_dim) self.L = lambda t: spdmat2 self.transition = randprocs.markov.continuous.LinearSDE( - state_dimension=test_ndim, - wiener_process_dimension=test_ndim, + state_dimension=state_dim, + wiener_process_dimension=state_dim, drift_matrix_function=self.G, force_vector_function=self.v, dispersion_matrix_function=self.L, diff --git a/tests/test_randprocs/test_markov/test_continuous/test_lti_sde.py b/tests/probnum/randprocs/markov/continuous/test_lti_sde.py similarity index 93% rename from tests/test_randprocs/test_markov/test_continuous/test_lti_sde.py rename to tests/probnum/randprocs/markov/continuous/test_lti_sde.py index a8e475279..407a4203f 100644 --- a/tests/test_randprocs/test_markov/test_continuous/test_lti_sde.py +++ b/tests/probnum/randprocs/markov/continuous/test_lti_sde.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov.test_continuous import test_linear_sde + +import pytest +from tests.probnum.randprocs.markov.continuous import test_linear_sde class TestLTISDE(test_linear_sde.TestLinearSDE): @@ -12,7 +13,7 @@ class TestLTISDE(test_linear_sde.TestLinearSDE): @pytest.fixture(autouse=True) def _setup( self, - test_ndim, + state_dim, spdmat1, spdmat2, forw_impl_string_linear_gauss, @@ -20,7 +21,7 @@ def _setup( ): self.G_const = spdmat1 - self.v_const = np.arange(test_ndim) + self.v_const = np.arange(state_dim) self.L_const = spdmat2 self.transition = randprocs.markov.continuous.LTISDE( diff --git a/tests/test_randprocs/test_markov/test_continuous/test_mfd.py b/tests/probnum/randprocs/markov/continuous/test_mfd.py similarity index 99% rename from tests/test_randprocs/test_markov/test_continuous/test_mfd.py rename to tests/probnum/randprocs/markov/continuous/test_mfd.py index e33ea23c0..a3528a0f2 100644 --- a/tests/test_randprocs/test_markov/test_continuous/test_mfd.py +++ b/tests/probnum/randprocs/markov/continuous/test_mfd.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs +import pytest + @pytest.fixture def dt(): diff --git a/tests/test_randprocs/test_markov/test_continuous/test_sde.py b/tests/probnum/randprocs/markov/continuous/test_sde.py similarity index 78% rename from tests/test_randprocs/test_markov/test_continuous/test_sde.py rename to tests/probnum/randprocs/markov/continuous/test_sde.py index 0c904aa55..7d33382ff 100644 --- a/tests/test_randprocs/test_markov/test_continuous/test_sde.py +++ b/tests/probnum/randprocs/markov/continuous/test_sde.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs -from tests.test_randprocs.test_markov import test_transition + +import pytest +from tests.probnum.randprocs.markov import test_transition class TestSDE(test_transition.InterfaceTestTransition): @@ -10,14 +11,14 @@ class TestSDE(test_transition.InterfaceTestTransition): # Replacement for an __init__ in the pytest language. See: # https://stackoverflow.com/questions/21430900/py-test-skips-test-class-if-constructor-is-defined @pytest.fixture(autouse=True) - def _setup(self, test_ndim, spdmat1): + def _setup(self, state_dim, spdmat1): self.g = lambda t, x: np.sin(x) self.l = lambda t, x: spdmat1 self.dg = lambda t, x: np.cos(x) self.transition = randprocs.markov.continuous.SDE( - state_dimension=test_ndim, - wiener_process_dimension=test_ndim, + state_dimension=state_dim, + wiener_process_dimension=state_dim, drift_function=self.g, dispersion_function=self.l, drift_jacobian=self.dg, @@ -60,14 +61,14 @@ def test_backward_realization(self, some_normal_rv1, some_normal_rv2): some_normal_rv1.mean, some_normal_rv2, 0.0, dt=0.1 ) - def test_input_dim(self, test_ndim): - assert self.transition.input_dim == test_ndim + def test_input_dim(self, state_dim): + assert self.transition.input_dim == state_dim - def test_output_dim(self, test_ndim): - assert self.transition.output_dim == test_ndim + def test_output_dim(self, state_dim): + assert self.transition.output_dim == state_dim - def test_state_dimension(self, test_ndim): - assert self.transition.state_dimension == test_ndim + def test_state_dimension(self, state_dim): + assert self.transition.state_dimension == state_dim - def test_wiener_process_dimension(self, test_ndim): - assert self.transition.wiener_process_dimension == test_ndim + def test_wiener_process_dimension(self, state_dim): + assert self.transition.wiener_process_dimension == state_dim diff --git a/tests/test_utils/test_linalg/__init__.py b/tests/probnum/randprocs/markov/discrete/__init__.py similarity index 100% rename from tests/test_utils/test_linalg/__init__.py rename to tests/probnum/randprocs/markov/discrete/__init__.py diff --git a/tests/test_randprocs/test_markov/test_discrete/test_condition_state.py b/tests/probnum/randprocs/markov/discrete/test_condition_state.py similarity index 100% rename from tests/test_randprocs/test_markov/test_discrete/test_condition_state.py rename to tests/probnum/randprocs/markov/discrete/test_condition_state.py diff --git a/tests/test_randprocs/test_markov/test_discrete/test_linear_gaussian.py b/tests/probnum/randprocs/markov/discrete/test_linear_gaussian.py similarity index 95% rename from tests/test_randprocs/test_markov/test_discrete/test_linear_gaussian.py rename to tests/probnum/randprocs/markov/discrete/test_linear_gaussian.py index 0696d03dc..00e0edab5 100644 --- a/tests/test_randprocs/test_markov/test_discrete/test_linear_gaussian.py +++ b/tests/probnum/randprocs/markov/discrete/test_linear_gaussian.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import config, linops, randprocs, randvars -from tests.test_randprocs.test_markov.test_discrete import test_nonlinear_gaussian + +import pytest +from tests.probnum.randprocs.markov.discrete import test_nonlinear_gaussian @pytest.fixture(params=["classic", "sqrt"]) @@ -30,7 +31,7 @@ class TestLinearGaussian(test_nonlinear_gaussian.TestNonlinearGaussian): @pytest.fixture(autouse=True) def _setup( self, - test_ndim, + state_dim, spdmat1, spdmat2, forw_impl_string_linear_gauss, @@ -39,12 +40,12 @@ def _setup( self.transition_matrix_fun = lambda t: spdmat1 self.noise_fun = lambda t: randvars.Normal( - mean=np.arange(test_ndim), cov=spdmat2 + mean=np.arange(state_dim), cov=spdmat2 ) self.transition = randprocs.markov.discrete.LinearGaussian( - input_dim=test_ndim, - output_dim=test_ndim, + input_dim=state_dim, + output_dim=state_dim, transition_matrix_fun=self.transition_matrix_fun, noise_fun=self.noise_fun, forward_implementation=forw_impl_string_linear_gauss, @@ -255,27 +256,27 @@ class TestLinearGaussianLinOps: @pytest.fixture(autouse=True) def _setup( self, - test_ndim, + state_dim, spdmat1, spdmat2, ): with config(matrix_free=True): self.noise_fun = lambda t: randvars.Normal( - mean=np.arange(test_ndim), cov=linops.aslinop(spdmat2) + mean=np.arange(state_dim), cov=linops.aslinop(spdmat2) ) self.transition_matrix_fun = lambda t: linops.aslinop(spdmat1) self.transition = randprocs.markov.discrete.LinearGaussian( - input_dim=test_ndim, - output_dim=test_ndim, + input_dim=state_dim, + output_dim=state_dim, transition_matrix_fun=self.transition_matrix_fun, noise_fun=self.noise_fun, forward_implementation="classic", backward_implementation="classic", ) self.sqrt_transition = randprocs.markov.discrete.LinearGaussian( - input_dim=test_ndim, - output_dim=test_ndim, + input_dim=state_dim, + output_dim=state_dim, transition_matrix_fun=self.transition_matrix_fun, noise_fun=self.noise_fun, forward_implementation="sqrt", @@ -306,7 +307,7 @@ def test_forward_rv(self, some_normal_rv1): out, _ = self.transition.forward_rv(linop_cov_rv, 0.0) assert isinstance(out, randvars.Normal) assert isinstance(out.cov, linops.LinearOperator) - assert isinstance(out.cov_cholesky, linops.LinearOperator) + assert isinstance(out._cov_cholesky, linops.LinearOperator) with pytest.raises(NotImplementedError): self.sqrt_transition.forward_rv(array_cov_rv, 0.0) @@ -333,7 +334,7 @@ def test_backward_rv_classic(self, some_normal_rv1, some_normal_rv2): out, _ = self.transition.backward_rv(linop_cov_rv1, linop_cov_rv2) assert isinstance(out, randvars.Normal) assert isinstance(out.cov, linops.LinearOperator) - assert isinstance(out.cov_cholesky, linops.LinearOperator) + assert isinstance(out._cov_cholesky, linops.LinearOperator) with pytest.raises(NotImplementedError): self.sqrt_transition.backward_rv(array_cov_rv1, array_cov_rv2) diff --git a/tests/test_randprocs/test_markov/test_discrete/test_lti_gaussian.py b/tests/probnum/randprocs/markov/discrete/test_lti_gaussian.py similarity index 90% rename from tests/test_randprocs/test_markov/test_discrete/test_lti_gaussian.py rename to tests/probnum/randprocs/markov/discrete/test_lti_gaussian.py index 09132c92a..df7aad4c6 100644 --- a/tests/test_randprocs/test_markov/test_discrete/test_lti_gaussian.py +++ b/tests/probnum/randprocs/markov/discrete/test_lti_gaussian.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov.test_discrete import test_linear_gaussian + +import pytest +from tests.probnum.randprocs.markov.discrete import test_linear_gaussian class TestLTIGaussian(test_linear_gaussian.TestLinearGaussian): @@ -12,7 +13,7 @@ class TestLTIGaussian(test_linear_gaussian.TestLinearGaussian): @pytest.fixture(autouse=True) def _setup( self, - test_ndim, + state_dim, spdmat1, spdmat2, forw_impl_string_linear_gauss, @@ -20,7 +21,7 @@ def _setup( ): self.transition_matrix = spdmat1 - self.noise = randvars.Normal(mean=np.arange(test_ndim), cov=spdmat2) + self.noise = randvars.Normal(mean=np.arange(state_dim), cov=spdmat2) self.transition = randprocs.markov.discrete.LTIGaussian( transition_matrix=self.transition_matrix, diff --git a/tests/test_randprocs/test_markov/test_discrete/test_nonlinear_gaussian.py b/tests/probnum/randprocs/markov/discrete/test_nonlinear_gaussian.py similarity index 87% rename from tests/test_randprocs/test_markov/test_discrete/test_nonlinear_gaussian.py rename to tests/probnum/randprocs/markov/discrete/test_nonlinear_gaussian.py index 7f564b09b..fe42de4de 100644 --- a/tests/test_randprocs/test_markov/test_discrete/test_nonlinear_gaussian.py +++ b/tests/probnum/randprocs/markov/discrete/test_nonlinear_gaussian.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov import test_transition + +import pytest +from tests.probnum.randprocs.markov import test_transition class TestNonlinearGaussian(test_transition.InterfaceTestTransition): @@ -17,17 +18,17 @@ class TestNonlinearGaussian(test_transition.InterfaceTestTransition): # Replacement for an __init__ in the pytest language. See: # https://stackoverflow.com/questions/21430900/py-test-skips-test-class-if-constructor-is-defined @pytest.fixture(autouse=True) - def _setup(self, test_ndim, spdmat1): + def _setup(self, state_dim, spdmat1): self.transition_fun = lambda t, x: np.sin(x) self.noise_fun = lambda t: randvars.Normal( - mean=np.zeros(test_ndim), cov=spdmat1 + mean=np.zeros(state_dim), cov=spdmat1 ) self.transition_fun_jacobian = lambda t, x: np.cos(x) self.transition = randprocs.markov.discrete.NonlinearGaussian( - input_dim=test_ndim, - output_dim=test_ndim, + input_dim=state_dim, + output_dim=state_dim, transition_fun=self.transition_fun, transition_fun_jacobian=self.transition_fun_jacobian, noise_fun=self.noise_fun, @@ -71,8 +72,8 @@ def test_backward_realization(self, some_normal_rv1, some_normal_rv2): with pytest.raises(NotImplementedError): self.transition.backward_realization(some_normal_rv1.mean, some_normal_rv2) - def test_input_dim(self, test_ndim): - assert self.transition.input_dim == test_ndim + def test_input_dim(self, state_dim): + assert self.transition.input_dim == state_dim - def test_output_dim(self, test_ndim): - assert self.transition.output_dim == test_ndim + def test_output_dim(self, state_dim): + assert self.transition.output_dim == state_dim diff --git a/tests/probnum/randprocs/markov/integrator/__init__.py b/tests/probnum/randprocs/markov/integrator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_randprocs/test_markov/test_integrator/conftest.py b/tests/probnum/randprocs/markov/integrator/conftest.py similarity index 50% rename from tests/test_randprocs/test_markov/test_integrator/conftest.py rename to tests/probnum/randprocs/markov/integrator/conftest.py index 4663c8f19..fdba8663b 100644 --- a/tests/test_randprocs/test_markov/test_integrator/conftest.py +++ b/tests/probnum/randprocs/markov/integrator/conftest.py @@ -4,5 +4,5 @@ @pytest.fixture -def some_num_derivatives(test_ndim): - return test_ndim - 1 +def some_num_derivatives(state_dim): + return state_dim - 1 diff --git a/tests/test_randprocs/test_markov/test_integrator/test_convert.py b/tests/probnum/randprocs/markov/integrator/test_convert.py similarity index 99% rename from tests/test_randprocs/test_markov/test_integrator/test_convert.py rename to tests/probnum/randprocs/markov/integrator/test_convert.py index 4a71b56b5..a6592dc24 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_convert.py +++ b/tests/probnum/randprocs/markov/integrator/test_convert.py @@ -1,10 +1,11 @@ """Tests for the coordinate conversion functions.""" import numpy as np -import pytest from probnum import randprocs +import pytest + @pytest.fixture def some_order(): diff --git a/tests/test_randprocs/test_markov/test_integrator/test_integrator.py b/tests/probnum/randprocs/markov/integrator/test_integrator.py similarity index 93% rename from tests/test_randprocs/test_markov/test_integrator/test_integrator.py rename to tests/probnum/randprocs/markov/integrator/test_integrator.py index bfaf618f1..3355db22a 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_integrator.py +++ b/tests/probnum/randprocs/markov/integrator/test_integrator.py @@ -1,9 +1,10 @@ import numpy as np -import pytest -from probnum import randprocs, randvars +from probnum import backend, randprocs, randvars from probnum.problems.zoo import linalg as linalg_zoo +import pytest + class TestIntegratorTransition: """An integrator should be usable as is, but its tests are also useful for @@ -101,11 +102,15 @@ def test_same_forward_outputs(both_transitions, diffusion): "both_transitions", [both_transitions_ibm(), both_transitions_ioup(), both_transitions_matern()], ) -def test_same_backward_outputs(both_transitions, diffusion, rng): +def test_same_backward_outputs(both_transitions, diffusion): + rng_state = backend.random.rng_state(3058) + trans1, trans2 = both_transitions real = 1 + 0.1 * np.random.rand(trans1.state_dimension) real2 = 1 + 0.1 * np.random.rand(trans1.state_dimension) - cov = linalg_zoo.random_spd_matrix(rng, dim=trans1.state_dimension) + cov = linalg_zoo.random_spd_matrix( + rng_state, shape=(trans1.state_dimension, trans1.state_dimension) + ) rv = randvars.Normal(real2, cov) out_1, info1 = trans1.backward_realization( real, rv, t=0.0, dt=0.5, compute_gain=True, _diffusion=diffusion diff --git a/tests/test_randprocs/test_markov/test_integrator/test_ioup.py b/tests/probnum/randprocs/markov/integrator/test_ioup.py similarity index 94% rename from tests/test_randprocs/test_markov/test_integrator/test_ioup.py rename to tests/probnum/randprocs/markov/integrator/test_ioup.py index ab59d5fb4..2e1bba274 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_ioup.py +++ b/tests/probnum/randprocs/markov/integrator/test_ioup.py @@ -2,11 +2,12 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov.test_continuous import test_lti_sde -from tests.test_randprocs.test_markov.test_integrator import test_integrator + +import pytest +from tests.probnum.randprocs.markov.continuous import test_lti_sde +from tests.probnum.randprocs.markov.integrator import test_integrator @pytest.mark.parametrize("driftspeed", [-2.0, 0.0, 2.0]) @@ -96,5 +97,5 @@ def _setup( def integrator(self): return self.transition - def test_wiener_process_dimension(self, test_ndim): + def test_wiener_process_dimension(self, state_dim): assert self.transition.wiener_process_dimension == 1 diff --git a/tests/test_randprocs/test_markov/test_integrator/test_iwp.py b/tests/probnum/randprocs/markov/integrator/test_iwp.py similarity index 95% rename from tests/test_randprocs/test_markov/test_integrator/test_iwp.py rename to tests/probnum/randprocs/markov/integrator/test_iwp.py index 0d9251d9d..c4d3cbc27 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_iwp.py +++ b/tests/probnum/randprocs/markov/integrator/test_iwp.py @@ -1,12 +1,13 @@ """Tests for integrated Wiener processes.""" import numpy as np -import pytest -from probnum import config, randprocs, randvars +from probnum import backend, config, randprocs, randvars from probnum.problems.zoo import linalg as linalg_zoo -from tests.test_randprocs.test_markov.test_continuous import test_lti_sde -from tests.test_randprocs.test_markov.test_integrator import test_integrator + +import pytest +from tests.probnum.randprocs.markov.continuous import test_lti_sde +from tests.probnum.randprocs.markov.integrator import test_integrator @pytest.mark.parametrize("initarg", [0.0, 2.0]) @@ -87,7 +88,7 @@ def _setup( def integrator(self): return self.transition - def test_wiener_process_dimension(self, test_ndim): + def test_wiener_process_dimension(self, state_dim): assert self.transition.wiener_process_dimension == 1 def test_discretise_no_force(self): @@ -140,7 +141,7 @@ def _setup( def integrator(self): return self.transition - def test_wiener_process_dimension(self, test_ndim): + def test_wiener_process_dimension(self, state_dim): assert self.transition.wiener_process_dimension == 1 def test_drift(self, some_normal_rv1): @@ -221,8 +222,9 @@ def qh_22_ibm(dt): @pytest.fixture -def spdmat3x3(rng): - return linalg_zoo.random_spd_matrix(rng, dim=3) +def spdmat3x3(): + rng_state = backend.random.rng_state(134) + return linalg_zoo.random_spd_matrix(rng_state=rng_state, shape=(3, 3)) @pytest.fixture @@ -231,7 +233,7 @@ def normal_rv3x3(spdmat3x3): return randvars.Normal( mean=np.random.rand(3), cov=spdmat3x3, - cov_cholesky=np.linalg.cholesky(spdmat3x3), + cache={"cov_cholesky": np.linalg.cholesky(spdmat3x3)}, ) diff --git a/tests/test_randprocs/test_markov/test_integrator/test_matern.py b/tests/probnum/randprocs/markov/integrator/test_matern.py similarity index 94% rename from tests/test_randprocs/test_markov/test_integrator/test_matern.py rename to tests/probnum/randprocs/markov/integrator/test_matern.py index 97ce8a8e6..4362ed067 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_matern.py +++ b/tests/probnum/randprocs/markov/integrator/test_matern.py @@ -2,11 +2,12 @@ import numpy as np -import pytest from probnum import randprocs, randvars -from tests.test_randprocs.test_markov.test_continuous import test_lti_sde -from tests.test_randprocs.test_markov.test_integrator import test_integrator + +import pytest +from tests.probnum.randprocs.markov.continuous import test_lti_sde +from tests.probnum.randprocs.markov.integrator import test_integrator @pytest.mark.parametrize("lengthscale", [-2.0, 2.0]) @@ -91,5 +92,5 @@ def _setup( def integrator(self): return self.transition - def test_wiener_process_dimension(self, test_ndim): + def test_wiener_process_dimension(self, state_dim): assert self.transition.wiener_process_dimension == 1 diff --git a/tests/test_randprocs/test_markov/test_integrator/test_preconditioner.py b/tests/probnum/randprocs/markov/integrator/test_preconditioner.py similarity index 99% rename from tests/test_randprocs/test_markov/test_integrator/test_preconditioner.py rename to tests/probnum/randprocs/markov/integrator/test_preconditioner.py index 8aa9fd6fd..a74224a98 100644 --- a/tests/test_randprocs/test_markov/test_integrator/test_preconditioner.py +++ b/tests/probnum/randprocs/markov/integrator/test_preconditioner.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import randprocs +import pytest + @pytest.fixture def precon(): diff --git a/tests/test_randprocs/test_markov/test_markov_process.py b/tests/probnum/randprocs/markov/test_markov_process.py similarity index 78% rename from tests/test_randprocs/test_markov/test_markov_process.py rename to tests/probnum/randprocs/markov/test_markov_process.py index 78ac7e508..9c9f1c263 100644 --- a/tests/test_randprocs/test_markov/test_markov_process.py +++ b/tests/probnum/randprocs/markov/test_markov_process.py @@ -1,13 +1,13 @@ """Tests for Markov processes.""" import numpy as np -import pytest -from probnum import randprocs, randvars +from probnum import backend, randprocs, randvars + +import pytest def test_bad_args_shape(): - rng = np.random.default_rng(seed=1) time_domain = (0.0, 10.0) time_grid = np.arange(*time_domain) @@ -27,4 +27,6 @@ def test_bad_args_shape(): ) with pytest.raises(ValueError): - prior_process.sample(rng=rng, args=time_grid.reshape(-1, 1)) + prior_process.sample( + rng_state=backend.random.rng_state(1), args=time_grid.reshape(-1, 1) + ) diff --git a/tests/test_randprocs/test_markov/test_transition.py b/tests/probnum/randprocs/markov/test_transition.py similarity index 100% rename from tests/test_randprocs/test_markov/test_transition.py rename to tests/probnum/randprocs/markov/test_transition.py diff --git a/tests/test_randprocs/test_gaussian_process.py b/tests/probnum/randprocs/test_gaussian_process.py similarity index 79% rename from tests/test_randprocs/test_gaussian_process.py rename to tests/probnum/randprocs/test_gaussian_process.py index 885457047..9e18b6c7d 100644 --- a/tests/test_randprocs/test_gaussian_process.py +++ b/tests/probnum/randprocs/test_gaussian_process.py @@ -1,16 +1,16 @@ """Tests for Gaussian processes.""" -import numpy as np -import pytest - -from probnum import functions, randprocs, randvars +from probnum import backend, functions, randprocs, randvars from probnum.randprocs import kernels +import pytest +import tests.utils + def test_mean_not_function_raises_error(): with pytest.raises(TypeError): randprocs.GaussianProcess( - mean=np.zeros_like, + mean=backend.zeros_like, cov=kernels.ExpQuad(input_shape=(1,)), ) @@ -20,7 +20,8 @@ def test_cov_not_kernel_raises_error(): TypeError.""" with pytest.raises(TypeError): randprocs.GaussianProcess( - mean=functions.Zero(input_shape=(1,), output_shape=(1,)), cov=np.dot + mean=functions.Zero(input_shape=(1,), output_shape=(1,)), + cov=lambda x0, x1: backend.exp(-backend.abs(x0 - x1)), ) @@ -55,5 +56,12 @@ def test_mean_wrong_input_shape_raises_error(): def test_finite_evaluation_is_normal(gaussian_process: randprocs.GaussianProcess): """A Gaussian process evaluated at a finite set of inputs is a Gaussian random variable.""" - x = np.random.normal(size=(5,) + gaussian_process.input_shape) + x_shape = (5,) + gaussian_process.input_shape + x = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=98998123, + shape=x_shape, + ), + shape=x_shape, + ) assert isinstance(gaussian_process(x), randvars.Normal) diff --git a/tests/test_randprocs/test_random_process.py b/tests/probnum/randprocs/test_random_process.py similarity index 60% rename from tests/test_randprocs/test_random_process.py rename to tests/probnum/randprocs/test_random_process.py index f7d76a2d2..197154e28 100644 --- a/tests/test_randprocs/test_random_process.py +++ b/tests/probnum/randprocs/test_random_process.py @@ -1,51 +1,77 @@ """Tests for random processes.""" -import numpy as np -import pytest +from probnum import backend, compat, functions, randprocs, randvars +from probnum.backend.typing import ShapeType -from probnum import functions, randprocs, randvars +import pytest +import tests.utils # pylint: disable=invalid-name -def test_output_shape(random_process: randprocs.RandomProcess, args0: np.ndarray): +def test_output_shape( + random_process: randprocs.RandomProcess, + args0: backend.Array, + args0_batch_shape: ShapeType, +): """Test whether evaluations of the random process have the correct shape.""" - expected_shape = args0.shape[:-1] + random_process.output_shape + expected_shape = args0_batch_shape + random_process.output_shape assert random_process(args0).shape == expected_shape -def test_mean_shape(random_process: randprocs.RandomProcess, args0: np.ndarray): +def test_mean_shape( + random_process: randprocs.RandomProcess, + args0: backend.Array, + args0_batch_shape: ShapeType, +): """Test whether the mean of the random process has the correct shape.""" - expected_shape = args0.shape[:-1] + random_process.output_shape + expected_shape = args0_batch_shape + random_process.output_shape assert random_process.mean(args0).shape == expected_shape -def test_var_shape(random_process: randprocs.RandomProcess, args0: np.ndarray): +def test_var_shape( + random_process: randprocs.RandomProcess, + args0: backend.Array, + args0_batch_shape: ShapeType, +): """Test whether the variance of the random process has the correct shape.""" - expected_shape = args0.shape[:-1] + random_process.output_shape + expected_shape = args0_batch_shape + random_process.output_shape assert random_process.var(args0).shape == expected_shape -def test_std_shape(random_process: randprocs.RandomProcess, args0: np.ndarray): +def test_std_shape( + random_process: randprocs.RandomProcess, + args0: backend.Array, + args0_batch_shape: ShapeType, +): """Test whether the standard deviation of the random process has the correct shape.""" - expected_shape = args0.shape[:-1] + random_process.output_shape + expected_shape = args0_batch_shape + random_process.output_shape assert random_process.std(args0).shape == expected_shape -def test_cov_shape(random_process: randprocs.RandomProcess, args0: np.ndarray): +def test_cov_shape( + random_process: randprocs.RandomProcess, + args0: backend.Array, + args0_batch_shape: ShapeType, +): """Test whether the covariance of the random process has the correct shape.""" - n = args0.shape[0] - expected_shape = 2 * random_process.output_shape + (n, n) + expected_shape = 2 * args0_batch_shape + 2 * random_process.output_shape assert random_process.cov.matrix(args0).shape == expected_shape def test_evaluated_random_process_is_random_variable( - random_process: randprocs.RandomProcess, rng: np.random.Generator + random_process: randprocs.RandomProcess, ): """Test whether evaluating a random process returns a random variable.""" - n_inputs_args0 = 10 - args0 = rng.normal(size=(n_inputs_args0,) + random_process.input_shape) + args0_shape = (10,) + random_process.input_shape + args0 = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=98332, + shape=args0_shape, + ), + shape=args0_shape, + ) y0 = random_process(args0) assert isinstance(y0, randvars.RandomVariable), ( @@ -54,40 +80,45 @@ def test_evaluated_random_process_is_random_variable( @pytest.mark.xfail(reason="Not yet implemented for random processes.") -def test_samples_are_callables( - random_process: randprocs.RandomProcess, rng: np.random.Generator -): +def test_samples_are_callables(random_process: randprocs.RandomProcess): """When not specifying inputs to the sample method it should return ``size`` number of callables.""" - assert callable(random_process.sample(rng=rng)) + assert callable(random_process.sample(rng_state=backend.random.rng_state(42))) @pytest.mark.xfail(reason="Not yet implemented for random processes.") def test_sample_paths_are_deterministic_functions( - random_process: randprocs.RandomProcess, args0: np.ndarray + random_process: randprocs.RandomProcess, args0: backend.Array ): """When sampling paths from a random process, repeated evaluation of the sample path at the same inputs should return the same values.""" - sample_path = random_process.sample() - np.testing.assert_array_equal(sample_path(args0), sample_path(args0)) + sample_path = random_process.sample(rng_state=backend.random.rng_state(43)) + compat.testing.assert_array_equal(sample_path(args0), sample_path(args0)) def test_rp_mean_cov_evaluated_matches_rv_mean_cov( - random_process: randprocs.RandomProcess, rng: np.random.Generator + random_process: randprocs.RandomProcess, ): """Check whether the evaluated mean and covariance function of a random process is equivalent to the mean and covariance of the evaluated random process as a random variable.""" - x = rng.normal(size=(10,) + random_process.input_shape) + x_shape = (10,) + random_process.input_shape + x = backend.random.standard_normal( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=98332, + shape=x_shape, + ), + shape=x_shape, + ) - np.testing.assert_allclose( + compat.testing.assert_allclose( random_process(x).mean, random_process.mean(x), err_msg=f"Mean of evaluated {repr(random_process)} does not match the " f"random process mean function evaluated.", ) - np.testing.assert_allclose( + compat.testing.assert_allclose( random_process(x).cov, random_process.cov.matrix(x), err_msg=f"Covariance of evaluated {repr(random_process)} does not match the " @@ -105,8 +136,8 @@ def test_invalid_mean_type_raises(): DummyRandomProcess( input_shape=(), output_shape=(), - dtype=np.double, - mean=np.zeros_like, + dtype=backend.float64, + mean=backend.zeros_like, ) @@ -115,8 +146,8 @@ def test_invalid_cov_type_raises(): DummyRandomProcess( input_shape=(), output_shape=(3,), - dtype=np.double, - cov=lambda x: np.zeros_like( # pylint: disable=unexpected-keyword-arg + dtype=backend.float64, + cov=lambda x: backend.zeros_like( # pylint: disable=unexpected-keyword-arg x, shape=x.shape + (3, 3), ), @@ -128,7 +159,7 @@ def test_inconsistent_mean_shape_errors(): DummyRandomProcess( input_shape=(42,), output_shape=(), - dtype=np.double, + dtype=backend.float64, mean=functions.Zero( input_shape=(3,), output_shape=(3,), @@ -139,7 +170,7 @@ def test_inconsistent_mean_shape_errors(): DummyRandomProcess( input_shape=(), output_shape=(1,), - dtype=np.double, + dtype=backend.float64, mean=functions.Zero( input_shape=(), output_shape=(3,), @@ -152,7 +183,7 @@ def test_inconsistent_cov_shape_errors(): DummyRandomProcess( input_shape=(42,), output_shape=(), - dtype=np.double, + dtype=backend.float64, cov=randprocs.kernels.ExpQuad( input_shape=(3,), ), @@ -162,7 +193,7 @@ def test_inconsistent_cov_shape_errors(): DummyRandomProcess( input_shape=(), output_shape=(1,), - dtype=np.double, + dtype=backend.float64, cov=randprocs.kernels.ExpQuad( input_shape=(), ), diff --git a/tests/probnum/randvars/__init__.py b/tests/probnum/randvars/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/arithmetic/__init__.py b/tests/probnum/randvars/arithmetic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/arithmetic/operand_generators.py b/tests/probnum/randvars/arithmetic/operand_generators.py new file mode 100644 index 000000000..29b14bc75 --- /dev/null +++ b/tests/probnum/randvars/arithmetic/operand_generators.py @@ -0,0 +1,42 @@ +from typing import Callable, Union + +from probnum import backend, randvars +from probnum.backend.typing import ShapeType +from probnum.problems.zoo.linalg import random_spd_matrix + +import tests.utils + +GeneratorFnType = Callable[[ShapeType], Union[randvars.RandomVariable, backend.Array]] + + +def array_generator(shape: ShapeType) -> backend.Array: + return 3.0 * backend.random.standard_normal( + tests.utils.random.rng_state_from_sampling_args( + base_seed=561562, + shape=shape, + ), + shape=shape, + ) + + +def constant_generator(shape: ShapeType) -> randvars.Constant: + return randvars.Constant(array_generator(shape)) + + +def normal_generator(shape: ShapeType) -> randvars.Normal: + rng_state_mean, rng_state_cov = backend.random.split( + tests.utils.random.rng_state_from_sampling_args( + base_seed=561562, + shape=shape, + ), + num=2, + ) + + mean = 5.0 * backend.random.standard_normal(rng_state_mean, shape=shape) + + return randvars.Normal( + mean=mean, + cov=random_spd_matrix( + rng_state_cov, shape=() if mean.shape == () else (mean.size, mean.size) + ), + ) diff --git a/tests/probnum/randvars/arithmetic/test_addition.py b/tests/probnum/randvars/arithmetic/test_addition.py new file mode 100644 index 000000000..3bc0c9351 --- /dev/null +++ b/tests/probnum/randvars/arithmetic/test_addition.py @@ -0,0 +1,206 @@ +import operator +from typing import Any, Callable, Tuple, Type, Union + +from probnum import backend, compat, randvars +from probnum.backend.typing import ShapeType + +from .operand_generators import ( + GeneratorFnType, + array_generator, + constant_generator, + normal_generator, +) + +import pytest +from pytest_cases import fixture, parametrize + + +@fixture(scope="package") +@parametrize( + shapes_=[ + ((), ()), + ((1,), (1,)), + ((4,), (4,)), + ((2, 3), (2, 3)), + ((2, 3, 2), (2, 3, 2)), + # ((3,), ()), # This is broken if the `Normal` random variable has fewer + # entries. + # ((3, 1), (1, 4)), # This is broken if `Normal`s are involved + ] +) +def shapes(shapes_: Tuple[ShapeType, ShapeType]) -> Tuple[ShapeType, ShapeType]: + return shapes_ + + +OperandType = Union[randvars.RandomVariable, backend.Array] + + +@fixture(scope="package") +@parametrize( + operator_operands_and_expected_result_type_=[ + (operator.add, constant_generator, constant_generator, randvars.Constant), + (operator.sub, constant_generator, constant_generator, randvars.Constant), + (operator.add, constant_generator, array_generator, randvars.Constant), + (operator.sub, constant_generator, array_generator, randvars.Constant), + (operator.add, array_generator, constant_generator, randvars.Constant), + (operator.sub, array_generator, constant_generator, randvars.Constant), + (operator.add, normal_generator, normal_generator, randvars.Normal), + (operator.sub, normal_generator, normal_generator, randvars.Normal), + (operator.add, normal_generator, constant_generator, randvars.Normal), + (operator.sub, normal_generator, constant_generator, randvars.Normal), + (operator.add, constant_generator, normal_generator, randvars.Normal), + (operator.sub, constant_generator, normal_generator, randvars.Normal), + (operator.add, normal_generator, array_generator, randvars.Normal), + (operator.sub, normal_generator, array_generator, randvars.Normal), + (operator.add, array_generator, normal_generator, randvars.Normal), + (operator.sub, array_generator, normal_generator, randvars.Normal), + ], +) +def operator_operands_and_expected_result_type( + shapes: Tuple[ShapeType, ShapeType], + operator_operands_and_expected_result_type_: Tuple[ + Callable[[Any, Any], Any], + GeneratorFnType, + GeneratorFnType, + Type[randvars.RandomVariable], + ], +) -> Tuple[ + Callable[[Any, Any], Any], + OperandType, + OperandType, + Type[randvars.RandomVariable], +]: + shape0, shape1 = shapes + + ( + operator, + generator0, + generator1, + expected_result_type, + ) = operator_operands_and_expected_result_type_ + + return operator, generator0(shape0), generator1(shape1), expected_result_type + + +@fixture(scope="package") +def operator( + operator_operands_and_expected_result_type: Tuple[ + Callable[[Any, Any], Any], + OperandType, + OperandType, + Type[randvars.RandomVariable], + ] +) -> Callable[[Any, Any], Any]: + return operator_operands_and_expected_result_type[0] + + +@fixture(scope="package") +def operand0( + operator_operands_and_expected_result_type: Tuple[ + Callable[[Any, Any], Any], + OperandType, + OperandType, + Type[randvars.RandomVariable], + ] +) -> OperandType: + return operator_operands_and_expected_result_type[1] + + +@fixture(scope="package") +def operand1( + operator_operands_and_expected_result_type: Tuple[ + Callable[[Any, Any], Any], + OperandType, + OperandType, + Type[randvars.RandomVariable], + ] +) -> OperandType: + return operator_operands_and_expected_result_type[2] + + +@fixture(scope="package") +def expected_result_type( + operator_operands_and_expected_result_type: Tuple[ + Callable[[Any, Any], Any], + OperandType, + OperandType, + Type[randvars.RandomVariable], + ] +) -> Type[randvars.RandomVariable]: + return operator_operands_and_expected_result_type[3] + + +@fixture(scope="package") +def result( + operator: Callable[[Any, Any], Any], + operand0: OperandType, + operand1: OperandType, +) -> randvars.RandomVariable: + return operator(operand0, operand1) + + +def test_type( + result: randvars.RandomVariable, expected_result_type: Callable[[Any, Any], Any] +): + assert isinstance(result, expected_result_type) + + +def test_shape( + operand0: OperandType, + operand1: OperandType, + result: randvars.RandomVariable, +): + if not isinstance(operand0, randvars.RandomVariable): + operand0 = randvars.asrandvar(operand0) + + if not isinstance(operand1, randvars.RandomVariable): + operand1 = randvars.asrandvar(operand1) + + expected_shape = backend.broadcast_shapes(operand0.shape, operand1.shape) + assert result.shape == expected_shape + + +def test_mean( + operator: Callable[[Any, Any], Any], + operand0: OperandType, + operand1: OperandType, + result: randvars.RandomVariable, +): + if not isinstance(operand0, randvars.RandomVariable): + operand0 = randvars.asrandvar(operand0) + + if not isinstance(operand1, randvars.RandomVariable): + operand1 = randvars.asrandvar(operand1) + + try: + mean0 = operand0.mean + mean1 = operand1.mean + except NotImplementedError: + pytest.skip() + + compat.testing.assert_allclose(result.mean, operator(mean0, mean1)) + + +def test_cov( + operand0: OperandType, + operand1: OperandType, + result: randvars.RandomVariable, +): + if not isinstance(operand0, randvars.RandomVariable): + operand0 = randvars.asrandvar(operand0) + + if not isinstance(operand1, randvars.RandomVariable): + operand1 = randvars.asrandvar(operand1) + + try: + cov0 = operand0.cov + cov1 = operand1.cov + except NotImplementedError: + pytest.skip() + + expected_cov = ( + cov0.reshape(operand0.shape + operand0.shape) + + cov1.reshape(operand1.shape + operand1.shape) + ).reshape(result.cov.shape) + + compat.testing.assert_allclose(result.cov, expected_cov) diff --git a/tests/probnum/randvars/arithmetic/test_const_matmul.py b/tests/probnum/randvars/arithmetic/test_const_matmul.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/arithmetic/test_const_multiplication.py b/tests/probnum/randvars/arithmetic/test_const_multiplication.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/normal/__init__.py b/tests/probnum/randvars/normal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/normal/cases.py b/tests/probnum/randvars/normal/cases.py new file mode 100644 index 000000000..e59b05503 --- /dev/null +++ b/tests/probnum/randvars/normal/cases.py @@ -0,0 +1,131 @@ +"""Test cases defining random variables with a normal distribution.""" + +from probnum import backend, linops, randvars +from probnum.backend.typing import ScalarLike, ShapeType +from probnum.problems.zoo.linalg import random_spd_matrix +from probnum.typing import MatrixType + +from pytest_cases import case, parametrize +import tests.utils + + +@case(tags=["scalar"]) +@parametrize(mean=[0.0, -1.0, 4]) +@parametrize(var=[3.0, 2]) +def case_scalar(mean: ScalarLike, var: ScalarLike) -> randvars.Normal: + return randvars.Normal(mean, var) + + +@case(tags=["scalar", "degenerate", "constant"]) +@parametrize(mean=[0.0, 12.23]) +def case_scalar_constant(mean: ScalarLike) -> randvars.Normal: + return randvars.Normal(mean=mean, cov=0.0) + + +@case(tags=["vector"]) +@parametrize(shape=[(1,), (2,), (5,), (10,)]) +def case_vector(shape: ShapeType) -> randvars.Normal: + rng_state_mean, rng_state_cov = backend.random.split( + tests.utils.random.rng_state_from_sampling_args( + base_seed=654, + shape=shape, + ), + num=2, + ) + + return randvars.Normal( + mean=5.0 * backend.random.standard_normal(rng_state_mean, shape=shape), + cov=random_spd_matrix(rng_state_cov, shape=(shape[0], shape[0])), + ) + + +@case(tags=["vector", "diag-cov"]) +@parametrize( + cov=[backend.eye(7, dtype=backend.float32), linops.Scaling(2.7, shape=(20, 20))], + ids=["backend.eye", "linops.Scaling"], +) +def case_vector_diag_cov(cov: MatrixType) -> randvars.Normal: + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=12390, + shape=cov.shape, + dtype=cov.dtype, + ) + + return randvars.Normal( + mean=3.1 * backend.random.standard_normal(rng_state, shape=cov.shape[0]), + cov=cov, + ) + + +@case(tags=["degenerate", "constant", "vector"]) +@parametrize( + cov=[backend.zeros, linops.Zero], ids=["cov=backend.zeros", "cov=linops.Zero"] +) +@parametrize(shape=[(3,)]) +def case_vector_zero_cov(cov: MatrixType, shape: ShapeType) -> randvars.Normal: + rng_state_mean = tests.utils.random.rng_state_from_sampling_args( + base_seed=624, + shape=shape, + ) + mean = backend.random.standard_normal(shape=shape, rng_state=rng_state_mean) + return randvars.Normal(mean=mean, cov=cov(shape=2 * shape)) + + +@case(tags=["matrix"]) +@parametrize(shape=[(1, 1), (5, 1), (1, 4), (2, 2), (3, 4)]) +def case_matrix(shape: ShapeType) -> randvars.Normal: + rng_state_mean, rng_state_cov = backend.random.split( + tests.utils.random.rng_state_from_sampling_args( + base_seed=453987, + shape=shape, + ), + num=2, + ) + + return randvars.Normal( + mean=4.0 * backend.random.standard_normal(rng_state_mean, shape=shape), + cov=random_spd_matrix( + rng_state_cov, shape=(shape[0] * shape[1], shape[0] * shape[1]) + ), + ) + + +@case(tags=["matrix", "mean-op", "cov-op"]) +@parametrize(shape=[(1, 1), (2, 1), (1, 3), (2, 2)]) +def case_matrix_mean_op_kronecker_cov(shape: ShapeType) -> randvars.Normal: + rng_state_mean, rng_state_cov_A, rng_state_cov_B = backend.random.split( + tests.utils.random.rng_state_from_sampling_args( + base_seed=421376, + shape=shape, + ), + num=3, + ) + + cov = linops.Kronecker( + A=random_spd_matrix(rng_state_cov_A, shape=(shape[0], shape[0])), + B=random_spd_matrix(rng_state_cov_B, shape=(shape[1], shape[1])), + ) + cov.is_symmetric = True + cov.A.is_symmetric = True + cov.B.is_symmetric = True + + return randvars.Normal( + mean=linops.aslinop( + backend.random.standard_normal(rng_state_mean, shape=shape) + ), + cov=cov, + ) + + +@case(tags=["degenerate", "constant", "matrix", "cov-op"]) +@parametrize(shape=[(2, 3)]) +def case_matrix_zero_cov(shape: ShapeType) -> randvars.Normal: + rng_state_mean = tests.utils.random.rng_state_from_sampling_args( + base_seed=624, + shape=shape, + ) + mean = backend.random.standard_normal(shape=shape, rng_state=rng_state_mean) + cov = linops.Kronecker( + linops.Zero(shape=(shape[0], shape[0])), linops.Zero(shape=(shape[1], shape[1])) + ) + return randvars.Normal(mean=mean, cov=cov) diff --git a/tests/probnum/randvars/normal/test_cholesky_updates.py b/tests/probnum/randvars/normal/test_cholesky_updates.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/normal/test_compare_scipy.py b/tests/probnum/randvars/normal/test_compare_scipy.py new file mode 100644 index 000000000..5acd838bb --- /dev/null +++ b/tests/probnum/randvars/normal/test_compare_scipy.py @@ -0,0 +1,119 @@ +"""Test properties of normal random variables.""" + +import scipy.stats + +from probnum import Backend, backend, compat, randvars +from probnum.backend.typing import ShapeType + +import pytest +from pytest_cases import filters, parametrize, parametrize_with_cases +import tests.utils + + +@parametrize_with_cases( + "rv", + cases=".cases", + filter=filters.has_tag("scalar") & ~filters.has_tag("degenerate"), +) +def test_entropy(rv: randvars.Normal): + scipy_entropy = scipy.stats.norm.entropy( + loc=backend.to_numpy(rv.mean), + scale=backend.to_numpy(rv.std), + ) + + compat.testing.assert_allclose(rv.entropy, scipy_entropy) + + +@parametrize_with_cases( + "rv", + cases=".cases", + filter=filters.has_tag("scalar") & ~filters.has_tag("degenerate"), +) +@parametrize("shape", ([(), (1,), (5,), (2, 3), (3, 1, 2)])) +def test_pdf_scalar(rv: randvars.Normal, shape: ShapeType): + x = backend.random.standard_normal( + tests.utils.random.rng_state_from_sampling_args(base_seed=245, shape=shape), + shape=shape, + dtype=rv.dtype, + ) + + scipy_pdf = scipy.stats.norm.pdf( + backend.to_numpy(x), + loc=backend.to_numpy(rv.mean), + scale=backend.to_numpy(rv.std), + ) + + compat.testing.assert_allclose(rv.pdf(x), scipy_pdf) + + +@parametrize_with_cases( + "rv", + cases=".cases", + filter=( + (filters.has_tag("vector") | filters.has_tag("matrix")) + & ~filters.has_tag("degenerate") + ), +) +@parametrize("shape", ((), (1,), (5,), (2, 3), (3, 1, 2))) +def test_pdf_multivariate(rv: randvars.Normal, shape: ShapeType): + x = rv.sample( + tests.utils.random.rng_state_from_sampling_args(base_seed=65465, shape=shape), + sample_shape=shape, + ) + + scipy_pdf = scipy.stats.multivariate_normal.pdf( + backend.to_numpy(x.reshape(shape + (-1,))), + mean=backend.to_numpy(rv.dense_mean.reshape(-1)), + cov=backend.to_numpy(rv.dense_cov), + ) + + # There is a bug in scipy's implementation of the pdf for the multivariate normal: + expected_shape = x.shape[: x.ndim - rv.ndim] + + if any(dim == 1 for dim in expected_shape): + # scipy's implementation happily squeezes `1` dimensions out of the batch + assert all(dim != 1 for dim in scipy_pdf.shape) + + scipy_pdf = scipy_pdf.reshape(expected_shape) + + compat.testing.assert_allclose(rv.pdf(x), scipy_pdf) + + +@pytest.mark.skipif_backend(Backend.JAX) +@pytest.mark.skipif_backend(Backend.TORCH) +@parametrize_with_cases( + "rv", + cases=".cases", + filter=( + (filters.has_tag("vector") | filters.has_tag("matrix")) + & ~filters.has_tag("degenerate") + ), +) +@parametrize("shape", ((), (1,), (5,), (2, 3), (3, 1, 2))) +def test_cdf_multivariate(rv: randvars.Normal, shape: ShapeType): + scipy_rv = scipy.stats.multivariate_normal( + mean=backend.to_numpy(rv.dense_mean.reshape(-1)), + cov=backend.to_numpy(rv.dense_cov), + ) + + x = rv.sample( + tests.utils.random.rng_state_from_sampling_args(base_seed=978134, shape=shape), + sample_shape=shape, + ) + + cdf = rv.cdf(x) + + scipy_cdf = scipy_rv.cdf(backend.to_numpy(x.reshape(shape + (-1,)))) + + # There is a bug in scipy's implementation of the pdf for the multivariate normal: + expected_shape = x.shape[: x.ndim - rv.ndim] + + if any(dim == 1 for dim in expected_shape): + # scipy's implementation happily squeezes `1` dimensions out of the batch + assert all(dim != 1 for dim in scipy_cdf.shape) + + scipy_cdf = scipy_cdf.reshape(expected_shape) + + compat.testing.assert_allclose( + cdf, scipy_cdf, atol=scipy_rv.abseps, rtol=scipy_rv.releps + ) diff --git a/tests/probnum/randvars/normal/test_construction.py b/tests/probnum/randvars/normal/test_construction.py new file mode 100644 index 000000000..686502d1f --- /dev/null +++ b/tests/probnum/randvars/normal/test_construction.py @@ -0,0 +1,19 @@ +"""Test the construction of Normal random variables.""" +from probnum import backend, randvars +from probnum.backend.typing import ShapeType + +import pytest +from pytest_cases import parametrize +import tests.utils + + +@parametrize(shape=[(), (3,), (2, 2)]) +def test_mean_cov_shape_mismatch(shape: ShapeType): + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=54784, shape=shape + ) + mean = backend.random.standard_normal(rng_state, shape=shape) + cov = backend.eye(10) + + with pytest.raises(ValueError): + randvars.Normal(mean=mean, cov=cov) diff --git a/tests/probnum/randvars/normal/test_sampling.py b/tests/probnum/randvars/normal/test_sampling.py new file mode 100644 index 000000000..9ccdca57a --- /dev/null +++ b/tests/probnum/randvars/normal/test_sampling.py @@ -0,0 +1,53 @@ +from probnum import backend, compat, randvars +from probnum.backend.typing import ShapeLike, ShapeType + +from pytest_cases import fixture, parametrize, parametrize_with_cases +import tests.utils + + +@fixture(scope="module") +@parametrize(shape=[(), 3, (1,), (1, 1), (2, 3, 2)]) +def sample_shape_arg(shape: ShapeLike) -> ShapeLike: + return shape + + +@fixture(scope="module") +def sample_shape(sample_shape_arg: ShapeLike) -> ShapeType: + return backend.asshape(sample_shape_arg) + + +@fixture(scope="module") +@parametrize_with_cases("rv_", cases=".cases", scope="module") +def rv(rv_: randvars.Normal) -> randvars.Normal: + return rv_ + + +@fixture(scope="module") +def samples( + rv: randvars.Normal, sample_shape_arg: ShapeLike, sample_shape: ShapeType +) -> backend.Array: + return rv.sample( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=9879, + shape=sample_shape + rv.shape, + ), + sample_shape=sample_shape_arg, + ) + + +def test_sample_shape( + samples: backend.Array, rv: randvars.Normal, sample_shape: ShapeType +): + assert samples.shape == sample_shape + rv.shape + + +@parametrize_with_cases("rv_constant", cases=".cases", has_tag=["constant"]) +def test_sample_constant(rv_constant: randvars.Normal): + sample = rv_constant.sample( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=2346, + shape=rv_constant.shape, + ) + ) + + compat.testing.assert_allclose(sample, rv_constant.mean) diff --git a/tests/probnum/randvars/test_getitem.py b/tests/probnum/randvars/test_getitem.py new file mode 100644 index 000000000..bc483894b --- /dev/null +++ b/tests/probnum/randvars/test_getitem.py @@ -0,0 +1,211 @@ +import functools +from typing import Tuple + +import numpy as np + +from probnum import backend, compat, linops, randvars +from probnum.backend.typing import ArrayIndicesLike, ShapeType +from probnum.problems.zoo.linalg import random_spd_matrix + +import pytest +from pytest_cases import THIS_MODULE, case, fixture, parametrize, parametrize_with_cases +import tests.utils + + +@case(tags=["normal"]) +@parametrize( + shape_and_getitem_arg=[ + # Indexing + [(), ()], + [(1,), 0], + [(2,), -1], + [(4, 5), 2], + [(3, 2), (0, 1)], + [(2,), None], + # Slicing + [(4,), slice(1, 4)], + [(2, 3), (slice(1, 2), slice(0, 3, 2))], + [(3,), slice(-1, -3, -2)], + # Advanced Indexing + ((3, 4), ([2, 0], [3, 0])), + ((3, 4), ([[2, 1]], [[3], [1], [2], [0]])), # broadcasting to (4, 2) + # Masking + ((1,), True), + ((2, 3), np.array([False, True])), + ( + (2, 3), + np.array( + [ + [True, True, False], + [False, True, False], + ] + ), + ), + ] +) +@parametrize(cov_linop=[False, True]) +def case_normal( + shape_and_getitem_arg: Tuple[ShapeType, ArrayIndicesLike], cov_linop: bool +) -> Tuple[randvars.Normal, ArrayIndicesLike]: + shape, getitem_arg = shape_and_getitem_arg + + # Generate `Normal` random variable with random parameters + mean_rng_state, cov_rng_state = backend.random.split( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=98723, + shape=shape, + ), + num=2, + ) + + mean = backend.random.standard_normal(rng_state=mean_rng_state, shape=shape) + cov = random_spd_matrix( + rng_state=cov_rng_state, shape=() if shape == () else 2 * (mean.size,) + ) + + if cov_linop: + if shape == (): + pytest.skip("`LinearOperator`s don't support scalar shapes") + + cov = linops.aslinop(cov) + + rv = randvars.Normal(mean, cov) + + return rv, getitem_arg + + +@fixture(scope="module") +@parametrize_with_cases("rv_,getitem_arg_", cases=THIS_MODULE, scope="module") +def rv_and_getitem_arg( + rv_: randvars.Normal, getitem_arg_: ArrayIndicesLike +) -> Tuple[randvars.Normal, ArrayIndicesLike]: + return rv_, getitem_arg_ + + +@fixture(scope="module") +def rv(rv_and_getitem_arg: Tuple[randvars.Normal, ArrayIndicesLike]) -> randvars.Normal: + return rv_and_getitem_arg[0] + + +@fixture(scope="module") +def getitem_arg( + rv_and_getitem_arg: Tuple[randvars.Normal, ArrayIndicesLike], +) -> ArrayIndicesLike: + return rv_and_getitem_arg[1] + + +@fixture(scope="module") +def getitem_rv(rv: randvars.Normal, getitem_arg: ArrayIndicesLike): + return rv[getitem_arg] + + +def test_shape( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + expected_shape = backend.zeros(rv.shape)[getitem_arg].shape + + assert getitem_rv.shape == expected_shape + + +def test_sample_shape( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + expected_shape = backend.zeros(rv.shape)[getitem_arg].shape + + sample = getitem_rv.sample( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=123897, shape=expected_shape + ) + ) + + assert sample.shape == expected_shape + + +def test_mean( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + compat.testing.assert_array_equal(getitem_rv.mean, rv.mean[getitem_arg]) + + +def test_var( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + compat.testing.assert_array_equal(getitem_rv.var, rv.var[getitem_arg]) + compat.testing.assert_array_equal(getitem_rv.mean, rv.mean[getitem_arg]) + + +def test_std( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + compat.testing.assert_array_equal(getitem_rv.std, rv.std[getitem_arg]) + + +def test_cov( + rv: randvars.Normal, + getitem_arg: ArrayIndicesLike, + getitem_rv: randvars.RandomVariable, +): + # Create tensor, wich contains indices as elements + if rv.ndim > 0: + index_array = np.stack( + np.meshgrid( + *(np.arange(0, dim) for dim in rv.shape), + indexing="ij", + ), + axis=-1, + ) + + @functools.partial(np.vectorize, otypes=[np.object_], signature="(d)->()") + def _make_index_objects(idcs: np.ndarray): + return list(int(idx) for idx in idcs) + + index_array = _make_index_objects(index_array) + else: + index_array = np.empty(shape=(), dtype=np.object_) + index_array[()] = [] + + # Select indices according to `getitem_arg` + getitem_idx_to_original_idx = index_array[getitem_arg] + + # "Unravel" original covariance + dense_cov = ( + rv.cov.todense() if isinstance(rv.cov, linops.LinearOperator) else rv.cov + ) + + cov_unraveled = dense_cov.reshape(rv.shape + rv.shape, order="C") + + if isinstance(getitem_idx_to_original_idx, list): + # __getitem__ returned a scalar random variable + assert getitem_rv.cov.shape == () + + cov_unraveled_idx = tuple( + getitem_idx_to_original_idx + getitem_idx_to_original_idx + ) + + assert getitem_rv.cov[()] == cov_unraveled[cov_unraveled_idx] + else: + # __getitem__ returned a multi-dimensional random variable + + # Row-vectorization of indices + raveled_getitem_idx_to_original_idx = getitem_idx_to_original_idx.reshape( + -1, order="C" + ) + + for i in range(getitem_rv.cov.shape[0]): + for j in range(getitem_rv.cov.shape[1]): + cov_unraveled_idx = tuple( + raveled_getitem_idx_to_original_idx[i] + + raveled_getitem_idx_to_original_idx[j] + ) + + assert getitem_rv.cov[i, j] == cov_unraveled[cov_unraveled_idx] diff --git a/tests/probnum/randvars/test_reshape.py b/tests/probnum/randvars/test_reshape.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/test_shapes.py b/tests/probnum/randvars/test_shapes.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/probnum/randvars/test_sym_matrix_normal.py b/tests/probnum/randvars/test_sym_matrix_normal.py new file mode 100644 index 000000000..a400734fb --- /dev/null +++ b/tests/probnum/randvars/test_sym_matrix_normal.py @@ -0,0 +1,70 @@ +from probnum import backend, compat, linops, randvars +from probnum.backend.typing import ShapeLike, ShapeType +from probnum.problems.zoo.linalg import random_spd_matrix + +from pytest_cases import THIS_MODULE, case, fixture, parametrize, parametrize_with_cases +import tests.utils + + +@case(tags=["symmetric-matrix"]) +@parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)]) +def case_symmetric_matrix(shape: ShapeType) -> randvars.SymmetricMatrixNormal: + rng_state_mean, rng_state_cov = backend.random.split( + tests.utils.random.rng_state_from_sampling_args( + base_seed=453987, + shape=shape, + ), + num=2, + ) + + assert shape[0] == shape[1] + + return randvars.SymmetricMatrixNormal( + mean=random_spd_matrix(rng_state_mean, shape=(shape[0], shape[0])), + cov=linops.SymmetricKronecker( + random_spd_matrix(rng_state_cov, shape=(shape[0], shape[0])) + ), + ) + + +@fixture(scope="module") +@parametrize(shape=[(), 3, (1,), (1, 1), (2, 1, 3)]) +def sample_shape_arg(shape: ShapeLike) -> ShapeLike: + return shape + + +@fixture(scope="module") +def sample_shape(sample_shape_arg: ShapeLike) -> ShapeType: + return backend.asshape(sample_shape_arg) + + +@fixture(scope="module") +@parametrize_with_cases("rv_", cases=THIS_MODULE, scope="module") +def rv(rv_: randvars.Normal) -> randvars.Normal: + return rv_ + + +@fixture(scope="module") +def samples( + rv: randvars.Normal, sample_shape_arg: ShapeLike, sample_shape: ShapeType +) -> backend.Array: + return rv.sample( + rng_state=tests.utils.random.rng_state_from_sampling_args( + base_seed=355231, + shape=sample_shape + rv.shape, + ), + sample_shape=sample_shape_arg, + ) + + +def test_sample_shape( + samples: backend.Array, rv: randvars.Normal, sample_shape: ShapeType +): + assert samples.shape == sample_shape + rv.shape + + +def test_samples_symmetric(samples: backend.Array): + compat.testing.assert_array_equal( + backend.swap_axes(samples, -2, -1), + samples, + ) diff --git a/tests/probnum/randvars/test_transpose.py b/tests/probnum/randvars/test_transpose.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_config.py b/tests/probnum/test_config.py similarity index 100% rename from tests/test_config.py rename to tests/probnum/test_config.py index 4775d377b..d63696503 100644 --- a/tests/test_config.py +++ b/tests/probnum/test_config.py @@ -1,8 +1,8 @@ -import pytest - import probnum from probnum._config import _DEFAULT_CONFIG_OPTIONS +import pytest + def test_defaults(): none_vals = {key: None for (key, _, _) in _DEFAULT_CONFIG_OPTIONS} diff --git a/tests/test_diffeq/test_callbacks/test_discrete_callback.py b/tests/test_diffeq/test_callbacks/test_discrete_callback.py index a4f984956..1eb57a916 100644 --- a/tests/test_diffeq/test_callbacks/test_discrete_callback.py +++ b/tests/test_diffeq/test_callbacks/test_discrete_callback.py @@ -3,9 +3,9 @@ import dataclasses -import pytest - from probnum import diffeq + +import pytest from tests.test_diffeq.test_callbacks import _callback_test_interface diff --git a/tests/test_diffeq/test_odefilter/test_approx_strategies/_approx_test_interface.py b/tests/test_diffeq/test_odefilter/test_approx_strategies/_approx_test_interface.py index 0575aec01..1fd2ae7d8 100644 --- a/tests/test_diffeq/test_odefilter/test_approx_strategies/_approx_test_interface.py +++ b/tests/test_diffeq/test_odefilter/test_approx_strategies/_approx_test_interface.py @@ -2,10 +2,10 @@ import abc -import pytest - from probnum.problems.zoo import diffeq as diffeq_zoo +import pytest + class ApproximationStrategyTest(abc.ABC): @abc.abstractmethod diff --git a/tests/test_diffeq/test_odefilter/test_approx_strategies/test_ek.py b/tests/test_diffeq/test_odefilter/test_approx_strategies/test_ek.py index 1d1c3bb39..8869890e8 100644 --- a/tests/test_diffeq/test_odefilter/test_approx_strategies/test_ek.py +++ b/tests/test_diffeq/test_odefilter/test_approx_strategies/test_ek.py @@ -1,9 +1,10 @@ """Tests for EK0/1.""" import numpy as np -import pytest from probnum import diffeq, filtsmooth + +import pytest from tests.test_diffeq.test_odefilter.test_approx_strategies import ( _approx_test_interface, ) diff --git a/tests/test_diffeq/test_odefilter/test_information_operators/_information_operator_test_inferface.py b/tests/test_diffeq/test_odefilter/test_information_operators/_information_operator_test_inferface.py index 52f8e6613..b68b9cd07 100644 --- a/tests/test_diffeq/test_odefilter/test_information_operators/_information_operator_test_inferface.py +++ b/tests/test_diffeq/test_odefilter/test_information_operators/_information_operator_test_inferface.py @@ -2,10 +2,10 @@ import abc -import pytest - from probnum.problems.zoo import diffeq as diffeq_zoo +import pytest + class InformationOperatorTest(abc.ABC): @abc.abstractmethod diff --git a/tests/test_diffeq/test_odefilter/test_information_operators/test_ode_residual.py b/tests/test_diffeq/test_odefilter/test_information_operators/test_ode_residual.py index 6f78d3099..8d988e212 100644 --- a/tests/test_diffeq/test_odefilter/test_information_operators/test_ode_residual.py +++ b/tests/test_diffeq/test_odefilter/test_information_operators/test_ode_residual.py @@ -1,9 +1,10 @@ """Test for ODE residual information operator.""" import numpy as np -import pytest from probnum import diffeq, randprocs, randvars + +import pytest from tests.test_diffeq.test_odefilter.test_information_operators import ( _information_operator_test_inferface, ) @@ -53,7 +54,7 @@ def test_as_transition(self, fitzhughnagumo): noise = transition.noise_fun(0.0) assert isinstance(transition, randprocs.markov.discrete.NonlinearGaussian) assert np.linalg.norm(noise.cov) > 0.0 - assert np.linalg.norm(noise.cov_cholesky) > 0.0 + assert np.linalg.norm(noise._cov_cholesky) > 0.0 def test_incorporate_ode(self, fitzhughnagumo): self.info_op.incorporate_ode(ode=fitzhughnagumo) diff --git a/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines.py b/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines.py index a6d81d834..2c8d3eb09 100644 --- a/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines.py +++ b/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines.py @@ -2,11 +2,12 @@ import numpy as np -import pytest -import pytest_cases from probnum import randprocs +import pytest +import pytest_cases + try: from jax.config import config # speed... diff --git a/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines_cases.py b/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines_cases.py index 76a5cd4a6..eeed8a639 100644 --- a/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines_cases.py +++ b/tests/test_diffeq/test_odefilter/test_init_routines/test_init_routines_cases.py @@ -1,12 +1,12 @@ """Test cases for initialization.""" -import pytest_cases - from probnum.diffeq.odefilter import init_routines from probnum.problems.zoo import diffeq as diffeq_zoo from . import known_initial_derivatives +import pytest_cases + try: from jax.config import config # speed... diff --git a/tests/test_diffeq/test_odefilter/test_odefilter.py b/tests/test_diffeq/test_odefilter/test_odefilter.py index 6d99e2ddd..b8f62f996 100644 --- a/tests/test_diffeq/test_odefilter/test_odefilter.py +++ b/tests/test_diffeq/test_odefilter/test_odefilter.py @@ -1,11 +1,11 @@ """Tests for ODE filters.""" +from probnum import diffeq, randprocs + import pytest import pytest_cases -from probnum import diffeq, randprocs - try: import jax as _ diff --git a/tests/test_diffeq/test_odefilter/test_odefilter_cases.py b/tests/test_diffeq/test_odefilter/test_odefilter_cases.py index f73d3014c..62e2cf1b8 100644 --- a/tests/test_diffeq/test_odefilter/test_odefilter_cases.py +++ b/tests/test_diffeq/test_odefilter/test_odefilter_cases.py @@ -1,11 +1,11 @@ """Test-cases for ODE filters.""" -import pytest_cases - from probnum import diffeq, randprocs import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest_cases + # logistic.rhs is implemented backend-agnostic, # thus it works for both numpy and jax diff --git a/tests/test_diffeq/test_odefilter/test_odefilter_solution.py b/tests/test_diffeq/test_odefilter/test_odefilter_solution.py index 4c7f7cab9..dce0f6fd9 100644 --- a/tests/test_diffeq/test_odefilter/test_odefilter_solution.py +++ b/tests/test_diffeq/test_odefilter/test_odefilter_solution.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import diffeq, randvars import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + @pytest.fixture def rng(): diff --git a/tests/test_diffeq/test_odefilter/test_odefilter_special.py b/tests/test_diffeq/test_odefilter/test_odefilter_special.py index e2c3a728e..8b3df2092 100644 --- a/tests/test_diffeq/test_odefilter/test_odefilter_special.py +++ b/tests/test_diffeq/test_odefilter/test_odefilter_special.py @@ -4,11 +4,12 @@ but the implementation. Therefore this test module is named w.r.t. ivpfiltsmooth.py. """ import numpy as np -import pytest from probnum import diffeq, randprocs import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + @pytest.fixture(name="ivp") def fixture_ivp(): diff --git a/tests/test_diffeq/test_odefilter/test_utils/test_problem_utils.py b/tests/test_diffeq/test_odefilter/test_utils/test_problem_utils.py index 317f1e22c..c7d6dcb97 100644 --- a/tests/test_diffeq/test_odefilter/test_utils/test_problem_utils.py +++ b/tests/test_diffeq/test_odefilter/test_utils/test_problem_utils.py @@ -2,11 +2,12 @@ import numpy as np -import pytest from probnum import diffeq, filtsmooth, problems, randprocs, randvars from probnum.problems.zoo import diffeq as diffeq_zoo +import pytest + @pytest.fixture def locations(): @@ -87,8 +88,8 @@ def test_ivp_to_regression_problem( # the process noise covariance matrices should be non-zero. if ode_measurement_variance > 0.0: noise = regprob.measurement_models[1].noise_fun(locations[0]) - assert np.linalg.norm(noise.cov > 0.0) - assert np.linalg.norm(noise.cov_cholesky > 0.0) + assert np.linalg.norm(noise.cov) > 0.0 + assert np.linalg.norm(noise._cov_cholesky) > 0.0 # If an approximation strategy is passed, the output should be an EKF component # which should suppoert forward_rv(). diff --git a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_cases.py b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_cases.py index 5050dbc95..d2439681c 100644 --- a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_cases.py +++ b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_cases.py @@ -1,10 +1,11 @@ import numpy as np -import pytest from scipy.integrate._ivp import rk from probnum import diffeq import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + _ADAPTIVE_STEPS = diffeq.stepsize.AdaptiveSteps(atol=1e-4, rtol=1e-4, firststep=0.1) _CONSTANT_STEPS = diffeq.stepsize.ConstantSteps(0.1) diff --git a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_odesolution.py b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_odesolution.py index f103d86c4..571650211 100644 --- a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_odesolution.py +++ b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_odesolution.py @@ -1,8 +1,9 @@ import numpy as np -import pytest_cases from probnum import randvars +import pytest_cases + @pytest_cases.fixture @pytest_cases.parametrize_with_cases( diff --git a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_solver.py b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_solver.py index 1933ee2d7..70e7cb8c2 100644 --- a/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_solver.py +++ b/tests/test_diffeq/test_perturbed/test_scipy_wrapper/test_wrapped_scipy_solver.py @@ -1,12 +1,13 @@ import numpy as np -import pytest -import pytest_cases from scipy.integrate._ivp import base, rk from scipy.integrate._ivp.common import OdeSolution from probnum import diffeq, randvars import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest +import pytest_cases + @pytest_cases.fixture @pytest_cases.parametrize_with_cases( diff --git a/tests/test_diffeq/test_perturbed/test_step/test_perturbation_functions.py b/tests/test_diffeq/test_perturbed/test_step/test_perturbation_functions.py index a20c8da1a..58aee7d9f 100644 --- a/tests/test_diffeq/test_perturbed/test_step/test_perturbation_functions.py +++ b/tests/test_diffeq/test_perturbed/test_step/test_perturbation_functions.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import diffeq +import pytest + @pytest.fixture def rng(): diff --git a/tests/test_diffeq/test_perturbed/test_step/test_perturbed_cases.py b/tests/test_diffeq/test_perturbed/test_step/test_perturbed_cases.py index 94133fcd3..579ab99c4 100644 --- a/tests/test_diffeq/test_perturbed/test_step/test_perturbed_cases.py +++ b/tests/test_diffeq/test_perturbed/test_step/test_perturbed_cases.py @@ -1,10 +1,11 @@ import numpy as np -import pytest from scipy.integrate._ivp import rk from probnum import diffeq import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + _ADAPTIVE_STEPS = diffeq.stepsize.AdaptiveSteps(atol=1e-4, rtol=1e-4, firststep=0.1) _CONSTANT_STEPS = diffeq.stepsize.ConstantSteps(0.1) diff --git a/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolution.py b/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolution.py index fdc1207c0..ec5530c9c 100644 --- a/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolution.py +++ b/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolution.py @@ -1,10 +1,11 @@ import numpy as np -import pytest from scipy.integrate._ivp import rk from probnum import diffeq, randvars import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + @pytest.fixture def steprule(): diff --git a/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolver.py b/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolver.py index e1522fde3..83c591020 100644 --- a/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolver.py +++ b/tests/test_diffeq/test_perturbed/test_step/test_perturbedstepsolver.py @@ -1,10 +1,11 @@ import numpy as np -import pytest -import pytest_cases from scipy.integrate._ivp import base from probnum import diffeq, randvars +import pytest +import pytest_cases + @pytest_cases.fixture @pytest_cases.parametrize_with_cases( diff --git a/tests/test_diffeq/test_perturbsolve_ivp.py b/tests/test_diffeq/test_perturbsolve_ivp.py index e04c0cc10..ceb523f24 100644 --- a/tests/test_diffeq/test_perturbsolve_ivp.py +++ b/tests/test_diffeq/test_perturbsolve_ivp.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import diffeq import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + @pytest.fixture def rng(): diff --git a/tests/test_diffeq/test_probsolve_ivp.py b/tests/test_diffeq/test_probsolve_ivp.py index 843ca24a7..1f6862674 100644 --- a/tests/test_diffeq/test_probsolve_ivp.py +++ b/tests/test_diffeq/test_probsolve_ivp.py @@ -1,10 +1,11 @@ import numpy as np -import pytest from probnum.diffeq import probsolve_ivp from probnum.diffeq.odefilter import ODEFilterSolution import probnum.problems.zoo.diffeq as diffeq_zoo +import pytest + @pytest.fixture def ivp(): diff --git a/tests/test_filtsmooth/conftest.py b/tests/test_filtsmooth/conftest.py index 0a8b18231..cff05e952 100644 --- a/tests/test_filtsmooth/conftest.py +++ b/tests/test_filtsmooth/conftest.py @@ -2,6 +2,7 @@ import numpy as np + import pytest diff --git a/tests/test_filtsmooth/test_gaussian/test_approx/_linearization_test_interface.py b/tests/test_filtsmooth/test_gaussian/test_approx/_linearization_test_interface.py index 6f4cedf04..f96473698 100644 --- a/tests/test_filtsmooth/test_gaussian/test_approx/_linearization_test_interface.py +++ b/tests/test_filtsmooth/test_gaussian/test_approx/_linearization_test_interface.py @@ -1,11 +1,12 @@ """Test interface for EKF and UKF.""" import numpy as np -import pytest from probnum import filtsmooth, problems, randprocs, randvars import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + class InterfaceDiscreteLinearizationTest: """Test approximate Gaussian filtering and smoothing. diff --git a/tests/test_filtsmooth/test_gaussian/test_approx/test_extendedkalman.py b/tests/test_filtsmooth/test_gaussian/test_approx/test_extendedkalman.py index e8ef5ef66..ab55082f2 100644 --- a/tests/test_filtsmooth/test_gaussian/test_approx/test_extendedkalman.py +++ b/tests/test_filtsmooth/test_gaussian/test_approx/test_extendedkalman.py @@ -1,7 +1,5 @@ """Tests for extended Kalman filtering.""" -import pytest - from probnum import filtsmooth from ._linearization_test_interface import ( @@ -9,6 +7,8 @@ InterfaceDiscreteLinearizationTest, ) +import pytest + class TestDiscreteEKFComponent(InterfaceDiscreteLinearizationTest): diff --git a/tests/test_filtsmooth/test_gaussian/test_approx/test_unscentedkalman.py b/tests/test_filtsmooth/test_gaussian/test_approx/test_unscentedkalman.py index 366fff5be..b893fb405 100644 --- a/tests/test_filtsmooth/test_gaussian/test_approx/test_unscentedkalman.py +++ b/tests/test_filtsmooth/test_gaussian/test_approx/test_unscentedkalman.py @@ -1,11 +1,11 @@ """Tests for unscented Kalman filtering.""" -import pytest - from probnum import filtsmooth from ._linearization_test_interface import InterfaceDiscreteLinearizationTest +import pytest + class TestDiscreteUKFComponent(InterfaceDiscreteLinearizationTest): diff --git a/tests/test_filtsmooth/test_gaussian/test_kalman.py b/tests/test_filtsmooth/test_gaussian/test_kalman.py index e098fffad..ff60d7bbf 100644 --- a/tests/test_filtsmooth/test_gaussian/test_kalman.py +++ b/tests/test_filtsmooth/test_gaussian/test_kalman.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import filtsmooth import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + # Problems diff --git a/tests/test_filtsmooth/test_gaussian/test_kalmanposterior.py b/tests/test_filtsmooth/test_gaussian/test_kalmanposterior.py index 5131ade65..7233300e9 100644 --- a/tests/test_filtsmooth/test_gaussian/test_kalmanposterior.py +++ b/tests/test_filtsmooth/test_gaussian/test_kalmanposterior.py @@ -1,9 +1,10 @@ import numpy as np -import pytest -from probnum import filtsmooth, problems, randprocs, randvars, utils +from probnum import backend, filtsmooth, problems, randprocs, randvars import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + @pytest.fixture(name="problem") def fixture_problem(rng): @@ -195,7 +196,7 @@ def test_sampling_shapes_1d(locs, size): ) posterior, _ = kalman.filtsmooth(regression_problem) - size = utils.as_shape(size) + size = backend.asshape(size) if locs is None: base_measure_reals = np.random.randn(*(size + posterior.locations.shape + (1,))) samples = posterior.transform_base_measure_realizations( diff --git a/tests/test_filtsmooth/test_kalman_filter_smoother.py b/tests/test_filtsmooth/test_kalman_filter_smoother.py index 1f490f261..e145bae3d 100644 --- a/tests/test_filtsmooth/test_kalman_filter_smoother.py +++ b/tests/test_filtsmooth/test_kalman_filter_smoother.py @@ -1,10 +1,11 @@ """Test for the convenience functions.""" import numpy as np -import pytest from probnum import filtsmooth +import pytest + @pytest.fixture(name="prior_dimension") def fixture_prior_dimension(): diff --git a/tests/test_filtsmooth/test_optim/test_gauss_newton.py b/tests/test_filtsmooth/test_optim/test_gauss_newton.py index c266d5cac..1c4146b1b 100644 --- a/tests/test_filtsmooth/test_optim/test_gauss_newton.py +++ b/tests/test_filtsmooth/test_optim/test_gauss_newton.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import filtsmooth import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + @pytest.fixture(name="setup", params=[filtsmooth_zoo.logistic_ode]) def fixture_setup(request): diff --git a/tests/test_filtsmooth/test_optim/test_stoppingcriterion.py b/tests/test_filtsmooth/test_optim/test_stoppingcriterion.py index a86cd454b..515d53766 100644 --- a/tests/test_filtsmooth/test_optim/test_stoppingcriterion.py +++ b/tests/test_filtsmooth/test_optim/test_stoppingcriterion.py @@ -2,10 +2,11 @@ import numpy as np -import pytest from probnum.filtsmooth.optim import _stopping_criterion +import pytest + @pytest.fixture(name="d1") def fixture_d1(): diff --git a/tests/test_filtsmooth/test_particle/test_particle_filter.py b/tests/test_filtsmooth/test_particle/test_particle_filter.py index f89315a4f..e99d38f80 100644 --- a/tests/test_filtsmooth/test_particle/test_particle_filter.py +++ b/tests/test_filtsmooth/test_particle/test_particle_filter.py @@ -1,9 +1,10 @@ import numpy as np -import pytest from probnum import filtsmooth, randvars import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + def test_effective_number_of_events(): weights = np.random.rand(10) diff --git a/tests/test_filtsmooth/test_particle/test_particle_filter_posterior.py b/tests/test_filtsmooth/test_particle/test_particle_filter_posterior.py index 51415a3c8..e4d8a1c2f 100644 --- a/tests/test_filtsmooth/test_particle/test_particle_filter_posterior.py +++ b/tests/test_filtsmooth/test_particle/test_particle_filter_posterior.py @@ -1,8 +1,9 @@ import numpy as np -import pytest from probnum import filtsmooth, randvars +import pytest + @pytest.fixture(name="state_list") def fixture_state_list(): diff --git a/tests/test_filtsmooth/test_utils.py b/tests/test_filtsmooth/test_utils.py index b4d090054..99adcae54 100644 --- a/tests/test_filtsmooth/test_utils.py +++ b/tests/test_filtsmooth/test_utils.py @@ -1,11 +1,12 @@ import functools import numpy as np -import pytest from probnum import filtsmooth, problems import probnum.problems.zoo.filtsmooth as filtsmooth_zoo +import pytest + @pytest.fixture(name="car_tracking1") def fixture_car_tracking1(rng): diff --git a/tests/test_functions/conftest.py b/tests/test_functions/conftest.py deleted file mode 100644 index ff86e4bf2..000000000 --- a/tests/test_functions/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def seed() -> int: - return 234 diff --git a/tests/test_linalg/cases/linear_systems.py b/tests/test_linalg/cases/linear_systems.py index 4554ae717..76a7f5adb 100644 --- a/tests/test_linalg/cases/linear_systems.py +++ b/tests/test_linalg/cases/linear_systems.py @@ -3,22 +3,23 @@ from typing import Union import numpy as np -import pytest_cases import scipy.sparse -from probnum import linops, problems +from probnum import backend, linops, problems from probnum.problems.zoo.linalg import random_linear_system +import pytest_cases + cases_matrices = ".matrices" @pytest_cases.parametrize_with_cases("matrix", cases=cases_matrices, scope="module") def case_linsys( matrix: Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], - rng: np.random.Generator, ) -> problems.LinearSystem: """Linear system.""" - return random_linear_system(rng=rng, matrix=matrix) + rng_state = backend.random.rng_state(abs(hash(matrix))) + return random_linear_system(rng_state=rng_state, matrix=matrix) @pytest_cases.parametrize_with_cases( @@ -29,7 +30,7 @@ def case_linsys( ) def case_spd_linsys( spd_matrix: Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], - rng: np.random.Generator, ) -> problems.LinearSystem: """Linear system with symmetric positive definite matrix.""" - return random_linear_system(rng=rng, matrix=spd_matrix) + rng_state = backend.random.rng_state(abs(hash(spd_matrix))) + return random_linear_system(rng_state=rng_state, matrix=spd_matrix) diff --git a/tests/test_linalg/cases/matrices.py b/tests/test_linalg/cases/matrices.py index fe97c5e67..d790a537f 100644 --- a/tests/test_linalg/cases/matrices.py +++ b/tests/test_linalg/cases/matrices.py @@ -2,41 +2,49 @@ import os -import numpy as np -from pytest_cases import case, parametrize import scipy -from probnum import linops +from probnum import backend, linops from probnum.problems.zoo.linalg import random_sparse_spd_matrix, random_spd_matrix from probnum.randprocs import kernels +from pytest_cases import case, parametrize +from tests.utils.random import rng_state_from_sampling_args + m_rows = [1, 2, 10, 100] n_cols = [1, 2, 10, 100] @case(tags=["symmetric", "positive_definite"]) @parametrize("n", n_cols) -def case_random_spd_matrix(n: int, rng: np.random.Generator) -> np.ndarray: - return random_spd_matrix(dim=n, rng=rng) +def case_random_spd_matrix(n: int) -> backend.Array: + rng_state = rng_state_from_sampling_args(n) + return random_spd_matrix(rng_state=rng_state, shape=(n, n)) @case(tags=["symmetric", "positive_definite"]) -def case_random_sparse_spd_matrix(rng: np.random.Generator) -> scipy.sparse.spmatrix: - return random_sparse_spd_matrix(dim=1000, density=0.01, rng=rng) +def case_random_sparse_spd_matrix() -> scipy.sparse.spmatrix: + rng_state = backend.random.rng_state(1) + return random_sparse_spd_matrix( + rng_state=rng_state, shape=(1000, 1000), density=0.01 + ) @case(tags=["symmetric", "positive_definite"]) @parametrize("n", n_cols) -def case_kernel_matrix(n: int, rng: np.random.Generator) -> np.ndarray: +def case_kernel_matrix(n: int) -> backend.Array: """Kernel Gram matrix.""" + rng_state = rng_state_from_sampling_args(n) x_min, x_max = (-4.0, 4.0) - X = rng.uniform(x_min, x_max, (n, 1)) - kern = kernels.ExpQuad(input_shape=1, lengthscale=1) + X = backend.random.uniform( + rng_state=rng_state, minval=x_min, maxval=x_max, shape=(n, 1) + ) + kern = kernels.ExpQuad(input_shape=1, lengthscale=1.0) return kern(X) @case(tags=["symmetric", "positive_definite"]) -def case_poisson() -> np.ndarray: +def case_poisson() -> backend.Array: """Poisson equation with Dirichlet conditions. - Laplace(u) = f in the interior @@ -53,4 +61,4 @@ def case_poisson() -> np.ndarray: @case(tags=["symmetric", "positive_definite"]) def case_scaling_linop() -> linops.Scaling: - return linops.Scaling(np.arange(10)) + return linops.Scaling(backend.arange(10)) diff --git a/tests/test_linalg/conftest.py b/tests/test_linalg/conftest.py index e6bb1f0ad..e91a3c335 100644 --- a/tests/test_linalg/conftest.py +++ b/tests/test_linalg/conftest.py @@ -1,10 +1 @@ """Test fixtures for linear algebra.""" - -import numpy as np -import pytest_cases - - -@pytest_cases.fixture() -def rng() -> np.random.Generator: - """Random number generator.""" - return np.random.default_rng(42) diff --git a/tests/test_linalg/test_problinsolve.py b/tests/test_linalg/test_problinsolve.py index 7f95d9292..f37bdd41d 100644 --- a/tests/test_linalg/test_problinsolve.py +++ b/tests/test_linalg/test_problinsolve.py @@ -7,6 +7,7 @@ import scipy.sparse.linalg from probnum import linalg, linops, randvars + from tests.testing import NumpyAssertions diff --git a/tests/test_linalg/test_solvers/cases/belief_updates.py b/tests/test_linalg/test_solvers/cases/belief_updates.py index 7629ad742..f47185d23 100644 --- a/tests/test_linalg/test_solvers/cases/belief_updates.py +++ b/tests/test_linalg/test_solvers/cases/belief_updates.py @@ -1,9 +1,9 @@ """Test cases describing different belief updates over quantities of interest of a linear system.""" -from pytest_cases import parametrize - from probnum.linalg.solvers.belief_updates import matrix_based, solution_based +from pytest_cases import parametrize + @parametrize(noise_var=[0.0, 0.001, 1.0]) def case_solution_based_projected_residual_belief_update(noise_var: float): diff --git a/tests/test_linalg/test_solvers/cases/beliefs.py b/tests/test_linalg/test_solvers/cases/beliefs.py index 3913d9966..89356bf16 100644 --- a/tests/test_linalg/test_solvers/cases/beliefs.py +++ b/tests/test_linalg/test_solvers/cases/beliefs.py @@ -2,11 +2,12 @@ system.""" import numpy as np -from pytest_cases import case from probnum import linops, randvars from probnum.linalg.solvers import beliefs +from pytest_cases import case + @case(tags=["sym", "posdef", "square"]) def case_trivial_sym_prior(ncols: int) -> beliefs.LinearSystemBelief: diff --git a/tests/test_linalg/test_solvers/cases/policies.py b/tests/test_linalg/test_solvers/cases/policies.py index 9a942ea62..77bd13c1a 100644 --- a/tests/test_linalg/test_solvers/cases/policies.py +++ b/tests/test_linalg/test_solvers/cases/policies.py @@ -1,8 +1,8 @@ """Test cases defined by policies.""" -from pytest_cases import case - +from probnum.backend.linalg import gram_schmidt_double, gram_schmidt_modified from probnum.linalg.solvers import policies -from probnum.utils.linalg import double_gram_schmidt, modified_gram_schmidt + +from pytest_cases import case def case_conjugate_gradient(): @@ -11,13 +11,13 @@ def case_conjugate_gradient(): def case_conjugate_gradient_reorthogonalized_residuals(): return policies.ConjugateGradientPolicy( - reorthogonalization_fn_residual=double_gram_schmidt + reorthogonalization_fn_residual=gram_schmidt_double ) def case_conjugate_gradient_reorthogonalized_actions(): return policies.ConjugateGradientPolicy( - reorthogonalization_fn_action=modified_gram_schmidt + reorthogonalization_fn_action=gram_schmidt_modified ) diff --git a/tests/test_linalg/test_solvers/cases/problems.py b/tests/test_linalg/test_solvers/cases/problems.py index 9b964c814..107fcd622 100644 --- a/tests/test_linalg/test_solvers/cases/problems.py +++ b/tests/test_linalg/test_solvers/cases/problems.py @@ -1,19 +1,19 @@ """Test cases defining linear systems to be solved.""" -import numpy as np -from pytest_cases import case - -from probnum import problems +from probnum import backend, problems from probnum.problems.zoo.linalg import random_sparse_spd_matrix, random_spd_matrix +from pytest_cases import case + @case(tags=["sym", "posdef"]) def case_random_spd_linsys( ncols: int, ) -> problems.LinearSystem: - rng = np.random.default_rng(1) - A = random_spd_matrix(rng=rng, dim=ncols) - x = rng.normal(size=(ncols,)) + rng_state = backend.random.rng_state(1) + rng_state_A, rng_state_x = backend.random.split(rng_state) + A = random_spd_matrix(rng_state=rng_state_A, shape=(ncols, ncols)) + x = backend.random.standard_normal(rng_state=rng_state_x, shape=(ncols,)) b = A @ x return problems.LinearSystem(A=A, b=b, solution=x) @@ -22,8 +22,11 @@ def case_random_spd_linsys( def case_random_sparse_spd_linsys( ncols: int, ) -> problems.LinearSystem: - rng = np.random.default_rng(1) - A = random_sparse_spd_matrix(rng=rng, dim=ncols, density=0.1) - x = rng.normal(size=(ncols,)) + rng_state = backend.random.rng_state(1) + rng_state_A, rng_state_x = backend.random.split(rng_state) + A = random_sparse_spd_matrix( + rng_state=rng_state_A, shape=(ncols, ncols), density=0.1 + ) + x = backend.random.standard_normal(rng_state=rng_state_x, shape=(ncols,)) b = A @ x return problems.LinearSystem(A=A, b=b, solution=x) diff --git a/tests/test_linalg/test_solvers/cases/solvers.py b/tests/test_linalg/test_solvers/cases/solvers.py index 20b63447b..3a5b37c9d 100644 --- a/tests/test_linalg/test_solvers/cases/solvers.py +++ b/tests/test_linalg/test_solvers/cases/solvers.py @@ -1,9 +1,9 @@ """Test cases defining probabilistic linear solvers.""" -from pytest_cases import case - from probnum.linalg import solvers +from pytest_cases import case + @case(tags=["solutionbased", "sym"]) def case_bayescg(): diff --git a/tests/test_linalg/test_solvers/cases/states.py b/tests/test_linalg/test_solvers/cases/states.py index 589b1b30c..f7b29ac4f 100644 --- a/tests/test_linalg/test_solvers/cases/states.py +++ b/tests/test_linalg/test_solvers/cases/states.py @@ -1,22 +1,23 @@ """Probabilistic linear solver state test cases.""" import numpy as np -from pytest_cases import case -from probnum import linalg, linops, randvars +from probnum import backend, linalg, linops, randvars from probnum.problems.zoo.linalg import random_linear_system, random_spd_matrix +from pytest_cases import case + # Problem n = 10 linsys = random_linear_system( - rng=np.random.default_rng(42), matrix=random_spd_matrix, dim=n + backend.random.rng_state(42), matrix=random_spd_matrix, shape=(n, n) ) # Prior Ainv = randvars.Normal( mean=linops.Identity(n), cov=linops.SymmetricKronecker(linops.Identity(n)) ) -b = randvars.Constant(linsys.b) +b = randvars.Constant(backend.to_numpy(linsys.b)) prior = linalg.solvers.beliefs.LinearSystemBelief( A=randvars.Constant(linsys.A), Ainv=Ainv, @@ -32,21 +33,21 @@ def case_initial_state(): @case(tags=["has_action"]) -def case_state( - rng: np.random.Generator, -): +def case_state(): """State of a linear solver.""" + rng_state = backend.random.rng_state(35792) state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) - state.action = rng.standard_normal(size=state.problem.A.shape[1]) + state.action = backend.random.standard_normal( + rng_state=rng_state, shape=state.problem.A.shape[1] + ) return state @case(tags=["has_action", "has_observation", "matrix_based"]) -def case_state_matrix_based( - rng: np.random.Generator, -): +def case_state_matrix_based(): """State of a matrix-based linear solver.""" + rng_state = backend.random.rng_state(9876534) prior = linalg.solvers.beliefs.LinearSystemBelief( A=randvars.Normal( mean=linops.Matrix(linsys.A), @@ -60,17 +61,20 @@ def case_state_matrix_based( b=b, ) state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) - state.action = rng.standard_normal(size=state.problem.A.shape[1]) - state.observation = rng.standard_normal(size=state.problem.A.shape[1]) + state.action = backend.random.standard_normal( + rng_state=rng_state, shape=state.problem.A.shape[1] + ) + state.observation = backend.random.standard_normal( + rng_state=rng_state, shape=state.problem.A.shape[1] + ) return state @case(tags=["has_action", "has_observation", "symmetric_matrix_based"]) -def case_state_symmetric_matrix_based( - rng: np.random.Generator, -): +def case_state_symmetric_matrix_based(): """State of a symmetric matrix-based linear solver.""" + rng_state = backend.random.rng_state(93456) prior = linalg.solvers.beliefs.LinearSystemBelief( A=randvars.Normal( mean=linops.Matrix(linsys.A), @@ -84,27 +88,31 @@ def case_state_symmetric_matrix_based( b=b, ) state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) - state.action = rng.standard_normal(size=state.problem.A.shape[1]) - state.observation = rng.standard_normal(size=state.problem.A.shape[1]) + state.action = backend.random.standard_normal( + rng_state=rng_state, shape=state.problem.A.shape[1] + ) + state.observation = backend.random.standard_normal( + rng_state=rng_state, shape=state.problem.A.shape[1] + ) return state @case(tags=["has_action", "has_observation", "solution_based"]) -def case_state_solution_based( - rng: np.random.Generator, -): +def case_state_solution_based(): """State of a solution-based linear solver.""" + rng_state = backend.random.rng_state(4832) + initial_state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) - initial_state.action = rng.standard_normal(size=initial_state.problem.A.shape[1]) - initial_state.observation = rng.standard_normal() + initial_state.action = backend.random.standard_normal( + rng_state=rng_state, shape=initial_state.problem.A.shape[1] + ) + initial_state.observation = backend.random.standard_normal(rng_state=rng_state) return initial_state -def case_state_converged( - rng: np.random.Generator, -): +def case_state_converged(): """State of a linear solver, which has converged at initialization.""" belief = linalg.solvers.beliefs.LinearSystemBelief( A=randvars.Constant(linsys.A), diff --git a/tests/test_linalg/test_solvers/cases/stopping_criteria.py b/tests/test_linalg/test_solvers/cases/stopping_criteria.py index 826255d70..a5927e8e4 100644 --- a/tests/test_linalg/test_solvers/cases/stopping_criteria.py +++ b/tests/test_linalg/test_solvers/cases/stopping_criteria.py @@ -1,9 +1,9 @@ """Stopping criteria test cases.""" -from pytest_cases import parametrize - from probnum.linalg.solvers import stopping_criteria +from pytest_cases import parametrize + def case_maxiter(): return stopping_criteria.MaxIterationsStoppingCriterion() diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py index 28c2bec01..18bfd1ed6 100644 --- a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py @@ -3,12 +3,13 @@ import pathlib import numpy as np -import pytest -from pytest_cases import parametrize_with_cases from probnum import linops, randvars from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs +import pytest +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem cases_belief_updates = case_modules + ".belief_updates" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py index de45f883d..312e7013e 100644 --- a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py @@ -3,12 +3,13 @@ import pathlib import numpy as np -import pytest -from pytest_cases import parametrize_with_cases from probnum import linops, randvars from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs +import pytest +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem cases_belief_updates = case_modules + ".belief_updates" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py index 374dce6d0..1dc9f0781 100644 --- a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py @@ -4,12 +4,13 @@ import pathlib import numpy as np -import pytest -from pytest_cases import parametrize_with_cases from probnum import randvars from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs +import pytest +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem cases_belief_updates = case_modules + ".belief_updates" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py index 25ad93b3f..298c44993 100644 --- a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py +++ b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py @@ -1,11 +1,12 @@ """Tests for beliefs about quantities of interest of a linear system.""" import numpy as np -import pytest -from probnum import linops, randvars +from probnum import backend, linops, randvars from probnum.linalg.solvers.beliefs import LinearSystemBelief from probnum.problems.zoo.linalg import random_spd_matrix +import pytest + def test_init_invalid_belief(): """Test whether instantiating a belief over neither x nor Ainv raises an error.""" @@ -80,35 +81,39 @@ def test_non_two_dimensional_raises_value_error(): LinearSystemBelief(A=A, Ainv=Ainv, x=x, b=b[:, None]) -def test_non_randvar_arguments_raises_type_error(): - A = np.eye(5) - Ainv = np.eye(5) - x = np.ones((5, 1)) - b = np.ones((5, 1)) +# def test_non_randvar_arguments_raises_type_error(): +# A = np.eye(5) +# Ainv = np.eye(5) +# x = np.ones((5, 1)) +# b = np.ones((5, 1)) - with pytest.raises(TypeError): - LinearSystemBelief(x=x) +# with pytest.raises(TypeError): +# LinearSystemBelief(x=x) - with pytest.raises(TypeError): - LinearSystemBelief(Ainv=Ainv) +# with pytest.raises(TypeError): +# LinearSystemBelief(Ainv=Ainv) - with pytest.raises(TypeError): - LinearSystemBelief(x=randvars.Constant(x), A=A) +# with pytest.raises(TypeError): +# LinearSystemBelief(x=randvars.Constant(x), A=A) - with pytest.raises(TypeError): - LinearSystemBelief(x=randvars.Constant(x), b=b) +# with pytest.raises(TypeError): +# LinearSystemBelief(x=randvars.Constant(x), b=b) -def test_induced_solution_belief(rng: np.random.Generator): +def test_induced_solution_belief(): """Test whether a consistent belief over the solution is inferred from a belief over the inverse.""" + rng_state = backend.random.rng_state(8294) + rng_state_A, rng_state_b = backend.random.split(rng_state=rng_state) n = 5 - A = randvars.Constant(random_spd_matrix(dim=n, rng=rng)) + A = randvars.Constant(random_spd_matrix(rng_state=rng_state_A, shape=(n, n))) Ainv = randvars.Normal( mean=linops.Scaling(factors=1 / np.diag(A.mean)), cov=linops.SymmetricKronecker(linops.Identity(n)), ) - b = randvars.Constant(rng.normal(size=(n, 1))) + b = randvars.Constant( + backend.random.standard_normal(rng_state=rng_state_b, shape=(n, 1)) + ) prior = LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=b) x_infer = Ainv @ b diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py b/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py index 9577ea7ba..655c5cb45 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py @@ -3,10 +3,11 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, information_ops +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py b/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py index 3433377dc..0f5f81312 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py @@ -3,10 +3,11 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, information_ops +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py index e5101b7f2..f13d1b2dc 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py @@ -3,10 +3,11 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, information_ops +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py index 60f5d142b..7b3bba53b 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py +++ b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py @@ -3,11 +3,12 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum import randvars from probnum.linalg.solvers import LinearSolverState, policies +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py index 0d51acae2..1c73bcb67 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py +++ b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py @@ -3,10 +3,11 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, policies +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py index 8201c2705..fd2dfb8a0 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py +++ b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py @@ -2,11 +2,12 @@ import pathlib import numpy as np -import pytest -from pytest_cases import parametrize, parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, policies +import pytest +from pytest_cases import parametrize, parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py index 8f8b8ebac..3f04fe9b4 100644 --- a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py +++ b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py @@ -3,11 +3,12 @@ import pathlib import numpy as np -from pytest_cases import filters, parametrize_with_cases from probnum import linops, problems, randvars from probnum.linalg.solvers import ProbabilisticLinearSolver, beliefs +from pytest_cases import filters, parametrize_with_cases + case_modules = pathlib.Path("cases").stem cases_solvers = case_modules + ".solvers" cases_beliefs = case_modules + ".beliefs" diff --git a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py index 734342e7d..187a7e0ab 100644 --- a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py +++ b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py @@ -2,11 +2,12 @@ import pathlib import numpy as np -from pytest_cases import parametrize_with_cases from probnum import linops, problems, randvars from probnum.linalg.solvers import ProbabilisticLinearSolver, beliefs +from pytest_cases import parametrize_with_cases + case_modules = pathlib.Path("cases").stem cases_solvers = case_modules + ".solvers" cases_beliefs = case_modules + ".beliefs" diff --git a/tests/test_linalg/test_solvers/test_state.py b/tests/test_linalg/test_solvers/test_state.py index 00b64eff8..73898c3e2 100644 --- a/tests/test_linalg/test_solvers/test_state.py +++ b/tests/test_linalg/test_solvers/test_state.py @@ -1,10 +1,11 @@ """Tests for the state of a probabilistic linear solver.""" import numpy as np -from pytest_cases import parametrize, parametrize_with_cases from probnum.linalg.solvers import LinearSolverState +from pytest_cases import parametrize, parametrize_with_cases + cases_states = "cases.states" diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py index 44344c5cb..e0a019af8 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py @@ -2,10 +2,10 @@ import pathlib -from pytest_cases import parametrize_with_cases - from probnum.linalg.solvers import LinearSolverState, stopping_criteria +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py index 97915a7d2..2bd213ac0 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py @@ -2,10 +2,10 @@ import pathlib -from pytest_cases import parametrize_with_cases - from probnum.linalg.solvers import LinearSolverState, stopping_criteria +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py index 22f5e7467..a28b08917 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py @@ -2,10 +2,10 @@ import pathlib -from pytest_cases import parametrize_with_cases - from probnum.linalg.solvers import LinearSolverState, stopping_criteria +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py index 013c10257..935cad8bb 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py @@ -2,10 +2,10 @@ import pathlib -from pytest_cases import parametrize_with_cases - from probnum.linalg.solvers import LinearSolverState, stopping_criteria +from pytest_cases import parametrize_with_cases + case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" cases_states = case_modules + ".states" diff --git a/tests/test_problems/test_zoo/test_linalg/test_random_linear_system.py b/tests/test_problems/test_zoo/test_linalg/test_random_linear_system.py deleted file mode 100644 index c2a0bf05f..000000000 --- a/tests/test_problems/test_zoo/test_linalg/test_random_linear_system.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Tests for functions generating random linear systems.""" - -import numpy as np -import pytest -import scipy.stats - -from probnum import randvars -from probnum.problems.zoo.linalg import random_linear_system, random_spd_matrix - - -def test_custom_random_matrix(rng: np.random.Generator): - random_unitary_matrix = lambda rng, dim: scipy.stats.unitary_group.rvs( - dim=dim, random_state=rng - ) - _ = random_linear_system(rng, random_unitary_matrix, dim=5) - - -def test_custom_solution_randvar(rng: np.random.Generator): - n = 5 - x = randvars.Normal(mean=np.ones(n), cov=np.eye(n)) - _ = random_linear_system(rng=rng, matrix=random_spd_matrix, solution_rv=x, dim=n) - - -def test_incompatible_matrix_and_solution(rng: np.random.Generator): - - with pytest.raises(ValueError): - _ = random_linear_system( - rng=rng, - matrix=random_spd_matrix, - solution_rv=randvars.Normal(np.ones(2), np.eye(2)), - dim=5, - ) diff --git a/tests/test_quad/conftest.py b/tests/test_quad/conftest.py index 961002356..dcc6d1f39 100644 --- a/tests/test_quad/conftest.py +++ b/tests/test_quad/conftest.py @@ -3,7 +3,6 @@ from typing import Dict import numpy as np -import pytest from probnum.quad.integration_measures import ( GaussianMeasure, @@ -13,6 +12,8 @@ from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.randprocs import kernels +import pytest + # pylint: disable=unnecessary-lambda diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 9af6fd4a9..f3fb27dcb 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -1,7 +1,6 @@ """Basic tests for Bayesian quadrature method.""" import numpy as np -import pytest from probnum import LambdaStoppingCriterion from probnum.quad.integration_measures import LebesgueMeasure @@ -10,6 +9,8 @@ from probnum.quad.solvers.stopping_criteria import ImmediateStop from probnum.randprocs.kernels import ExpQuad +import pytest + @pytest.fixture def input_dim(): diff --git a/tests/test_quad/test_bayesquad/test_bq.py b/tests/test_quad/test_bayesquad/test_bq.py index 704c62976..9b744700a 100644 --- a/tests/test_quad/test_bayesquad/test_bq.py +++ b/tests/test_quad/test_bayesquad/test_bq.py @@ -1,7 +1,6 @@ """Test cases for Bayesian quadrature.""" import numpy as np -import pytest from scipy.integrate import quad as scipyquad from probnum.quad import bayesquad, bayesquad_from_data @@ -11,6 +10,8 @@ from ..util import gauss_hermite_tensor, gauss_legendre_tensor +import pytest + @pytest.fixture def rng(): @@ -194,8 +195,8 @@ def test_domain_ignored_if_lebesgue(input_dim, measure): def test_zero_function_gives_zero_variance_with_mle(): - """Test that BQ variance is zero for zero function when MLE is used to set the - scale parameter.""" + """Test that BQ variance is zero for zero function when MLE is used to set the scale + parameter.""" input_dim = 1 domain = (0, 1) fun = lambda x: np.zeros(x.shape[0]) diff --git a/tests/test_quad/test_belief_update.py b/tests/test_quad/test_belief_update.py index 7a7041589..3aaf0bcaf 100644 --- a/tests/test_quad/test_belief_update.py +++ b/tests/test_quad/test_belief_update.py @@ -1,9 +1,9 @@ """Test cases for the BQ belief updater.""" -import pytest - from probnum.quad.solvers.belief_updates import BQStandardBeliefUpdate +import pytest + def test_belief_update_raises(): # negative jitter is not allowed diff --git a/tests/test_quad/test_bq_state.py b/tests/test_quad/test_bq_state.py index 85123f06c..abee07cbe 100644 --- a/tests/test_quad/test_bq_state.py +++ b/tests/test_quad/test_bq_state.py @@ -1,7 +1,6 @@ """Basic tests for the BQ info container and BQ state.""" import numpy as np -import pytest from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure from probnum.quad.kernel_embeddings import KernelEmbedding @@ -9,6 +8,8 @@ from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal +import pytest + @pytest.fixture def nevals(): diff --git a/tests/test_quad/test_bq_utils.py b/tests/test_quad/test_bq_utils.py index 9fd145bb9..34b9e7848 100644 --- a/tests/test_quad/test_bq_utils.py +++ b/tests/test_quad/test_bq_utils.py @@ -1,10 +1,11 @@ """Basic tests for bq utils.""" import numpy as np -import pytest from probnum.quad._utils import as_domain +import pytest + # fmt: off @pytest.mark.parametrize( diff --git a/tests/test_quad/test_integration_measure.py b/tests/test_quad/test_integration_measure.py index 7596a2167..a86ca91b5 100644 --- a/tests/test_quad/test_integration_measure.py +++ b/tests/test_quad/test_integration_measure.py @@ -1,10 +1,11 @@ """Test cases for integration measures.""" import numpy as np -import pytest from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure +import pytest + # Tests for Gaussian measure def test_gaussian_diagonal_covariance(input_dim: int): diff --git a/tests/test_quad/test_kernel_conversion.py b/tests/test_quad/test_kernel_conversion.py index ed3d09d4e..150353e26 100644 --- a/tests/test_quad/test_kernel_conversion.py +++ b/tests/test_quad/test_kernel_conversion.py @@ -1,10 +1,10 @@ """Test cases for converting kernels to product kernels in quad.""" -import pytest - from probnum.quad.kernel_embeddings._matern_lebesgue import _convert_to_product_matern from probnum.randprocs.kernels import Matern +import pytest + def test_product_kernel_conversion_matern(): kernel = Matern(input_shape=(1,)) diff --git a/tests/test_quad/test_kernel_embeddings.py b/tests/test_quad/test_kernel_embeddings.py index e6b4fe137..f5768d6a5 100644 --- a/tests/test_quad/test_kernel_embeddings.py +++ b/tests/test_quad/test_kernel_embeddings.py @@ -1,13 +1,14 @@ """Test cases for kernel embeddings.""" import numpy as np -import pytest from scipy.integrate import quad from probnum.quad.kernel_embeddings import KernelEmbedding from .util import gauss_hermite_tensor, gauss_legendre_tensor +import pytest + # Common tests def test_kernel_mean_shape(kernel_embedding, x): diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index fa848da73..77ebb6349 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -1,11 +1,12 @@ """Basic tests for BQ policies.""" import numpy as np -import pytest from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure from probnum.quad.solvers.policies import VanDerCorputPolicy +import pytest + def test_van_der_corput_multi_d_error(): """Check that van der Corput policy fails in dimensions higher than one.""" diff --git a/tests/test_quad/test_stopping_criterion.py b/tests/test_quad/test_stopping_criterion.py index ae91acd97..b7c9c64b3 100644 --- a/tests/test_quad/test_stopping_criterion.py +++ b/tests/test_quad/test_stopping_criterion.py @@ -3,7 +3,6 @@ from typing import Tuple import numpy as np -import pytest from probnum.quad.integration_measures import LebesgueMeasure from probnum.quad.solvers import BQIterInfo, BQState @@ -17,6 +16,8 @@ from probnum.randprocs.kernels import ExpQuad from probnum.randvars import Normal +import pytest + _nevals = 5 _rel_tol = 1e-5 _var_tol = 1e-5 diff --git a/tests/test_quad/util.py b/tests/test_quad/util.py index 84397c055..4c79c7f2e 100644 --- a/tests/test_quad/util.py +++ b/tests/test_quad/util.py @@ -5,7 +5,7 @@ from scipy.linalg import sqrtm from scipy.special import roots_legendre -from probnum.typing import FloatLike, IntLike +from probnum.backend.typing import FloatLike, IntLike # Auxiliary functions for quadrature tests diff --git a/tests/test_randprocs/conftest.py b/tests/test_randprocs/conftest.py deleted file mode 100644 index a0f56efa0..000000000 --- a/tests/test_randprocs/conftest.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Fixtures for random process tests.""" - -import functools -from typing import Callable - -import numpy as np -import pytest - -from probnum import functions, randprocs -from probnum.randprocs import kernels - - -@pytest.fixture( - params=[pytest.param(seed, id=f"seed{seed}") for seed in range(3)], - name="rng", -) -def fixture_rng(request) -> np.random.Generator: - """Random state(s) used for test parameterization.""" - return np.random.default_rng(seed=request.param) - - -@pytest.fixture( - params=[ - pytest.param(input_dim, id=f"indim{input_dim}") for input_dim in [1, 10, 100] - ], - name="input_dim", -) -def fixture_input_dim(request) -> int: - """Input dimension of the random process.""" - return request.param - - -@pytest.fixture( - params=[ - pytest.param(output_dim, id=f"outdim{output_dim}") for output_dim in [1, 2, 10] - ] -) -def output_dim(request) -> int: - """Output dimension of the random process.""" - return request.param - - -@pytest.fixture( - params=[ - pytest.param(mu, id=mu[0]) - for mu in [ - ("zero", functions.Zero), - ( - "lin", - functools.partial( - functions.LambdaFunction, lambda x: 2 * x.sum(axis=1) + 1.0 - ), - ), - ] - ], - name="mean", -) -def fixture_mean(request, input_dim: int) -> Callable: - """Mean function of a random process.""" - return request.param[1](input_shape=(input_dim,), output_shape=()) - - -@pytest.fixture( - params=[ - pytest.param(kerndef, id=kerndef[0].__name__) - for kerndef in [ - (kernels.Polynomial, {"constant": 1.0, "exponent": 3}), - (kernels.ExpQuad, {"lengthscale": 1.5}), - (kernels.RatQuad, {"lengthscale": 0.5, "alpha": 2.0}), - (kernels.Matern, {"lengthscale": 0.5, "nu": 1.5}), - ] - ], - name="cov", -) -def fixture_cov(request, input_dim: int) -> kernels.Kernel: - """Covariance function.""" - return request.param[0](**request.param[1], input_shape=(input_dim,)) - - -@pytest.fixture( - params=[ - pytest.param(randprocdef, id=randprocdef[0]) - for randprocdef in [ - ( - "gp", - randprocs.GaussianProcess( - mean=functions.Zero(input_shape=(1,)), - cov=kernels.Matern(input_shape=(1,)), - ), - ), - ] - ], - name="random_process", -) -def fixture_random_process(request) -> randprocs.RandomProcess: - """Random process.""" - return request.param[1] - - -@pytest.fixture(name="gaussian_process") -def fixture_gaussian_process(mean, cov) -> randprocs.GaussianProcess: - """Gaussian process.""" - return randprocs.GaussianProcess(mean=mean, cov=cov) - - -@pytest.fixture(params=[pytest.param(n, id=f"n{n}") for n in [1, 10]], name="args0") -def fixture_args0( - request, - random_process: randprocs.RandomProcess, - rng: np.random.Generator, -) -> np.ndarray: - """Input(s) to a random process.""" - return rng.normal(size=(request.param,) + random_process.input_shape) diff --git a/tests/test_randprocs/test_kernels/conftest.py b/tests/test_randprocs/test_kernels/conftest.py deleted file mode 100644 index 60a3cd0d1..000000000 --- a/tests/test_randprocs/test_kernels/conftest.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Test fixtures for kernels.""" - -from typing import Callable, Optional - -import numpy as np -import pytest - -import probnum as pn -from probnum.typing import ShapeType - - -@pytest.fixture( - params=[pytest.param(seed, id=f"seed{seed}") for seed in range(1)], - name="rng", -) -def fixture_rng(request): - """Random state(s) used for test parameterization.""" - return np.random.default_rng(seed=request.param) - - -# Kernel objects -@pytest.fixture( - params=[ - pytest.param(input_shape, id=f"inshape{input_shape}") - for input_shape in [(), (1,), (10,), (100,)] - ], - name="input_shape", -) -def fixture_input_shape(request) -> ShapeType: - """Input shape of the covariance function.""" - return request.param - - -@pytest.fixture( - params=[ - pytest.param(kerndef, id=kerndef[0].__name__) - for kerndef in [ - (pn.randprocs.kernels.Linear, {"constant": 1.0}), - (pn.randprocs.kernels.WhiteNoise, {"sigma_sq": 1.0}), - (pn.randprocs.kernels.Polynomial, {"constant": 1.0, "exponent": 3}), - (pn.randprocs.kernels.ExpQuad, {"lengthscale": 1.5}), - (pn.randprocs.kernels.RatQuad, {"lengthscale": 0.5, "alpha": 2.0}), - (pn.randprocs.kernels.Matern, {"lengthscale": 0.5, "nu": 0.5}), - (pn.randprocs.kernels.Matern, {"lengthscale": 0.5, "nu": 1.5}), - (pn.randprocs.kernels.Matern, {"lengthscale": 1.5, "nu": 2.5}), - (pn.randprocs.kernels.Matern, {"lengthscale": 2.5, "nu": 7.0}), - (pn.randprocs.kernels.Matern, {"lengthscale": 3.0, "nu": np.inf}), - (pn.randprocs.kernels.ProductMatern, {"lengthscales": 0.5, "nus": 0.5}), - ] - ], - name="kernel", -) -def fixture_kernel(request, input_shape: ShapeType) -> pn.randprocs.kernels.Kernel: - """Kernel / covariance function.""" - return request.param[0](input_shape=input_shape, **request.param[1]) - - -@pytest.fixture(name="kernel_call_naive") -def fixture_kernel_call_naive( - kernel: pn.randprocs.kernels.Kernel, -) -> Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]: - """Naive implementation of kernel broadcasting which applies the kernel function to - scalar arguments while looping over the first dimensions of the inputs explicitly. - - Can be used as a reference implementation of `Kernel.__call__` vectorization. - """ - - if kernel.input_ndim == 0: - kernel_vectorized = np.vectorize(kernel, signature="(),()->()") - else: - assert kernel.input_ndim == 1 - - kernel_vectorized = np.vectorize(kernel, signature="(d),(d)->()") - - return lambda x0, x1: ( - kernel_vectorized(x0, x0) if x1 is None else kernel_vectorized(x0, x1) - ) - - -# Test data for `Kernel.matrix` -@pytest.fixture( - params=[ - pytest.param(shape, id=f"x0{shape}") - for shape in [ - (), - (1,), - (2,), - (10,), - (100,), - ] - ], - name="x0_batch_shape", -) -def fixture_x0_batch_shape(request) -> ShapeType: - """Batch shape of the first argument of ``Kernel.matrix``.""" - return request.param - - -@pytest.fixture( - params=[ - pytest.param(shape, id=f"x1{shape}") - for shape in [ - None, - (), - (1,), - (3,), - (10,), - ] - ], - name="x1_batch_shape", -) -def fixture_x1_batch_shape(request) -> Optional[ShapeType]: - """Batch shape of the second argument of ``Kernel.matrix`` or ``None`` if the second - argument is ``None``.""" - return request.param - - -@pytest.fixture(name="x0") -def fixture_x0( - rng: np.random.Generator, x0_batch_shape: ShapeType, input_shape: ShapeType -) -> np.ndarray: - """Random data from a standard normal distribution.""" - return rng.normal(0, 1, size=x0_batch_shape + input_shape) - - -@pytest.fixture(name="x1") -def fixture_x1( - rng: np.random.Generator, - x1_batch_shape: Optional[ShapeType], - input_shape: ShapeType, -) -> Optional[np.ndarray]: - """Random data from a standard normal distribution.""" - if x1_batch_shape is None: - return None - - return rng.normal(0, 1, size=x1_batch_shape + input_shape) diff --git a/tests/test_randprocs/test_kernels/test_product_matern.py b/tests/test_randprocs/test_kernels/test_product_matern.py deleted file mode 100644 index 31221b974..000000000 --- a/tests/test_randprocs/test_kernels/test_product_matern.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Test cases for the product Matern kernel.""" - -import numpy as np -import pytest - -from probnum.randprocs import kernels -import probnum.utils as _utils - - -@pytest.mark.parametrize("nu", [0.5, 1.5, 2.5, 3.0]) -def test_kernel_matrix(input_dim, nu): - """Check that the product Matérn kernel matrix is an elementwise product of 1D - Matérn kernel matrices.""" - lengthscale = 1.25 - matern = kernels.Matern(input_shape=(1,), lengthscale=lengthscale, nu=nu) - product_matern = kernels.ProductMatern( - input_shape=(input_dim,), lengthscales=lengthscale, nus=nu - ) - rng = np.random.default_rng(42) - num_xs = 15 - xs = rng.random(size=(num_xs, input_dim)) - kernel_matrix1 = product_matern.matrix(xs) - kernel_matrix2 = np.ones(shape=(num_xs, num_xs)) - for dim in range(input_dim): - kernel_matrix2 *= matern.matrix(_utils.as_colvec(xs[:, dim])) - np.testing.assert_allclose( - kernel_matrix1, - kernel_matrix2, - ) - - -@pytest.mark.parametrize( - "ell,nu", - [ - (np.array([3.0]), 0.5), - (3.0, np.array([0.5])), - (np.array([3.0]), np.array([0.5])), - ], -) -def test_wrong_initialization_raises_exception(ell, nu): - """Parameters must be scalars if kernel input is scalar.""" - with pytest.raises(ValueError): - kernels.ProductMatern(input_shape=(), lengthscales=ell, nus=nu) diff --git a/tests/test_randprocs/test_markov/conftest.py b/tests/test_randprocs/test_markov/conftest.py deleted file mode 100644 index c1e534da6..000000000 --- a/tests/test_randprocs/test_markov/conftest.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Fixtures to be shared across all modules in this directory. - -Mostly some random variables of matching dimensions. -""" - -import numpy as np -import pytest - -from probnum import randvars -from probnum.problems.zoo.linalg import random_spd_matrix - - -@pytest.fixture -def rng(): - return np.random.default_rng(seed=123) - - -@pytest.fixture(params=[2]) -def test_ndim(request): - """Test dimension.""" - return request.param - - -# A few covariance matrices - - -@pytest.fixture -def spdmat1(test_ndim, rng): - return random_spd_matrix(rng, dim=test_ndim) - - -@pytest.fixture -def spdmat2(test_ndim, rng): - return random_spd_matrix(rng, dim=test_ndim) - - -@pytest.fixture -def spdmat3(test_ndim, rng): - return random_spd_matrix(rng, dim=test_ndim) - - -@pytest.fixture -def spdmat4(test_ndim, rng): - return random_spd_matrix(rng, dim=test_ndim) - - -# A few 'Normal' random variables - - -@pytest.fixture -def some_normal_rv1(test_ndim, spdmat1, rng): - - return randvars.Normal( - mean=rng.uniform(size=test_ndim), - cov=spdmat1, - cov_cholesky=np.linalg.cholesky(spdmat1), - ) - - -@pytest.fixture -def some_normal_rv2(test_ndim, spdmat2, rng): - return randvars.Normal( - mean=rng.uniform(size=test_ndim), - cov=spdmat2, - cov_cholesky=np.linalg.cholesky(spdmat2), - ) - - -@pytest.fixture -def some_normal_rv3(test_ndim, spdmat3, rng): - return randvars.Normal( - mean=rng.uniform(size=test_ndim), - cov=spdmat3, - cov_cholesky=np.linalg.cholesky(spdmat3), - ) - - -@pytest.fixture -def some_normal_rv4(test_ndim, spdmat4, rng): - return randvars.Normal( - mean=rng.uniform(size=test_ndim), - cov=spdmat4, - cov_cholesky=np.linalg.cholesky(spdmat4), - ) - - -@pytest.fixture -def diffusion(): - """A diffusion != 1 makes it easier to see if _diffusion is actually used in forward - and backward.""" - return 5.1412512431 - - -@pytest.fixture(params=["classic", "sqrt"]) -def forw_impl_string_linear_gauss(request): - """Forward implementation choices passed via strings.""" - return request.param - - -@pytest.fixture(params=["classic", "joseph", "sqrt"]) -def backw_impl_string_linear_gauss(request): - """Backward implementation choices passed via strings.""" - return request.param diff --git a/tests/test_randvars/test_arithmetic/conftest.py b/tests/test_randvars/test_arithmetic/conftest.py index 7522c99ae..9be62c857 100644 --- a/tests/test_randvars/test_arithmetic/conftest.py +++ b/tests/test_randvars/test_arithmetic/conftest.py @@ -1,59 +1,77 @@ """Fixtures for random variable arithmetic.""" -import numpy as np -import pytest - -from probnum import linops, randvars +from probnum import backend, linops, randvars +from probnum.backend.typing import ShapeLike from probnum.problems.zoo.linalg import random_spd_matrix -from probnum.typing import ShapeLike - -@pytest.fixture -def rng() -> np.random.Generator: - return np.random.default_rng(42) +import pytest +import tests.utils @pytest.fixture -def constant(shape_const: ShapeLike, rng: np.random.Generator) -> randvars.Constant: - return randvars.Constant(support=rng.normal(size=shape_const)) +def constant(shape_const: ShapeLike) -> randvars.Constant: + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=19836, shape=shape_const + ) + + return randvars.Constant( + support=backend.random.standard_normal(rng_state, shape=shape_const) + ) @pytest.fixture def multivariate_normal( - shape: ShapeLike, precompute_cov_cholesky: bool, rng: np.random.Generator + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=1908, shape=shape + ) + rng_state_mean, rng_state_cov = backend.random.split(rng_state) + rv = randvars.Normal( - mean=rng.normal(size=shape), - cov=random_spd_matrix(rng=rng, dim=shape[0]), + mean=backend.random.standard_normal(rng_state_mean, shape=shape), + cov=random_spd_matrix(rng_state_cov, dim=shape[0]), ) if precompute_cov_cholesky: - rv.precompute_cov_cholesky() + rv._compute_cov_cholesky() return rv @pytest.fixture def matrixvariate_normal( - shape: ShapeLike, precompute_cov_cholesky: bool, rng: np.random.Generator + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=354, shape=shape + ) + rng_state_mean, rng_state_cov_A, rng_state_cov_B = backend.random.split( + rng_state, num=3 + ) + rv = randvars.Normal( - mean=rng.normal(size=shape), + mean=backend.random.standard_normal(rng_state_mean, shape=shape), cov=linops.Kronecker( - A=random_spd_matrix(dim=shape[0], rng=rng), - B=random_spd_matrix(dim=shape[1], rng=rng), + A=random_spd_matrix(rng_state_cov_A, dim=shape[0]), + B=random_spd_matrix(rng_state_cov_B, dim=shape[1]), ), ) if precompute_cov_cholesky: - rv.precompute_cov_cholesky() + rv._compute_cov_cholesky() return rv @pytest.fixture def symmetric_matrixvariate_normal( - shape: ShapeLike, precompute_cov_cholesky: bool, rng: np.random.Generator + shape: ShapeLike, precompute_cov_cholesky: bool ) -> randvars.Normal: + rng_state = tests.utils.random.rng_state_from_sampling_args( + base_seed=246, shape=shape + ) + rng_state_mean, rng_state_cov = backend.random.split(rng_state) + rv = randvars.Normal( - mean=random_spd_matrix(dim=shape[0], rng=rng), - cov=linops.SymmetricKronecker(A=random_spd_matrix(dim=shape[0], rng=rng)), + mean=random_spd_matrix(rng_state_mean, dim=shape[0]), + cov=linops.SymmetricKronecker(A=random_spd_matrix(rng_state_cov, dim=shape[0])), ) if precompute_cov_cholesky: - rv.precompute_cov_cholesky() + rv._compute_cov_cholesky() return rv diff --git a/tests/test_randvars/test_arithmetic/test_constant.py b/tests/test_randvars/test_arithmetic/test_constant.py index d0abbe45f..9249dd842 100644 --- a/tests/test_randvars/test_arithmetic/test_constant.py +++ b/tests/test_randvars/test_arithmetic/test_constant.py @@ -3,10 +3,11 @@ from typing import Callable import numpy as np -import pytest from probnum import randvars +import pytest + @pytest.mark.parametrize( "op", diff --git a/tests/test_randvars/test_arithmetic/test_generic.py b/tests/test_randvars/test_arithmetic/test_generic.py index 8561c3fed..89e3966ca 100644 --- a/tests/test_randvars/test_arithmetic/test_generic.py +++ b/tests/test_randvars/test_arithmetic/test_generic.py @@ -2,18 +2,21 @@ import numpy as np from numpy.typing import DTypeLike -import pytest -from probnum import randvars -from probnum.typing import ShapeLike +from probnum import backend, randvars +from probnum.backend.typing import ShapeLike + +import pytest @pytest.mark.parametrize("shape,dtype", [((5,), np.single), ((2, 3), np.double)]) def test_generic_randvar_dtype_shape_inference(shape: ShapeLike, dtype: DTypeLike): x = randvars.RandomVariable( - shape=shape, dtype=dtype, sample=lambda size, rng: np.zeros(size + shape) + shape=shape, + dtype=dtype, + sample=lambda seed, sample_shape: backend.zeros(sample_shape + shape), ) y = np.array(5.0) z = x + y - assert z.dtype == np.promote_types(dtype, y.dtype) + assert z.dtype == backend.promote_types(dtype, y.dtype) assert z.shape == shape diff --git a/tests/test_randvars/test_arithmetic/test_matrixvariate_normal.py b/tests/test_randvars/test_arithmetic/test_matrixvariate_normal.py index 936b0a5c3..5a0c58851 100644 --- a/tests/test_randvars/test_arithmetic/test_matrixvariate_normal.py +++ b/tests/test_randvars/test_arithmetic/test_matrixvariate_normal.py @@ -1,10 +1,11 @@ """Tests for matrix-variate normal arithmetic.""" import numpy as np -import pytest from probnum import linops +import pytest + @pytest.mark.parametrize( "shape_const,shape", diff --git a/tests/test_randvars/test_arithmetic/test_multivariate_normal.py b/tests/test_randvars/test_arithmetic/test_multivariate_normal.py index 67f25f8ea..d5d967e4d 100644 --- a/tests/test_randvars/test_arithmetic/test_multivariate_normal.py +++ b/tests/test_randvars/test_arithmetic/test_multivariate_normal.py @@ -1,9 +1,10 @@ """Tests for multi-variate normal arithmetic.""" import numpy as np -import pytest -from probnum import utils +from probnum import backend + +import pytest @pytest.mark.parametrize("shape,shape_const", [((3,), (3,))]) @@ -112,7 +113,7 @@ def test_constant_multivariate_normal_matrix_multiplication_right( if matrix_product.cov_cholesky_is_precomputed: np.testing.assert_allclose( matrix_product.cov_cholesky, - utils.linalg.cholesky_update( + backend.linalg.cholesky_update( constant.support @ multivariate_normal.cov_cholesky ), ) @@ -142,7 +143,7 @@ def test_constant_multivariate_normal_matrix_multiplication_left( if matrix_product.cov_cholesky_is_precomputed: np.testing.assert_allclose( matrix_product.cov_cholesky, - utils.linalg.cholesky_update( + backend.linalg.cholesky_update( constant.support.T @ multivariate_normal.cov_cholesky ), ) diff --git a/tests/test_randvars/test_categorical.py b/tests/test_randvars/test_categorical.py index f4bd9b961..77c74feb2 100644 --- a/tests/test_randvars/test_categorical.py +++ b/tests/test_randvars/test_categorical.py @@ -4,9 +4,10 @@ import string import numpy as np -import pytest -from probnum import randvars, utils +from probnum import backend, randvars + +import pytest NDIM = 5 @@ -53,7 +54,7 @@ def test_support(categ): @pytest.mark.parametrize("size", [(), 1, (1,), (1, 1)]) def test_sample(categ, size, rng): samples = categ.sample(rng=rng, size=size) - expected_shape = utils.as_shape(size) + categ.shape + expected_shape = backend.asshape(size) + categ.shape assert samples.shape == expected_shape diff --git a/tests/test_randvars/test_normal.py b/tests/test_randvars/test_normal.py index 3d6f99731..b5faed6fc 100644 --- a/tests/test_randvars/test_normal.py +++ b/tests/test_randvars/test_normal.py @@ -3,84 +3,17 @@ import unittest import numpy as np -import scipy.linalg -import scipy.sparse import scipy.stats from probnum import config, linops, randvars from probnum.problems.zoo.linalg import random_spd_matrix + from tests.testing import NumpyAssertions class NormalTestCase(unittest.TestCase, NumpyAssertions): """General test case for the normal distribution.""" - def setUp(self): - """Resources for tests.""" - - # Seed - self.seed = 42 - self.rng = np.random.default_rng(seed=self.seed) - - # Parameters - m = 7 - n = 3 - self.constants = [-1, -2.4, 0, 200, np.pi] - sparsemat = scipy.sparse.rand(m=m, n=n, density=0.1, random_state=self.rng) - self.normal_params = [ - # Univariate - (-1.0, 3.0), - (1, 3), - # Multivariate - (np.random.uniform(size=10), np.eye(10)), - (np.random.uniform(size=10), random_spd_matrix(rng=self.rng, dim=10)), - # Matrixvariate - ( - np.random.uniform(size=(2, 2)), - linops.SymmetricKronecker( - A=np.array([[1.0, 2.0], [2.0, 10.0]]), - B=np.array([[5.0, -1.0], [-1.0, 10.0]]), - ).todense(), - ), - # Operatorvariate - ( - np.array([1.0, -5.0]), - linops.Matrix(A=np.array([[2.0, 1.0], [1.0, 1.0]])), - ), - ( - linops.Matrix(A=np.array([[0.0, -5.0]])), - linops.Identity(shape=(2, 2)), - ), - ( - np.array([[1.0, 2.0], [-3.0, -0.4], [4.0, 1.0]]), - linops.Kronecker(A=np.eye(3), B=5 * np.eye(2)), - ), - ( - linops.Matrix(A=sparsemat.todense()), - linops.Kronecker(linops.Identity(m), linops.Identity(n)), - ), - ( - linops.Matrix(A=np.random.uniform(size=(2, 2))), - linops.SymmetricKronecker( - A=np.array([[1.0, 2.0], [2.0, 10.0]]), - B=np.array([[5.0, -1.0], [-1.0, 10.0]]), - ), - ), - # Symmetric Kronecker Identical Factors - ( - linops.Identity(shape=25), - linops.SymmetricKronecker(A=linops.Identity(25)), - ), - ] - - def test_correct_instantiation(self): - """Test whether different variants of the normal distribution are instances of - Normal.""" - for mean, cov in self.normal_params: - with self.subTest(): - dist = randvars.Normal(mean=mean, cov=cov) - self.assertIsInstance(dist, randvars.Normal) - def test_scalarmult(self): """Multiply a rv with a normal distribution with a scalar.""" for (mean, cov), const in list( @@ -454,23 +387,6 @@ def test_precompute_cov_cholesky(self): with self.subTest("Cholesky is precomputed"): self.assertTrue(rv.cov_cholesky_is_precomputed) - def test_damping_factor_config(self): - mean, cov = self.params - rv = randvars.Normal(mean, cov) - - chol_standard_damping = rv.dense_cov_cholesky(damping_factor=None) - self.assertAllClose( - chol_standard_damping, - np.sqrt(rv.cov + 1e-12), - ) - - with config(covariance_inversion_damping=1e-3): - chol_altered_damping = rv.dense_cov_cholesky(damping_factor=None) - self.assertAllClose( - chol_altered_damping, - np.sqrt(rv.cov + 1e-3), - ) - def test_cov_cholesky_cov_cholesky_passed(self): """A value for cov_cholesky is passed in init. @@ -483,7 +399,7 @@ def test_cov_cholesky_cov_cholesky_passed(self): # This is purposely not the correct Cholesky factor for test reasons cov_cholesky = np.random.rand() - rv = randvars.Normal(mean, cov, cov_cholesky=cov_cholesky) + rv = randvars.Normal(mean, cov, cache={"cov_cholesky": cov_cholesky}) with self.subTest("Cholesky precomputed"): self.assertTrue(rv.cov_cholesky_is_precomputed) @@ -497,25 +413,6 @@ def test_cov_cholesky_cov_cholesky_passed(self): class MultivariateNormalTestCase(unittest.TestCase, NumpyAssertions): - def setUp(self): - - self.seed = 42 - self.rng = np.random.default_rng(self.seed) - - self.params = ( - self.rng.uniform(size=10), - random_spd_matrix(rng=self.rng, dim=10), - ) - - def test_newaxis(self): - vector_rv = randvars.Normal(*self.params) - - matrix_rv = vector_rv[:, np.newaxis] - - self.assertEqual(matrix_rv.shape, (10, 1)) - self.assertArrayEqual(np.squeeze(matrix_rv.mean), vector_rv.mean) - self.assertArrayEqual(matrix_rv.cov, vector_rv.cov) - def test_reshape(self): rv = randvars.Normal(*self.params) @@ -621,7 +518,7 @@ def test_cov_cholesky_cov_cholesky_passed(self): # This is purposely not the correct Cholesky factor for test reasons cov_cholesky = np.random.rand(*cov.shape) - rv = randvars.Normal(mean, cov, cov_cholesky=cov_cholesky) + rv = randvars.Normal(mean, cov, cache={"cov_cholesky": cov_cholesky}) with self.subTest("Cholesky precomputed"): self.assertTrue(rv.cov_cholesky_is_precomputed) @@ -641,12 +538,16 @@ def test_cholesky_cov_incompatible_types(self): cov_cholesky_wrong_type = cov_cholesky.tolist() with self.subTest("Different type raises ValueError"): with self.assertRaises(TypeError): - randvars.Normal(mean, cov, cov_cholesky=cov_cholesky_wrong_type) + randvars.Normal( + mean, cov, cache={"cov_cholesky": cov_cholesky_wrong_type} + ) cov_cholesky_wrong_shape = cov_cholesky[1:] with self.subTest("Different shape raises ValueError"): with self.assertRaises(ValueError): - randvars.Normal(mean, cov, cov_cholesky=cov_cholesky_wrong_shape) + randvars.Normal( + mean, cov, cache={"cov_cholesky": cov_cholesky_wrong_shape} + ) cov_cholesky_wrong_dtype = cov_cholesky.astype(int) with self.subTest("Different data type is promoted"): @@ -656,7 +557,7 @@ def test_cholesky_cov_incompatible_types(self): # Assert data type of cov_cholesky is changed during __init__ normal_new_dtype = randvars.Normal( - mean, cov, cov_cholesky=cov_cholesky_wrong_dtype + mean, cov, cache={"cov_cholesky": cov_cholesky_wrong_dtype} ) self.assertEqual( normal_new_dtype.cov.dtype, normal_new_dtype.cov_cholesky.dtype @@ -664,11 +565,6 @@ def test_cholesky_cov_incompatible_types(self): class MatrixvariateNormalTestCase(unittest.TestCase, NumpyAssertions): - def setUp(self): - # Seed - self.seed = 42 - self.rng = np.random.default_rng(seed=self.seed) - def test_reshape(self): rv = randvars.Normal( mean=np.random.uniform(size=(4, 3)), @@ -765,7 +661,7 @@ def test_cov_cholesky_cov_cholesky_passed(self): rv = randvars.Normal( mean=np.random.uniform(size=(2, 2)), cov=random_spd_matrix(rng=self.rng, dim=4), - cov_cholesky=cov_cholesky, + cache={"cov_cholesky": cov_cholesky}, ) with self.subTest("Cholesky precomputed"): diff --git a/tests/test_randvars/test_random_variable.py b/tests/test_randvars/test_random_variable.py index 81713c081..7d1b94958 100644 --- a/tests/test_randvars/test_random_variable.py +++ b/tests/test_randvars/test_random_variable.py @@ -4,11 +4,12 @@ import unittest import numpy as np -import pytest import scipy.stats import probnum from probnum import linops, randvars + +import pytest from tests.testing import NumpyAssertions diff --git a/tests/test_utils/test_linalg/test_cholesky_updates.py b/tests/test_utils/test_linalg/test_cholesky_updates.py deleted file mode 100644 index 2f05f272c..000000000 --- a/tests/test_utils/test_linalg/test_cholesky_updates.py +++ /dev/null @@ -1,65 +0,0 @@ -import numpy as np -import pytest - -from probnum.problems.zoo.linalg import random_spd_matrix -import probnum.utils.linalg as utlin - - -@pytest.fixture -def even_ndim(): - """Even dimension for the tests, because it is halfed in test_cholesky_optional - below.""" - return 10 - - -@pytest.fixture -def rng(): - return np.random.default_rng(seed=123) - - -@pytest.fixture -def spdmat1(even_ndim, rng): - return random_spd_matrix(rng, dim=even_ndim) - - -@pytest.fixture -def spdmat2(even_ndim, rng): - return random_spd_matrix(rng, dim=even_ndim) - - -def test_cholesky_update(spdmat1, spdmat2): - expected = np.linalg.cholesky(spdmat1 + spdmat2) - - S1 = np.linalg.cholesky(spdmat1) - S2 = np.linalg.cholesky(spdmat2) - received = utlin.cholesky_update(S1, S2) - np.testing.assert_allclose(expected, received) - - -def test_cholesky_optional(spdmat1, even_ndim): - """Assert that cholesky_update() transforms a non-square matrix square-root into a - correct Cholesky factor.""" - H = np.random.rand(even_ndim // 2, even_ndim) - expected = np.linalg.cholesky(H @ spdmat1 @ H.T) - S1 = np.linalg.cholesky(spdmat1) - received = utlin.cholesky_update(H @ S1) - np.testing.assert_allclose(expected, received) - - -def test_tril_to_positive_tril(): - - # Make a random tril matrix - mat = np.tril(np.random.rand(4, 4)) - scale = np.array([1.0, 1.0, 1e-5, 1e-5]) - signs = np.array([1.0, -1.0, -1.0, -1.0]) - tril = mat @ np.diag(scale) - tril_wrong_signs = tril @ np.diag(signs) - - # Call triu_to_positive_til - tril_received = utlin.tril_to_positive_tril(tril_wrong_signs) - - # Sanity check - np.testing.assert_allclose(tril @ tril.T, tril_received @ tril_received.T) - - # Assert that the initial tril matrix comes out - np.testing.assert_allclose(tril_received, tril) diff --git a/tests/test_utils/test_linalg/test_inner_product.py b/tests/test_utils/test_linalg/test_inner_product.py deleted file mode 100644 index 57822628f..000000000 --- a/tests/test_utils/test_linalg/test_inner_product.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for general inner products.""" - -import numpy as np -import pytest - -from probnum.problems.zoo.linalg import random_spd_matrix -from probnum.utils.linalg import induced_norm, inner_product - - -@pytest.fixture(scope="module", params=[1, 10, 50]) -def n(request) -> int: - """Vector size.""" - return request.param - - -@pytest.fixture(scope="module", params=[1, 3, 5]) -def m(request) -> int: - """Number of simultaneous vectors.""" - return request.param - - -@pytest.fixture(scope="module", params=[1, 3]) -def p(request) -> int: - """Number of matrices.""" - return request.param - - -@pytest.fixture(scope="module") -def vector0(n: int) -> np.ndarray: - rng = np.random.default_rng(86 + n) - return rng.standard_normal(size=(n,)) - - -@pytest.fixture(scope="module") -def vector1(n: int) -> np.ndarray: - rng = np.random.default_rng(567 + n) - return rng.standard_normal(size=(n,)) - - -@pytest.fixture(scope="module") -def array0(p: int, m: int, n: int) -> np.ndarray: - rng = np.random.default_rng(86 + p + m + n) - return rng.standard_normal(size=(p, m, n)) - - -@pytest.fixture(scope="module") -def array1(m: int, n: int) -> np.ndarray: - rng = np.random.default_rng(567 + m + n) - return rng.standard_normal(size=(m, n)) - - -def test_inner_product_vectors(vector0: np.ndarray, vector1: np.ndarray): - assert inner_product(v=vector0, w=vector1) == pytest.approx( - np.inner(vector0, vector1) - ) - - -def test_inner_product_arrays(array0: np.ndarray, array1: np.ndarray): - assert inner_product(v=array0, w=array1) == pytest.approx( - np.einsum("...i,...i", array0, array1) - ) - - -def test_euclidean_norm_vector(vector0: np.ndarray): - assert np.linalg.norm(vector0, ord=2) == pytest.approx(induced_norm(v=vector0)) - - -@pytest.mark.parametrize("axis", [0, 1]) -def test_euclidean_norm_array(array0: np.ndarray, axis: int): - assert np.linalg.norm(array0, axis=axis, ord=2) == pytest.approx( - induced_norm(v=array0, axis=axis) - ) - - -@pytest.mark.parametrize("axis", [0, 1]) -def test_induced_norm_array(array0: np.ndarray, axis: int): - inprod_mat = random_spd_matrix( - rng=np.random.default_rng(254), dim=array0.shape[axis] - ) - array0_moved_axis = np.moveaxis(array0, axis, -1) - A_array_0_moved_axis = np.squeeze( - inprod_mat @ array0_moved_axis[..., :, None], axis=-1 - ) - - assert np.sqrt( - np.sum(array0_moved_axis * A_array_0_moved_axis, axis=-1) - ) == pytest.approx(induced_norm(v=array0, A=inprod_mat, axis=axis)) diff --git a/tests/test_utils/test_linalg/test_orthogonalize.py b/tests/test_utils/test_linalg/test_orthogonalize.py deleted file mode 100644 index 2e4c25f2a..000000000 --- a/tests/test_utils/test_linalg/test_orthogonalize.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Tests for orthogonalization functions.""" - -from functools import partial -from typing import Callable, Union - -import numpy as np -import pytest - -from probnum import linops -from probnum.problems.zoo.linalg import random_spd_matrix -from probnum.utils.linalg import ( - double_gram_schmidt, - gram_schmidt, - modified_gram_schmidt, -) - -n = 100 - - -@pytest.fixture(scope="module", params=[1, 10, 50]) -def basis_size(request) -> int: - """Number of basis vectors.""" - return request.param - - -@pytest.fixture(scope="module") -def vector() -> np.ndarray: - rng = np.random.default_rng(526367 + n) - return rng.standard_normal(size=(n,)) - - -@pytest.fixture(scope="module") -def vectors() -> np.ndarray: - rng = np.random.default_rng(234 + n) - return rng.standard_normal(size=(2, 10, n)) - - -@pytest.fixture( - scope="module", - params=[ - np.eye(n), - linops.Identity(n), - linops.Scaling(factors=1.0, shape=(n, n)), - np.inner, - ], -) -def inprod(request) -> int: - return request.param - - -@pytest.fixture( - scope="module", - params=[ - partial(double_gram_schmidt, gram_schmidt_fn=gram_schmidt), - partial(double_gram_schmidt, gram_schmidt_fn=modified_gram_schmidt), - ], -) -def orthogonalization_fn(request) -> int: - return request.param - - -def test_is_orthogonal( - vector: np.ndarray, - basis_size: int, - inprod: Union[ - np.ndarray, - linops.LinearOperator, - Callable[[np.ndarray, np.ndarray], np.ndarray], - ], - orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], -): - # Compute orthogonal basis - seed = abs(32 + hash(basis_size)) - basis = np.random.default_rng(seed).normal(size=(vector.shape[0], basis_size)) - orthogonal_basis, _ = np.linalg.qr(basis) - orthogonal_basis = orthogonal_basis.T - - # Orthogonalize vector - ortho_vector = orthogonalization_fn( - v=vector, orthogonal_basis=orthogonal_basis, inner_product=inprod - ) - np.testing.assert_allclose( - orthogonal_basis @ ortho_vector, - np.zeros((basis_size,)), - atol=1e-12, - rtol=1e-12, - ) - - -def test_is_normalized( - vector: np.ndarray, - basis_size: int, - orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], -): - # Compute orthogonal basis - seed = abs(9467 + hash(basis_size)) - basis = np.random.default_rng(seed).normal(size=(vector.shape[0], basis_size)) - orthogonal_basis, _ = np.linalg.qr(basis) - orthogonal_basis = orthogonal_basis.T - - # Orthogonalize vector - ortho_vector = orthogonalization_fn( - v=vector, orthogonal_basis=orthogonal_basis, normalize=True - ) - - assert np.inner(ortho_vector, ortho_vector) == pytest.approx(1.0) - - -@pytest.mark.parametrize( - "inner_product_matrix", - [ - np.diag(np.random.default_rng(123).standard_gamma(1.0, size=(n,))), - 5 * np.eye(n), - random_spd_matrix(rng=np.random.default_rng(46), dim=n), - ], -) -def test_noneuclidean_innerprod( - vector: np.ndarray, - basis_size: int, - inner_product_matrix: np.ndarray, - orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], -): - evals, evecs = np.linalg.eigh(inner_product_matrix) - orthogonal_basis = evecs * 1 / np.sqrt(evals) - orthogonal_basis = orthogonal_basis[:, 0:basis_size].T - - # Orthogonalize vector - ortho_vector = orthogonalization_fn( - v=vector, - orthogonal_basis=orthogonal_basis, - inner_product=inner_product_matrix, - normalize=False, - ) - - np.testing.assert_allclose( - orthogonal_basis @ inner_product_matrix @ ortho_vector, - np.zeros((basis_size,)), - atol=1e-12, - rtol=1e-12, - ) - - -def test_broadcasting( - vectors: np.ndarray, - basis_size: int, - orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], -): - # Compute orthogonal basis - seed = abs(32 + hash(basis_size)) - basis = np.random.default_rng(seed).normal(size=(vectors.shape[-1], basis_size)) - orthogonal_basis, _ = np.linalg.qr(basis) - orthogonal_basis = orthogonal_basis.T - - # Orthogonalize vector - ortho_vectors = orthogonalization_fn(v=vectors, orthogonal_basis=orthogonal_basis) - np.testing.assert_allclose( - np.squeeze(orthogonal_basis @ ortho_vectors[..., None], axis=-1), - np.zeros(vectors.shape[:-1] + (basis_size,)), - atol=1e-12, - rtol=1e-12, - ) diff --git a/tests/testing/__init__.py b/tests/testing/__init__.py index 132fafc4a..6da1aa37e 100644 --- a/tests/testing/__init__.py +++ b/tests/testing/__init__.py @@ -1,2 +1 @@ from .assertions import * -from .statistics import * diff --git a/tests/testing/statistics.py b/tests/testing/statistics.py deleted file mode 100644 index 4907224d4..000000000 --- a/tests/testing/statistics.py +++ /dev/null @@ -1,47 +0,0 @@ -"""This module implements some test statistics that are used in multiple test suites.""" - - -import numpy as np - -__all__ = ["chi_squared_statistic"] - - -def chi_squared_statistic(realisations, means, covs): - """Compute the multivariate chi-squared test statistic for a set of realisations of - a random variable. - - For :math:`N`, :math:`d`-dimensional realisations :math:`x_1, ..., x_N` - with (assumed) means :math:`m_1, ..., m_N` and covariances - :math:`C_1, ..., C_N`, compute the value - - .. math:`\\chi^2 - = \\frac{1}{Nd} \\sum_{n=1}^N (x_n - m_n)^\\top C_n^{-1}(x_n - m_n).` - - If it is roughly equal to 1, the samples are likely to correspond to given - mean and covariance. - - Parameters - ---------- - realisations : array_like - :math:`N` realisations of a :math:`d`-dimensional random variable. Shape (N, d). - means : array_like - :math:`N`, :math:`d`-dimensional (assumed) means of a random variable. - Shape (N, d). - realisations : array_like - :math:`N`, :math:`d \\times d`-dimensional (assumed) covariances of a random - variable. Shape (N, d, d). - """ - if not realisations.shape == means.shape == covs.shape[:-1]: - print(realisations.shape, means.shape, covs.shape) - raise TypeError("Inputs do not align") - centered_realisations = realisations - means - centered_2 = np.linalg.solve(covs, centered_realisations) - return _dot_along_last_axis(centered_realisations, centered_2).mean() - - -def _dot_along_last_axis(a, b): - """Dot product of (N, K) and (N, K) into (N,). - - Extracted, because otherwise I keep having to look up einsum... - """ - return np.einsum("...j, ...j->...", a, b) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..04987fba8 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1 @@ +from . import random diff --git a/tests/utils/random.py b/tests/utils/random.py new file mode 100644 index 000000000..61348aac8 --- /dev/null +++ b/tests/utils/random.py @@ -0,0 +1,143 @@ +import hashlib +import numbers +from typing import Optional, Union + +import numpy as np + +from probnum import backend +from probnum.backend.random import RNGState +from probnum.backend.typing import DTypeLike, IntLike, ShapeLike + +__all__ = [ + "rng_state_from_sampling_args", +] + + +def rng_state_from_sampling_args( + *, + base_seed: IntLike, + shape: ShapeLike, + dtype: Optional[DTypeLike] = None, + **kwargs: Union[numbers.Number, np.ndarray, backend.Array], +) -> RNGState: + """Diversify random states for deterministic testing. + + When writing a test relying on "random" input data generated from a fixed random + seeds, a common pattern is to parametrize over seed and shape like so: + + >>> import pytest + >>> from probnum.backend.typing import ShapeType + >>> @pytest.fixture(params=[42, 43]) + ... def seed(request) -> int: + ... return request.param + + >>> @pytest.fixture(params=((2,), (4,))) + ... def shape(request) -> ShapeType: + ... return request.param + + >>> def test_function(seed: int, shape: ShapeType): + ... x = backend.random.uniform( + ... backend.random.rng_state(seed), + ... shape=shape, + ... ) + ... ... # Test something + + Unfortunately, when sampling with the same RNG state but with different shapes in + NumPy and JAX, some sampling routines produce partially identical arrays. + + >>> np.random.default_rng(42).uniform(size=(2,)) + array([0.77395605, 0.43887844]) + >>> np.random.default_rng(42).uniform(size=(4,)) + array([0.77395605, 0.43887844, 0.85859792, 0.69736803]) + + To diversify test data, while retaining test determinism (especially under the order + of test execution!), `rng_state_from_sampling_args` provides a deterministic way to + modify the base seed through other arguments passed to the sampling routine: + + >>> def test_data(seed: int, shape: ShapeType) -> backend.Array: + ... return backend.random.uniform( + ... rng_state_from_sampling_args(base_seed=seed, shape=shape), + ... shape=shape, + ... ) + + >>> backend.all(test_data(42, shape=(2,)) != test_data(42, shape=(4,))[:2]) + True + + Parameters + ---------- + base_seed + Seed value common to all sample calls in a parametrized test. + shape + `shape` argument to the `backend.random.` call. + dtype + `dtype` argument to the `backend.random.` call. + **kwargs + Any other keyword argument passed to the `backend.random.` call. + + Returns + ------- + rng_state + An RNG state object that is deterministically generated from the function's + arguments using a cryptographic hash function. + + Raises + ------ + ValueError + If the `base_seed` is a negative number. + TypeError + If the type of any of the `kwargs` is not supported. + """ + + # Hash unique representations of the arguments into a 7-byte positive integer. + # We choose 7 bytes, since an 8-byte positive integer could already overflow as an + # int64. + h = hashlib.blake2b(digest_size=7) + + # `base_seed` + base_seed = int(base_seed) + + if base_seed < 0: + raise ValueError("`base_seed` must be a non-negative `int`") + + h.update(hex(base_seed).encode()) + + # `shape` + shape = backend.asshape(shape) + + h.update(b"(") + + for entry in shape: + h.update(hex(entry).encode()) + + h.update(b")") + + # `dtype` + if dtype is not None: + dtype = backend.asdtype(dtype) + + h.update(str(dtype).encode()) + + # `kwargs` + for key, value in kwargs.items(): + h.update(key.encode()) + + if isinstance(value, numbers.Number) and ( + # NumPy doesn't handle `fractions.Fraction` too well + not isinstance(value, numbers.Rational) + or isinstance(value, numbers.Real) + ): + h.update(np.asarray(value).tobytes()) + elif isinstance(value, np.ndarray): + h.update(value.tobytes(order="A")) + elif backend.isarray(value): + h.update(backend.to_numpy(value).tobytes(order="A")) + else: + raise TypeError( + "Values passed by `kwargs` must be either numbers, `np.ndarray`s, or " + f"`backend.Array`s, not {type(value)}." + ) + + # Convert hash to positive integer + seed_int = abs(int(h.hexdigest(), base=16)) + + return backend.random.rng_state(seed_int) diff --git a/tox.ini b/tox.ini index 4705caac6..e53d364ef 100644 --- a/tox.ini +++ b/tox.ini @@ -4,11 +4,18 @@ # and then run "tox" from this directory. [tox] -envlist = py3, docs, benchmarks, black, isort, pylint +envlist = py3-{numpy,jax,torch}, docs, benchmarks, black, isort, pylint [testenv] usedevelop = True -extras = full +extras = + full + jax: jax + torch: torch +setenv = + numpy: PROBNUM_BACKEND = numpy + jax: PROBNUM_BACKEND = jax + torch: PROBNUM_BACKEND = torch deps = -r{toxinidir}/tests/requirements.txt commands = pytest {posargs:--cov=probnum --no-cov-on-fail --cov-report=xml} --doctest-modules --color=yes @@ -66,6 +73,8 @@ commands = # Global Linting Pass pylint src/probnum --disable="no-member,abstract-method,arguments-differ,arguments-renamed,redefined-builtin,redefined-outer-name,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unnecessary-pass,unused-variable,unused-argument,attribute-defined-outside-init,no-else-return,no-else-raise,no-self-use,else-if-used,consider-using-from-import,duplicate-code,missing-module-docstring,missing-class-docstring,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,useless-param-doc,useless-type-doc,missing-return-type-doc" --jobs=0 # Per-package Linting Passes + pylint src/probnum/backend --jobs=0 + pylint src/probnum/compat --jobs=0 pylint src/probnum/diffeq --disable="redefined-outer-name,too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unnecessary-pass,unused-variable,unused-argument,no-self-use,duplicate-code,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0 pylint src/probnum/filtsmooth --disable="no-member,arguments-differ,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unused-variable,unused-argument,no-self-use,duplicate-code,useless-param-doc" --jobs=0 pylint src/probnum/linalg --disable="no-member,abstract-method,arguments-differ,else-if-used,redefined-builtin,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unused-argument,attribute-defined-outside-init,no-else-return,no-else-raise,no-self-use,duplicate-code,missing-module-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0