diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 561cde407..8fc2a8cf5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,7 +53,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e . + pip install -e ".[dev]" python --version - name: Run tests run: | @@ -97,7 +97,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e . + pip install -e ".[dev]" python --version - name: Run tests # This job uses a cmd shell, therefore the environment variable syntax is different! diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 25251e5e0..66ad0149a 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -13,6 +13,8 @@ dependencies: - pytest - pytest-cov - libgcc<15 + - pydantic>=2.0.0 + - preliz - pip - pip: - jax diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5d4f62bb2..fbe157ce0 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -46,6 +46,21 @@ Distributions Skellam histogram_approximation +Prior +===== + +.. currentmodule:: pymc_extras.prior +.. autosummary:: + :toctree: generated/ + + create_dim_handler + handle_dims + Prior + VariableFactory + sample_prior + Censored + Scaled + Transforms ========== diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py new file mode 100644 index 000000000..20702bfc9 --- /dev/null +++ b/pymc_extras/prior.py @@ -0,0 +1,1186 @@ +"""Class that represents a prior distribution. + +The `Prior` class is a wrapper around PyMC distributions that allows the user +to create outside of the PyMC model. + +Examples +-------- +Create a normal prior. + +.. code-block:: python + + from pymc_extras.prior import Prior + + normal = Prior("Normal") + +Create a hierarchical normal prior by using distributions for the parameters +and specifying the dims. + +.. code-block:: python + + hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + +Create a non-centered hierarchical normal prior with the `centered` parameter. + +.. code-block:: python + + non_centered_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + # Only change needed to make it non-centered + centered=False, + ) + +Create a hierarchical beta prior by using Beta distribution, distributions for +the parameters, and specifying the dims. + +.. code-block:: python + + hierarchical_beta = Prior( + "Beta", + alpha=Prior("HalfNormal"), + beta=Prior("HalfNormal"), + dims="channel", + ) + +Create a transformed hierarchical normal prior by using the `transform` +parameter. Here the "sigmoid" transformation comes from `pm.math`. + +.. code-block:: python + + transformed_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + transform="sigmoid", + dims="channel", + ) + +Create a prior with a custom transform function by registering it with +`register_tensor_transform`. + +.. code-block:: python + + from pymc_extras.prior import register_tensor_transform + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + +""" + +from __future__ import annotations + +import copy + +from collections.abc import Callable +from inspect import signature +from typing import Any, Protocol, runtime_checkable + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import xarray as xr + +from pydantic import InstanceOf, validate_call +from pydantic.dataclasses import dataclass +from pymc.distributions.shape_utils import Dims + + +class UnsupportedShapeError(Exception): + """Error for when the shapes from variables are not compatible.""" + + +class UnsupportedDistributionError(Exception): + """Error for when an unsupported distribution is used.""" + + +class UnsupportedParameterizationError(Exception): + """The follow parameterization is not supported.""" + + +class MuAlreadyExistsError(Exception): + """Error for when 'mu' is present in Prior.""" + + def __init__(self, distribution: Prior) -> None: + self.distribution = distribution + self.message = f"The mu parameter is already defined in {distribution}" + super().__init__(self.message) + + +class UnknownTransformError(Exception): + """Error for when an unknown transform is used.""" + + +def _remove_leading_xs(args: list[str | int]) -> list[str | int]: + """Remove leading 'x' from the args.""" + while args and args[0] == "x": + args.pop(0) + + return args + + +def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable: + """Take a tensor of dims `dims` and align it to `desired_dims`. + + Doesn't check for validity of the dims + + Examples + -------- + 1D to 2D with new dim + + .. code-block:: python + + x = np.array([1, 2, 3]) + dims = "channel" + + desired_dims = ("channel", "group") + + handle_dims(x, dims, desired_dims) + + """ + x = pt.as_tensor_variable(x) + + if np.ndim(x) == 0: + return x + + dims = dims if isinstance(dims, tuple) else (dims,) + desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,) + + if difference := set(dims).difference(desired_dims): + raise UnsupportedShapeError( + f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. " + f"{difference} is missing from the desired dims." + ) + + aligned_dims = np.array(dims)[:, None] == np.array(desired_dims) + + missing_dims = aligned_dims.sum(axis=0) == 0 + new_idx = aligned_dims.argmax(axis=0) + + args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)] + args = _remove_leading_xs(args) + return x.dimshuffle(*args) + + +DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike] + + +def create_dim_handler(desired_dims: Dims) -> DimHandler: + """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function.""" + + def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: + return handle_dims(x, dims, desired_dims) + + return func + + +def _dims_to_str(obj: tuple[str, ...]) -> str: + if len(obj) == 1: + return f'"{obj[0]}"' + + return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")" + + +def _get_pymc_distribution(name: str) -> type[pm.Distribution]: + if not hasattr(pm, name): + raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}") + + return getattr(pm, name) + + +Transform = Callable[[pt.TensorLike], pt.TensorLike] + +CUSTOM_TRANSFORMS: dict[str, Transform] = {} + + +def register_tensor_transform(name: str, transform: Transform) -> None: + """Register a tensor transform function to be used in the `Prior` class. + + Parameters + ---------- + name : str + The name of the transform. + func : Callable[[pt.TensorLike], pt.TensorLike] + The function to apply to the tensor. + + Examples + -------- + Register a custom transform function. + + .. code-block:: python + + from pymc_extras.prior import ( + Prior, + register_tensor_transform, + ) + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + + """ + CUSTOM_TRANSFORMS[name] = transform + + +def _get_transform(name: str): + if name in CUSTOM_TRANSFORMS: + return CUSTOM_TRANSFORMS[name] + + for module in (pt, pm.math): + if hasattr(module, name): + break + else: + module = None + + if not module: + msg = ( + f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " + "If this is a custom function, register it with the " + "`pymc_extras.prior.register_tensor_transform` function before " + "previous function call." + ) + + raise UnknownTransformError(msg) + + return getattr(module, name) + + +def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]: + return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"} + + +@runtime_checkable +class VariableFactory(Protocol): + """Protocol for something that works like a Prior class.""" + + dims: tuple[str, ...] + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a TensorVariable.""" + + +def sample_prior( + factory: VariableFactory, + coords=None, + name: str = "var", + wrap: bool = False, + **sample_prior_predictive_kwargs, +) -> xr.Dataset: + """Sample the prior for an arbitrary VariableFactory. + + Parameters + ---------- + factory : VariableFactory + The factory to sample from. + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + wrap : bool, optional + Whether to wrap the variable in a `pm.Deterministic` node, by default False. + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from an arbitrary variable factory. + + .. code-block:: python + + import pymc as pm + + import pytensor.tensor as pt + + from pymc_extras.prior import sample_prior + + class CustomVariableDefinition: + def __init__(self, dims, n: int): + self.dims = dims + self.n = n + + def create_variable(self, name: str) -> "TensorVariable": + x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims) + return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0) + + cubic = CustomVariableDefinition(dims=("channel",), n=3) + coords = {"channel": ["C1", "C2", "C3"]} + # Doesn't include the return value + prior = sample_prior(cubic, coords=coords) + + prior_with = sample_prior(cubic, coords=coords, wrap=True) + + """ + coords = coords or {} + + if isinstance(factory.dims, str): + dims = (factory.dims,) + else: + dims = factory.dims + + if missing_keys := set(dims) - set(coords.keys()): + raise KeyError(f"Coords are missing the following dims: {missing_keys}") + + with pm.Model(coords=coords) as model: + if wrap: + pm.Deterministic(name, factory.create_variable(name), dims=factory.dims) + else: + factory.create_variable(name) + + return pm.sample_prior_predictive( + model=model, + **sample_prior_predictive_kwargs, + ).prior + + +class Prior: + """A class to represent a prior distribution. + + Make use of the various helper methods to understand the distributions + better. + + - `preliz` attribute to get the equivalent distribution in `preliz` + - `sample_prior` method to sample from the prior + - `graph` get a dummy model graph with the distribution + - `constrain` to shift the distribution to a different range + + Parameters + ---------- + distribution : str + The name of PyMC distribution. + dims : Dims, optional + The dimensions of the variable, by default None + centered : bool, optional + Whether the variable is centered or not, by default True. + Only allowed for Normal distribution. + transform : str, optional + The name of the transform to apply to the variable after it is + created, by default None or no transform. The transformation must + be registered with `register_tensor_transform` function or + be available in either `pytensor.tensor` or `pymc.math`. + + """ + + # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family + non_centered_distributions: dict[str, dict[str, float]] = { + "Normal": {"mu": 0, "sigma": 1}, + "StudentT": {"mu": 0, "sigma": 1}, + "ZeroSumNormal": {"sigma": 1}, + } + + pymc_distribution: type[pm.Distribution] + pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None + + @validate_call + def __init__( + self, + distribution: str, + *, + dims: Dims | None = None, + centered: bool = True, + transform: str | None = None, + **parameters, + ) -> None: + self.distribution = distribution + self.parameters = parameters + self.dims = dims + self.centered = centered + self.transform = transform + + self._checks() + + @property + def distribution(self) -> str: + """The name of the PyMC distribution.""" + return self._distribution + + @distribution.setter + def distribution(self, distribution: str) -> None: + if hasattr(self, "_distribution"): + raise AttributeError("Can't change the distribution") + + self._distribution = distribution + self.pymc_distribution = _get_pymc_distribution(distribution) + + @property + def transform(self) -> str | None: + """The name of the transform to apply to the variable after it is created.""" + return self._transform + + @transform.setter + def transform(self, transform: str | None) -> None: + self._transform = transform + self.pytensor_transform = not transform or _get_transform(transform) # type: ignore + + @property + def dims(self) -> Dims: + """The dimensions of the variable.""" + return self._dims + + @dims.setter + def dims(self, dims) -> None: + if isinstance(dims, str): + dims = (dims,) + + if isinstance(dims, list): + dims = tuple(dims) + + self._dims = dims or () + + self._param_dims_work() + self._unique_dims() + + def __getitem__(self, key: str) -> Prior | Any: + """Return the parameter of the prior.""" + return self.parameters[key] + + def _checks(self) -> None: + if not self.centered: + self._correct_non_centered_distribution() + + self._parameters_are_at_least_subset_of_pymc() + self._convert_lists_to_numpy() + self._parameters_are_correct_type() + + def _parameters_are_at_least_subset_of_pymc(self) -> None: + pymc_params = _get_pymc_parameters(self.pymc_distribution) + if not set(self.parameters.keys()).issubset(pymc_params): + msg = ( + f"Parameters {set(self.parameters.keys())} " + "are not a subset of the pymc distribution " + f"parameters {set(pymc_params)}" + ) + raise ValueError(msg) + + def _convert_lists_to_numpy(self) -> None: + def convert(x): + if not isinstance(x, list): + return x + + return np.array(x) + + self.parameters = {key: convert(value) for key, value in self.parameters.items()} + + def _parameters_are_correct_type(self) -> None: + supported_types = ( + int, + float, + np.ndarray, + Prior, + pt.TensorVariable, + VariableFactory, + ) + + incorrect_types = { + param: type(value) + for param, value in self.parameters.items() + if not isinstance(value, supported_types) + } + if incorrect_types: + msg = ( + "Parameters must be one of the following types: " + f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}" + ) + raise ValueError(msg) + + def _correct_non_centered_distribution(self) -> None: + if not self.centered and self.distribution not in self.non_centered_distributions: + raise UnsupportedParameterizationError( + f"{self.distribution!r} is not supported for non-centered parameterization. " + f"Choose from {list(self.non_centered_distributions.keys())}" + ) + + required_parameters = set(self.non_centered_distributions[self.distribution].keys()) + + if set(self.parameters.keys()) < required_parameters: + msg = " and ".join([f"{param!r}" for param in required_parameters]) + raise ValueError( + f"Must have at least {msg} parameter for non-centered for {self.distribution!r}" + ) + + def _unique_dims(self) -> None: + if not self.dims: + return + + if len(self.dims) != len(set(self.dims)): + raise ValueError("Dims must be unique") + + def _param_dims_work(self) -> None: + other_dims = set() + for value in self.parameters.values(): + if hasattr(value, "dims"): + other_dims.update(value.dims) + + if not other_dims.issubset(self.dims): + raise UnsupportedShapeError( + f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}" + ) + + def __str__(self) -> str: + """Return a string representation of the prior.""" + param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()]) + param_str = "" if not param_str else f", {param_str}" + + dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else "" + centered_str = f", centered={self.centered}" if not self.centered else "" + transform_str = f', transform="{self.transform}"' if self.transform else "" + return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})' + + def __repr__(self) -> str: + """Return a string representation of the prior.""" + return f"{self}" + + def _create_parameter(self, param, value, name): + if not hasattr(value, "create_variable"): + return value + + child_name = f"{name}_{param}" + return self.dim_handler(value.create_variable(child_name), value.dims) + + def _create_centered_variable(self, name: str): + parameters = { + param: self._create_parameter(param, value, name) + for param, value in self.parameters.items() + } + return self.pymc_distribution(name, **parameters, dims=self.dims) + + def _create_non_centered_variable(self, name: str) -> pt.TensorVariable: + def handle_variable(var_name: str): + parameter = self.parameters[var_name] + if not hasattr(parameter, "create_variable"): + return parameter + + return self.dim_handler( + parameter.create_variable(f"{name}_{var_name}"), + parameter.dims, + ) + + defaults = self.non_centered_distributions[self.distribution] + other_parameters = { + param: handle_variable(param) + for param in self.parameters.keys() + if param not in defaults + } + offset = self.pymc_distribution( + f"{name}_offset", + **defaults, + **other_parameters, + dims=self.dims, + ) + if "mu" in self.parameters: + mu = ( + handle_variable("mu") + if isinstance(self.parameters["mu"], Prior) + else self.parameters["mu"] + ) + else: + mu = 0 + + sigma = ( + handle_variable("sigma") + if isinstance(self.parameters["sigma"], Prior) + else self.parameters["sigma"] + ) + + return pm.Deterministic( + name, + mu + sigma * offset, + dims=self.dims, + ) + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a PyMC variable from the prior. + + Must be used in a PyMC model context. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a hierarchical normal variable in larger PyMC model. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + with pm.Model(coords=coords): + var = dist.create_variable("var") + + """ + self.dim_handler = create_dim_handler(self.dims) + + if self.transform: + var_name = f"{name}_raw" + + def transform(var): + return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims) + else: + var_name = name + + def transform(var): + return var + + create_variable = ( + self._create_centered_variable if self.centered else self._create_non_centered_variable + ) + var = create_variable(name=var_name) + return transform(var) + + @property + def preliz(self): + """Create an equivalent preliz distribution. + + Helpful to visualize a distribution when it is univariate. + + Returns + ------- + preliz.distributions.Distribution + + Examples + -------- + Create a preliz distribution from a prior. + + .. code-block:: python + + from pymc_extras.prior import Prior + + dist = Prior("Gamma", alpha=5, beta=1) + dist.preliz.plot_pdf() + + """ + import preliz as pz + + return getattr(pz, self.distribution)(**self.parameters) + + def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior: + """Create a new prior with a given mass constrained within the given bounds. + + Wrapper around `preliz.maxent`. + + Parameters + ---------- + lower : float + The lower bound. + upper : float + The upper bound. + mass: float = 0.95 + The mass of the distribution to keep within the bounds. + kwargs : dict + Additional arguments to pass to `pz.maxent`. + + Returns + ------- + Prior + The maximum entropy prior with a mass constrained to the given bounds. + + Examples + -------- + Create a Beta distribution that is constrained to have 95% of the mass + between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + ).constrain(lower=0.5, upper=0.8) + + Create a Beta distribution with mean 0.6, that is constrained to + have 95% of the mass between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + mu=0.6, + ).constrain(lower=0.5, upper=0.8) + + """ + from preliz import maxent + + if self.transform: + raise ValueError("Can't constrain a transformed variable") + + if kwargs is None: + kwargs = {} + kwargs.setdefault("plot", False) + + if kwargs["plot"]: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict + else: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict + + return Prior( + self.distribution, + dims=self.dims, + transform=self.transform, + centered=self.centered, + **new_parameters, + ) + + def __eq__(self, other) -> bool: + """Check if two priors are equal.""" + if not isinstance(other, Prior): + return False + + try: + np.testing.assert_equal(self.parameters, other.parameters) + except AssertionError: + return False + + return ( + self.distribution == other.distribution + and self.dims == other.dims + and self.centered == other.centered + and self.transform == other.transform + ) + + def sample_prior( + self, + coords=None, + name: str = "var", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a hierarchical normal distribution. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def __deepcopy__(self, memo) -> Prior: + """Return a deep copy of the prior.""" + if id(self) in memo: + return memo[id(self)] + + copy_obj = Prior( + self.distribution, + dims=copy.copy(self.dims), + centered=self.centered, + transform=self.transform, + **copy.deepcopy(self.parameters), + ) + memo[id(self)] = copy_obj + return copy_obj + + def deepcopy(self) -> Prior: + """Return a deep copy of the prior.""" + return copy.deepcopy(self) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create the graph for a 2D transformed hierarchical distribution. + + .. code-block:: python + + from pymc_extras.prior import Prior + + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + sigma = Prior("HalfNormal", dims="channel") + dist = Prior( + "Normal", + mu=mu, + sigma=sigma, + dims=("channel", "geo"), + centered=False, + transform="sigmoid", + ) + + dist.to_graph() + + .. image:: /_static/example-graph.png + :alt: Example graph + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create a likelihood variable from the prior. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + + dist = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model(): + # Create the likelihood variable + mu = pm.Normal("mu", mu=0, sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution!r} is not supported." + ) + + if "mu" in self.parameters: + raise MuAlreadyExistsError(self) + + distribution = self.deepcopy() + distribution.parameters["mu"] = mu + distribution.parameters["observed"] = observed + return distribution.create_variable(name) + + +class VariableNotFound(Exception): + """Variable is not found.""" + + +def _remove_random_variable(var: pt.TensorVariable) -> None: + if var.name is None: + raise ValueError("This isn't removable") + + name: str = var.name + + model = pm.modelcontext(None) + for idx, free_rv in enumerate(model.free_RVs): + if var == free_rv: + index_to_remove = idx + break + else: + raise VariableNotFound(f"Variable {var.name!r} not found") + + var.name = None + model.free_RVs.pop(index_to_remove) + model.named_vars.pop(name) + + +@dataclass +class Censored: + """Create censored random variable. + + Examples + -------- + Create a censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + Create hierarchical censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + samples = censored_normal.sample_prior(coords=coords) + + """ + + distribution: InstanceOf[Prior] + lower: float | InstanceOf[pt.TensorVariable] = -np.inf + upper: float | InstanceOf[pt.TensorVariable] = np.inf + + def __post_init__(self) -> None: + """Check validity at initialization.""" + if not self.distribution.centered: + raise ValueError( + "Censored distribution must be centered so that .dist() API can be used on distribution." + ) + + if self.distribution.transform is not None: + raise ValueError( + "Censored distribution can't have a transform so that .dist() API can be used on distribution." + ) + + @property + def dims(self) -> tuple[str, ...]: + """The dims from the distribution to censor.""" + return self.distribution.dims + + @dims.setter + def dims(self, dims) -> None: + self.distribution.dims = dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create censored random variable.""" + dist = self.distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + def sample_prior( + self, + coords=None, + name: str = "variable", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a censored Gamma distribution. + + .. code-block:: python + + gamma = Prior("Gamma", mu=1, sigma=1, dims="channel") + dist = Censored(gamma, lower=0.5) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create graph for a censored Normal distribution + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + censored_normal.to_graph() + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create observed censored variable. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a censored likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal", sigma=Prior("HalfNormal")) + dist = Censored(normal, lower=0) + + observed = 1 + + with pm.Model(): + # Create the likelihood variable + mu = pm.HalfNormal("mu", sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution.distribution!r} is not supported." + ) + + if "mu" in self.distribution.parameters: + raise MuAlreadyExistsError(self.distribution) + + distribution = self.distribution.deepcopy() + distribution.parameters["mu"] = mu + + dist = distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + observed=observed, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + +class Scaled: + """Scaled distribution for numerical stability.""" + + def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None: + self.dist = dist + self.factor = factor + + @property + def dims(self) -> Dims: + """The dimensions of the scaled distribution.""" + return self.dist.dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a scaled variable. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The scaled variable. + """ + var = self.dist.create_variable(f"{name}_unscaled") + return pm.Deterministic(name, var * self.factor, dims=self.dims) diff --git a/pyproject.toml b/pyproject.toml index 02a604703..982699723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,9 @@ dynamic = ["version"] # specify the version in the __init__.py file dependencies = [ "pymc>=5.21.1", "scikit-learn", - "better-optimize>=0.1.2" + "better-optimize>=0.1.2", + "pydantic>=2.0.0", + "preliz", ] [project.optional-dependencies] @@ -47,6 +49,7 @@ complete = [ ] dev = [ "pytest>=6.0", + "pytest-mock", "dask[all]<2025.1.1", "blackjax", "statsmodels", diff --git a/tests/test_prior.py b/tests/test_prior.py new file mode 100644 index 000000000..c534da021 --- /dev/null +++ b/tests/test_prior.py @@ -0,0 +1,921 @@ +from copy import deepcopy +from typing import NamedTuple + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +import xarray as xr + +from graphviz.graphs import Digraph +from preliz.distributions import distributions as preliz_distributions +from pydantic import ValidationError +from pymc.model_graph import fast_eval + +from pymc_extras.prior import ( + Censored, + MuAlreadyExistsError, + Prior, + Scaled, + UnknownTransformError, + UnsupportedDistributionError, + UnsupportedParameterizationError, + UnsupportedShapeError, + VariableFactory, + handle_dims, + register_tensor_transform, + sample_prior, +) + + +@pytest.mark.parametrize( + "x, dims, desired_dims, expected_fn", + [ + (np.arange(3), "channel", "channel", lambda x: x), + (np.arange(3), "channel", ("geo", "channel"), lambda x: x), + (np.arange(3), "channel", ("channel", "geo"), lambda x: x[:, None]), + (np.arange(3), "channel", ("x", "y", "channel", "geo"), lambda x: x[:, None]), + ( + np.arange(3 * 2).reshape(3, 2), + ("channel", "geo"), + ("geo", "x", "y", "channel"), + lambda x: x.T[:, None, None, :], + ), + ( + np.arange(4 * 2 * 3).reshape(4, 2, 3), + ("channel", "geo", "store"), + ("geo", "x", "store", "channel"), + lambda x: x.swapaxes(0, 2).swapaxes(0, 1)[:, None, :, :], + ), + ], + ids=[ + "same_dims", + "different_dims", + "dim_padding", + "just_enough_dims", + "transpose_and_padding", + "swaps_and_padding", + ], +) +def test_handle_dims(x, dims, desired_dims, expected_fn) -> None: + result = handle_dims(x, dims, desired_dims) + if isinstance(result, pt.TensorVariable): + result = fast_eval(result) + + np.testing.assert_array_equal(result, expected_fn(x)) + + +@pytest.mark.parametrize( + "x, dims, desired_dims", + [ + (np.ones(3), "channel", "something_else"), + (np.ones((3, 2)), ("a", "b"), ("a", "B")), + ], + ids=["no_incommon", "some_incommon"], +) +def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None: + match = " are not a subset of the desired dims " + with pytest.raises(UnsupportedShapeError, match=match): + handle_dims(x, dims, desired_dims) + + +def test_missing_transform() -> None: + match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" + with pytest.raises(UnknownTransformError, match=match): + Prior("Normal", transform="foo_bar") + + +def test_get_item() -> None: + var = Prior("Normal", mu=0, sigma=1) + + assert var["mu"] == 0 + assert var["sigma"] == 1 + + +def test_noncentered_needs_params() -> None: + with pytest.raises(ValueError): + Prior( + "Normal", + centered=False, + ) + + +def test_different_than_pymc_params() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, b=1) + + +def test_non_unique_dims() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, sigma=1, dims=("channel", "channel")) + + +def test_doesnt_check_validity_parameterization() -> None: + try: + Prior("Normal", mu=0, sigma=1, tau=1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_doesnt_check_validity_values() -> None: + try: + Prior("Normal", mu=0, sigma=-1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_preliz() -> None: + var = Prior("Normal", mu=0, sigma=1) + dist = var.preliz + assert isinstance(dist, preliz_distributions.Distribution) + + +@pytest.mark.parametrize( + "var, expected", + [ + (Prior("Normal", mu=0, sigma=1), 'Prior("Normal", mu=0, sigma=1)'), + ( + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")), + 'Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"))', + ), + (Prior("Normal", dims="channel"), 'Prior("Normal", dims="channel")'), + ( + Prior("Normal", mu=0, sigma=1, transform="sigmoid"), + 'Prior("Normal", mu=0, sigma=1, transform="sigmoid")', + ), + ], +) +def test_str(var, expected) -> None: + assert str(var) == expected + + +@pytest.mark.parametrize( + "var", + [ + Prior("Normal", mu=0, sigma=1), + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel"), + Prior("Normal", dims=("geo", "channel")), + ], +) +def test_repr(var) -> None: + assert eval(repr(var)) == var + + +def test_invalid_distribution() -> None: + with pytest.raises(UnsupportedDistributionError): + Prior("Invalid") + + +def test_broadcast_doesnt_work(): + with pytest.raises(UnsupportedShapeError): + Prior( + "Normal", + mu=0, + sigma=Prior("HalfNormal", sigma=1, dims="x"), + dims="y", + ) + + +def test_dim_workaround_flaw() -> None: + distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="y", + ) + + try: + distribution["mu"].dims = "x" + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + with pytest.raises(UnsupportedShapeError): + distribution._param_dims_work() + + +def test_noncentered_error() -> None: + with pytest.raises(UnsupportedParameterizationError): + Prior( + "Gamma", + mu=0, + sigma=1, + dims="x", + centered=False, + ) + + +def test_create_variable_multiple_times() -> None: + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + + coords = { + "channel": ["a", "b", "c"], + } + with pm.Model(coords=coords) as model: + mu.create_variable("mu") + mu.create_variable("mu_2") + + suffixes = [ + "", + "_offset", + "_mu", + "_sigma", + ] + dims = [(3,), (3,), (), ()] + + for prefix in ["mu", "mu_2"]: + for suffix, dim in zip(suffixes, dims, strict=False): + assert fast_eval(model[f"{prefix}{suffix}"]).shape == dim + + +@pytest.fixture +def large_var() -> Prior: + mu = Prior( + "Normal", + mu=Prior("Normal", mu=1), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + sigma = Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo") + + return Prior("Normal", mu=mu, sigma=sigma, dims=("geo", "channel")) + + +def test_create_variable(large_var) -> None: + coords = { + "channel": ["a", "b", "c"], + "geo": ["x", "y"], + } + with pm.Model(coords=coords) as model: + large_var.create_variable("var") + + var_names = [ + "var", + "var_mu", + "var_sigma", + "var_mu_offset", + "var_mu_mu", + "var_mu_sigma", + "var_sigma_sigma", + ] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [ + (2, 3), + (3,), + (2,), + (3,), + (), + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_transform() -> None: + var = Prior("Normal", mu=0, sigma=1, transform="sigmoid") + + with pm.Model() as model: + var.create_variable("var") + + var_names = [ + "var", + "var_raw", + ] + dims = [ + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_constrain_with_transform_error() -> None: + var = Prior("Normal", transform="sigmoid") + + with pytest.raises(ValueError): + var.constrain(lower=0, upper=1) + + +def test_constrain(mocker) -> None: + var = Prior("Normal") + + mocker.patch( + "preliz.maxent", + return_value=mocker.Mock(params_dict={"mu": 5, "sigma": 2}), + ) + + new_var = var.constrain(lower=0, upper=1) + assert new_var == Prior("Normal", mu=5, sigma=2) + + +def test_dims_change() -> None: + var = Prior("Normal", mu=0, sigma=1) + var.dims = "channel" + + assert var.dims == ("channel",) + + +def test_dims_change_error() -> None: + mu = Prior("Normal", dims="channel") + var = Prior("Normal", mu=mu, dims="channel") + + with pytest.raises(UnsupportedShapeError): + var.dims = "geo" + + +def test_deepcopy() -> None: + priors = { + "alpha": Prior("Beta", alpha=1, beta=1), + "gamma": Prior("Normal", mu=0, sigma=1), + } + + new_priors = deepcopy(priors) + priors["alpha"].dims = "channel" + + assert new_priors["alpha"].dims == () + + +@pytest.fixture +def mmm_default_model_config(): + return { + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "likelihood": { + "dist": "Normal", + "kwargs": { + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + }, + }, + "gamma_control": { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 2}, + "dims": "control", + }, + "gamma_fourier": { + "dist": "Laplace", + "kwargs": {"mu": 0, "b": 1}, + "dims": "fourier_mode", + }, + } + + +def test_sample_prior() -> None: + var = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + transform="sigmoid", + ) + + coords = {"channel": ["A", "B", "C"]} + prior = var.sample_prior(coords=coords, draws=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_sample_prior_missing_coords() -> None: + dist = Prior("Normal", dims="channel") + + with pytest.raises(KeyError, match="Coords"): + dist.sample_prior() + + +def test_to_graph() -> None: + hierarchical_distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + G = hierarchical_distribution.to_graph() + assert isinstance(G, Digraph) + + +def test_equality_non_prior() -> None: + dist = Prior("Normal") + + assert dist != 1 + + +def test_deepcopy_memo() -> None: + memo = {} + dist = Prior("Normal") + memo[id(dist)] = dist + deepcopy(dist, memo) + assert len(memo) == 1 + deepcopy(dist, memo) + + assert len(memo) == 1 + + +def test_create_likelihood_variable() -> None: + distribution = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model() as model: + mu = pm.Normal("mu") + + data = distribution.create_likelihood_variable("data", mu=mu, observed=10) + + assert model.observed_RVs == [data] + assert "data_sigma" in model + + +def test_create_likelihood_variable_already_has_mu() -> None: + distribution = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + + with pm.Model(): + mu = pm.Normal("mu") + + with pytest.raises(MuAlreadyExistsError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_create_likelihood_non_mu_parameterized_distribution() -> None: + distribution = Prior("Cauchy") + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_non_centered_student_t() -> None: + try: + Prior( + "StudentT", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + nu=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_cant_reset_distribution() -> None: + dist = Prior("Normal") + with pytest.raises(AttributeError, match="Can't change the distribution"): + dist.distribution = "Cauchy" + + +def test_nonstring_distribution() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior(pm.Normal) + + +def test_change_the_transform() -> None: + dist = Prior("Normal") + dist.transform = "logit" + assert dist.transform == "logit" + + +def test_nonstring_transform() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior("Normal", transform=pm.math.log) + + +def test_checks_param_value_types() -> None: + with pytest.raises(ValueError, match="Parameters must be one of the following types"): + Prior("Normal", mu="str", sigma="str") + + +def test_check_equality_with_numpy() -> None: + dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1) + assert dist == dist.deepcopy() + + +def clear_custom_transforms() -> None: + global CUSTOM_TRANSFORMS + CUSTOM_TRANSFORMS = {} + + +def test_custom_transform() -> None: + new_transform_name = "foo_bar" + with pytest.raises(UnknownTransformError): + Prior("Normal", transform=new_transform_name) + + register_tensor_transform(new_transform_name, lambda x: x**2) + + dist = Prior("Normal", transform=new_transform_name) + prior = dist.sample_prior(draws=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) + + +def test_custom_transform_comes_first() -> None: + # function in pytensor.tensor + register_tensor_transform("square", lambda x: 2 * x) + + dist = Prior("Normal", transform="square") + prior = dist.sample_prior(draws=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) + + clear_custom_transforms() + + +def test_zsn_non_centered() -> None: + try: + Prior("ZeroSumNormal", sigma=1, centered=False) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +class Arbitrary: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) + + +class ArbitraryWithoutName: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + with pm.Model(name=name): + location = pm.Normal("location", dims=self.dims) + scale = pm.HalfNormal("scale", dims=self.dims) + + return pm.Normal("standard_normal") * scale + location + + +def test_sample_prior_arbitrary() -> None: + var = Arbitrary(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + + +def test_sample_prior_arbitrary_no_name() -> None: + var = ArbitraryWithoutName(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + assert "var" not in prior + + prior_with = sample_prior( + var, + coords={"channel": ["A", "B", "C"]}, + draws=25, + wrap=True, + ) + + assert isinstance(prior_with, xr.Dataset) + assert "var" in prior_with + + +def test_create_prior_with_arbitrary() -> None: + dist = Prior( + "Normal", + mu=Arbitrary(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + + coords = { + "channel": ["C1", "C2", "C3"], + "geo": ["G1", "G2"], + } + with pm.Model(coords=coords) as model: + dist.create_variable("var") + + assert "var_mu" in model + var_mu = model["var_mu"] + + assert fast_eval(var_mu).shape == (len(coords["channel"]),) + + +def test_censored_is_variable_factory() -> None: + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + assert isinstance(censored_normal, VariableFactory) + + +@pytest.mark.parametrize( + "dims, expected_dims", + [ + ("channel", ("channel",)), + (("channel", "geo"), ("channel", "geo")), + ], + ids=["string", "tuple"], +) +def test_censored_dims_from_distribution(dims, expected_dims) -> None: + normal = Prior("Normal", dims=dims) + censored_normal = Censored(normal, lower=0) + + assert censored_normal.dims == expected_dims + + +def test_censored_variables_created() -> None: + normal = Prior("Normal", mu=Prior("Normal"), dims="dim") + censored_normal = Censored(normal, lower=0) + + coords = {"dim": range(3)} + with pm.Model(coords=coords) as model: + censored_normal.create_variable("var") + + var_names = ["var", "var_mu"] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [(3,), ()] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_censored_sample_prior() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": ["A", "B", "C"]} + prior = censored_normal.sample_prior(coords=coords, draws=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_censored_to_graph() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + G = censored_normal.to_graph() + assert isinstance(G, Digraph) + + +def test_censored_likelihood_variable() -> None: + normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu") + variable = censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=[1, 2, 3], + ) + + assert isinstance(variable, pt.TensorVariable) + assert model.observed_RVs == [variable] + assert "likelihood_sigma" in model + + +def test_censored_likelihood_unsupported_distribution() -> None: + cauchy = Prior("Cauchy") + censored_cauchy = Censored(cauchy, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + censored_cauchy.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_likelihood_already_has_mu() -> None: + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + censored_normal = Censored(normal, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(MuAlreadyExistsError): + censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_dims_setter() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + censored_normal.dims = "date" + assert normal.dims == ("date",) + + +class ModelData(NamedTuple): + mu: float + observed: list[float] + + +@pytest.fixture(scope="session") +def model_data() -> ModelData: + return ModelData(mu=0, observed=[0, 1, 2, 3, 4]) + + +@pytest.fixture(scope="session") +def normal_model_with_censored_API(model_data) -> pm.Model: + coords = {"idx": range(len(model_data.observed))} + with pm.Model(coords=coords) as model: + sigma = Prior("HalfNormal") + normal = Prior("Normal", sigma=sigma, dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + mu=model_data.mu, + observed=model_data.observed, + ) + + return model + + +@pytest.fixture(scope="session") +def normal_model_with_censored_logp(normal_model_with_censored_API): + return normal_model_with_censored_API.compile_logp() + + +@pytest.fixture(scope="session") +def expected_normal_model(model_data) -> pm.Model: + n_points = len(model_data.observed) + with pm.Model() as expected_model: + sigma = pm.HalfNormal("censored_normal_sigma") + normal = pm.Normal.dist(mu=model_data.mu, sigma=sigma, shape=n_points) + pm.Censored( + "censored_normal", + normal, + lower=0, + upper=np.inf, + observed=model_data.observed, + ) + + return expected_model + + +@pytest.fixture(scope="session") +def expected_normal_model_logp(expected_normal_model): + return expected_normal_model.compile_logp() + + +@pytest.mark.parametrize("sigma_log__", [-10, -5, -2.5, 0, 2.5, 5, 10]) +def test_censored_normal_logp( + sigma_log__, + normal_model_with_censored_logp, + expected_normal_model_logp, +) -> None: + points = {"censored_normal_sigma_log__": sigma_log__} + normal_model_logp = normal_model_with_censored_logp(points) + expected_model_logp = expected_normal_model_logp(points) + np.testing.assert_allclose(normal_model_logp, expected_model_logp) + + +@pytest.mark.parametrize( + "mu", + [ + 0, + np.arange(10), + ], + ids=["scalar", "vector"], +) +def test_censored_logp(mu) -> None: + n_points = 10 + observed = np.zeros(n_points) + coords = {"idx": range(n_points)} + with pm.Model(coords=coords) as model: + normal = Prior("Normal", dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + observed=observed, + mu=mu, + ) + logp = model.compile_logp() + + with pm.Model() as expected_model: + pm.Censored( + "censored_normal", + pm.Normal.dist(mu=mu, sigma=1, shape=n_points), + lower=0, + upper=np.inf, + observed=observed, + ) + expected_logp = expected_model.compile_logp() + + point = {} + np.testing.assert_allclose(logp(point), expected_logp(point)) + + +def test_scaled_initializes_correctly() -> None: + """Test that the Scaled class initializes correctly.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + assert scaled.dist == normal + assert scaled.factor == 2.0 + + +def test_scaled_dims_property() -> None: + """Test that the dims property returns the dimensions of the underlying distribution.""" + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + scaled = Scaled(normal, factor=2.0) + + assert scaled.dims == ("channel",) + + # Test with multiple dimensions + normal.dims = ("channel", "geo") + assert scaled.dims == ("channel", "geo") + + +def test_scaled_create_variable() -> None: + """Test that the create_variable method properly scales the variable.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + with pm.Model() as model: + scaled_var = scaled.create_variable("scaled_var") + + # Check that both the scaled and unscaled variables exist + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + + # The deterministic node should be the scaled variable + assert model["scaled_var"] == scaled_var + + +def test_scaled_creates_correct_dimensions() -> None: + """Test that the scaled variable has the correct dimensions.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords): + scaled_var = scaled.create_variable("scaled_var") + + # Check that the scaled variable has the correct dimensions + assert fast_eval(scaled_var).shape == (3,) + + +def test_scaled_applies_factor() -> None: + """Test that the scaling factor is correctly applied.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = 3.5 + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify scaling + prior = sample_prior(scaled, draws=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * factor) + + +def test_scaled_with_tensor_factor() -> None: + """Test that the Scaled class works with a tensor factor.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = pt.as_tensor_variable(2.5) + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify tensor scaling + prior = sample_prior(scaled, draws=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * 2.5) + + +def test_scaled_with_hierarchical_prior() -> None: + """Test that the Scaled class works with hierarchical priors.""" + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords) as model: + scaled.create_variable("scaled_var") + + # Check that all necessary variables were created + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + assert "scaled_var_unscaled_mu" in model + assert "scaled_var_unscaled_sigma" in model + + +def test_scaled_sample_prior() -> None: + """Test that sample_prior works with the Scaled class.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + prior = sample_prior(scaled, coords=coords, draws=25, name="scaled_var") + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + assert "scaled_var" in prior + assert "scaled_var_unscaled" in prior