diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 260c7b0e..84e467ec 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,7 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 +- pydantic>=2.0.0 - pip - pip: - blackjax diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 6a92aea5..7b53c434 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -11,6 +11,7 @@ dependencies: - statsmodels - numba<=0.60.0 - pymc>=5.21 +- pydantic>=2.0.0 - pip: - blackjax - scikit-learn diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py new file mode 100644 index 00000000..8172cc23 --- /dev/null +++ b/pymc_extras/deserialize.py @@ -0,0 +1,230 @@ +"""Deserialize dictionaries into Python objects. + +This is a two step process: + +1. Determine if the data is of the correct type. +2. Deserialize the data into a python object. + +This is used to deserialize JSON data for PyMC-Marketing. + +Examples +-------- +Make use of the already registered PyMC-Marketing deserializers: + +.. code-block:: python + + from pymc_extras.deserialize import deserialize + + prior_class_data = { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 1} + } + prior = deserialize(prior_class_data) + # Prior("Normal", mu=0, sigma=1) + +Register custom class deserialization: + +.. code-block:: python + + from pymc_extras.deserialize import register_deserialization + + class MyClass: + def __init__(self, value: int): + self.value = value + + def to_dict(self) -> dict: + # Example of what the to_dict method might look like. + return {"value": self.value} + + register_deserialization( + is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), + deserialize=lambda data: MyClass(value=data["value"]), + ) + +Deserialize data into that custom class: + +.. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"value": 42} + obj = deserialize(data) + assert isinstance(obj, MyClass) + + +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +IsType = Callable[[Any], bool] +Deserialize = Callable[[Any], Any] + + +@dataclass +class Deserializer: + """Object to store information required for deserialization. + + All deserializers should be stored via the :func:`register_deserialization` function + instead of creating this object directly. + + Attributes + ---------- + is_type : IsType + Function to determine if the data is of the correct type. + deserialize : Deserialize + Function to deserialize the data. + + Examples + -------- + .. code-block:: python + + from typing import Any + + class MyClass: + def __init__(self, value: int): + self.value = value + + from pymc_extras.deserialize import Deserializer + + def is_type(data: Any) -> bool: + return data.keys() == {"value"} and isinstance(data["value"], int) + + def deserialize(data: dict) -> MyClass: + return MyClass(value=data["value"]) + + deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize) + + """ + + is_type: IsType + deserialize: Deserialize + + +DESERIALIZERS: list[Deserializer] = [] + + +class DeserializableError(Exception): + """Error raised when data cannot be deserialized.""" + + def __init__(self, data: Any): + self.data = data + super().__init__( + f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping." + ) + + +def deserialize(data: Any) -> Any: + """Deserialize a dictionary into a Python object. + + Use the :func:`register_deserialization` function to add custom deserializations. + + Deserialization is a two step process due to the dynamic nature of the data: + + 1. Determine if the data is of the correct type. + 2. Deserialize the data into a Python object. + + Each registered deserialization is checked in order until one is found that can + deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised. + + A :class:`DeserializableError` is raised when the data fails to be deserialized + by any of the registered deserializers. + + Parameters + ---------- + data : Any + The data to deserialize. + + Returns + ------- + Any + The deserialized object. + + Raises + ------ + DeserializableError + Raised when the data doesn't match any registered deserializations + or fails to be deserialized. + + Examples + -------- + Deserialize a :class:`pymc_extras.prior.Prior` object: + + .. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} + prior = deserialize(data) + # Prior("Normal", mu=0, sigma=1) + + """ + for mapping in DESERIALIZERS: + try: + is_type = mapping.is_type(data) + except Exception: + is_type = False + + if not is_type: + continue + + try: + return mapping.deserialize(data) + except Exception as e: + raise DeserializableError(data) from e + else: + raise DeserializableError(data) + + +def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: + """Register an arbitrary deserialization. + + Use the :func:`deserialize` function to then deserialize data using all registered + deserialize functions. + + Classes from PyMC-Marketing have their deserialization mappings registered + automatically. However, custom classes will need to be registered manually + using this function before they can be deserialized. + + Parameters + ---------- + is_type : Callable[[Any], bool] + Function to determine if the data is of the correct type. + deserialize : Callable[[dict], Any] + Function to deserialize the data of that type. + + Examples + -------- + Register a custom class deserialization: + + .. code-block:: python + + from pymc_extras.deserialize import register_deserialization + + class MyClass: + def __init__(self, value: int): + self.value = value + + def to_dict(self) -> dict: + # Example of what the to_dict method might look like. + return {"value": self.value} + + register_deserialization( + is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), + deserialize=lambda data: MyClass(value=data["value"]), + ) + + Use that custom class deserialization: + + .. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"value": 42} + obj = deserialize(data) + assert isinstance(obj, MyClass) + + """ + mapping = Deserializer(is_type=is_type, deserialize=deserialize) + DESERIALIZERS.append(mapping) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py new file mode 100644 index 00000000..53ff43e5 --- /dev/null +++ b/pymc_extras/prior.py @@ -0,0 +1,1362 @@ +"""Class that represents a prior distribution. + +The `Prior` class is a wrapper around PyMC distributions that allows the user +to create outside of the PyMC model. + +Examples +-------- +Create a normal prior. + +.. code-block:: python + + from pymc_extras.prior import Prior + + normal = Prior("Normal") + +Create a hierarchical normal prior by using distributions for the parameters +and specifying the dims. + +.. code-block:: python + + hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + +Create a non-centered hierarchical normal prior with the `centered` parameter. + +.. code-block:: python + + non_centered_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + # Only change needed to make it non-centered + centered=False, + ) + +Create a hierarchical beta prior by using Beta distribution, distributions for +the parameters, and specifying the dims. + +.. code-block:: python + + hierarchical_beta = Prior( + "Beta", + alpha=Prior("HalfNormal"), + beta=Prior("HalfNormal"), + dims="channel", + ) + +Create a transformed hierarchical normal prior by using the `transform` +parameter. Here the "sigmoid" transformation comes from `pm.math`. + +.. code-block:: python + + transformed_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + transform="sigmoid", + dims="channel", + ) + +Create a prior with a custom transform function by registering it with +`register_tensor_transform`. + +.. code-block:: python + + from pymc_extras.prior import register_tensor_transform + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + +""" + +from __future__ import annotations + +import copy + +from collections.abc import Callable +from inspect import signature +from typing import Any, Protocol, runtime_checkable + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import xarray as xr + +from pydantic import InstanceOf, validate_call +from pydantic.dataclasses import dataclass +from pymc.distributions.shape_utils import Dims + +from pymc_extras.deserialize import deserialize, register_deserialization + + +class UnsupportedShapeError(Exception): + """Error for when the shapes from variables are not compatible.""" + + +class UnsupportedDistributionError(Exception): + """Error for when an unsupported distribution is used.""" + + +class UnsupportedParameterizationError(Exception): + """The follow parameterization is not supported.""" + + +class MuAlreadyExistsError(Exception): + """Error for when 'mu' is present in Prior.""" + + def __init__(self, distribution: Prior) -> None: + self.distribution = distribution + self.message = f"The mu parameter is already defined in {distribution}" + super().__init__(self.message) + + +class UnknownTransformError(Exception): + """Error for when an unknown transform is used.""" + + +def _remove_leading_xs(args: list[str | int]) -> list[str | int]: + """Remove leading 'x' from the args.""" + while args and args[0] == "x": + args.pop(0) + + return args + + +def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable: + """Take a tensor of dims `dims` and align it to `desired_dims`. + + Doesn't check for validity of the dims + + Examples + -------- + 1D to 2D with new dim + + .. code-block:: python + + x = np.array([1, 2, 3]) + dims = "channel" + + desired_dims = ("channel", "group") + + handle_dims(x, dims, desired_dims) + + """ + x = pt.as_tensor_variable(x) + + if np.ndim(x) == 0: + return x + + dims = dims if isinstance(dims, tuple) else (dims,) + desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,) + + if difference := set(dims).difference(desired_dims): + raise UnsupportedShapeError( + f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. " + f"{difference} is missing from the desired dims." + ) + + aligned_dims = np.array(dims)[:, None] == np.array(desired_dims) + + missing_dims = aligned_dims.sum(axis=0) == 0 + new_idx = aligned_dims.argmax(axis=0) + + args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)] + args = _remove_leading_xs(args) + return x.dimshuffle(*args) + + +DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike] + + +def create_dim_handler(desired_dims: Dims) -> DimHandler: + """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function.""" + + def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: + return handle_dims(x, dims, desired_dims) + + return func + + +def _dims_to_str(obj: tuple[str, ...]) -> str: + if len(obj) == 1: + return f'"{obj[0]}"' + + return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")" + + +def _get_pymc_distribution(name: str) -> type[pm.Distribution]: + if not hasattr(pm, name): + raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}") + + return getattr(pm, name) + + +Transform = Callable[[pt.TensorLike], pt.TensorLike] + +CUSTOM_TRANSFORMS: dict[str, Transform] = {} + + +def register_tensor_transform(name: str, transform: Transform) -> None: + """Register a tensor transform function to be used in the `Prior` class. + + Parameters + ---------- + name : str + The name of the transform. + func : Callable[[pt.TensorLike], pt.TensorLike] + The function to apply to the tensor. + + Examples + -------- + Register a custom transform function. + + .. code-block:: python + + from pymc_extras.prior import ( + Prior, + register_tensor_transform, + ) + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + + """ + CUSTOM_TRANSFORMS[name] = transform + + +def _get_transform(name: str): + if name in CUSTOM_TRANSFORMS: + return CUSTOM_TRANSFORMS[name] + + for module in (pt, pm.math): + if hasattr(module, name): + break + else: + module = None + + if not module: + msg = ( + f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " + "If this is a custom function, register it with the " + "`pymc_extras.prior.register_tensor_transform` function before " + "previous function call." + ) + + raise UnknownTransformError(msg) + + return getattr(module, name) + + +def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]: + return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"} + + +@runtime_checkable +class VariableFactory(Protocol): + """Protocol for something that works like a Prior class.""" + + dims: tuple[str, ...] + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a TensorVariable.""" + + +def sample_prior( + factory: VariableFactory, + coords=None, + name: str = "var", + wrap: bool = False, + **sample_prior_predictive_kwargs, +) -> xr.Dataset: + """Sample the prior for an arbitrary VariableFactory. + + Parameters + ---------- + factory : VariableFactory + The factory to sample from. + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + wrap : bool, optional + Whether to wrap the variable in a `pm.Deterministic` node, by default False. + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from an arbitrary variable factory. + + .. code-block:: python + + import pymc as pm + + import pytensor.tensor as pt + + from pymc_extras.prior import sample_prior + + class CustomVariableDefinition: + def __init__(self, dims, n: int): + self.dims = dims + self.n = n + + def create_variable(self, name: str) -> "TensorVariable": + x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims) + return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0) + + cubic = CustomVariableDefinition(dims=("channel",), n=3) + coords = {"channel": ["C1", "C2", "C3"]} + # Doesn't include the return value + prior = sample_prior(cubic, coords=coords) + + prior_with = sample_prior(cubic, coords=coords, wrap=True) + + """ + coords = coords or {} + + if isinstance(factory.dims, str): + dims = (factory.dims,) + else: + dims = factory.dims + + if missing_keys := set(dims) - set(coords.keys()): + raise KeyError(f"Coords are missing the following dims: {missing_keys}") + + with pm.Model(coords=coords) as model: + if wrap: + pm.Deterministic(name, factory.create_variable(name), dims=factory.dims) + else: + factory.create_variable(name) + + return pm.sample_prior_predictive( + model=model, + **sample_prior_predictive_kwargs, + ).prior + + +class Prior: + """A class to represent a prior distribution. + + Make use of the various helper methods to understand the distributions + better. + + - `preliz` attribute to get the equivalent distribution in `preliz` + - `sample_prior` method to sample from the prior + - `graph` get a dummy model graph with the distribution + - `constrain` to shift the distribution to a different range + + Parameters + ---------- + distribution : str + The name of PyMC distribution. + dims : Dims, optional + The dimensions of the variable, by default None + centered : bool, optional + Whether the variable is centered or not, by default True. + Only allowed for Normal distribution. + transform : str, optional + The name of the transform to apply to the variable after it is + created, by default None or no transform. The transformation must + be registered with `register_tensor_transform` function or + be available in either `pytensor.tensor` or `pymc.math`. + + """ + + # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family + non_centered_distributions: dict[str, dict[str, float]] = { + "Normal": {"mu": 0, "sigma": 1}, + "StudentT": {"mu": 0, "sigma": 1}, + "ZeroSumNormal": {"sigma": 1}, + } + + pymc_distribution: type[pm.Distribution] + pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None + + @validate_call + def __init__( + self, + distribution: str, + *, + dims: Dims | None = None, + centered: bool = True, + transform: str | None = None, + **parameters, + ) -> None: + self.distribution = distribution + self.parameters = parameters + self.dims = dims + self.centered = centered + self.transform = transform + + self._checks() + + @property + def distribution(self) -> str: + """The name of the PyMC distribution.""" + return self._distribution + + @distribution.setter + def distribution(self, distribution: str) -> None: + if hasattr(self, "_distribution"): + raise AttributeError("Can't change the distribution") + + self._distribution = distribution + self.pymc_distribution = _get_pymc_distribution(distribution) + + @property + def transform(self) -> str | None: + """The name of the transform to apply to the variable after it is created.""" + return self._transform + + @transform.setter + def transform(self, transform: str | None) -> None: + self._transform = transform + self.pytensor_transform = not transform or _get_transform(transform) # type: ignore + + @property + def dims(self) -> Dims: + """The dimensions of the variable.""" + return self._dims + + @dims.setter + def dims(self, dims) -> None: + if isinstance(dims, str): + dims = (dims,) + + self._dims = dims or () + + self._param_dims_work() + self._unique_dims() + + def __getitem__(self, key: str) -> Prior | Any: + """Return the parameter of the prior.""" + return self.parameters[key] + + def _checks(self) -> None: + if not self.centered: + self._correct_non_centered_distribution() + + self._parameters_are_at_least_subset_of_pymc() + self._convert_lists_to_numpy() + self._parameters_are_correct_type() + + def _parameters_are_at_least_subset_of_pymc(self) -> None: + pymc_params = _get_pymc_parameters(self.pymc_distribution) + if not set(self.parameters.keys()).issubset(pymc_params): + msg = ( + f"Parameters {set(self.parameters.keys())} " + "are not a subset of the pymc distribution " + f"parameters {set(pymc_params)}" + ) + raise ValueError(msg) + + def _convert_lists_to_numpy(self) -> None: + def convert(x): + if not isinstance(x, list): + return x + + return np.array(x) + + self.parameters = {key: convert(value) for key, value in self.parameters.items()} + + def _parameters_are_correct_type(self) -> None: + supported_types = ( + int, + float, + np.ndarray, + Prior, + pt.TensorVariable, + VariableFactory, + ) + + incorrect_types = { + param: type(value) + for param, value in self.parameters.items() + if not isinstance(value, supported_types) + } + if incorrect_types: + msg = ( + "Parameters must be one of the following types: " + f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}" + ) + raise ValueError(msg) + + def _correct_non_centered_distribution(self) -> None: + if not self.centered and self.distribution not in self.non_centered_distributions: + raise UnsupportedParameterizationError( + f"{self.distribution!r} is not supported for non-centered parameterization. " + f"Choose from {list(self.non_centered_distributions.keys())}" + ) + + required_parameters = set(self.non_centered_distributions[self.distribution].keys()) + + if set(self.parameters.keys()) < required_parameters: + msg = " and ".join([f"{param!r}" for param in required_parameters]) + raise ValueError( + f"Must have at least {msg} parameter for non-centered for {self.distribution!r}" + ) + + def _unique_dims(self) -> None: + if not self.dims: + return + + if len(self.dims) != len(set(self.dims)): + raise ValueError("Dims must be unique") + + def _param_dims_work(self) -> None: + other_dims = set() + for value in self.parameters.values(): + if hasattr(value, "dims"): + other_dims.update(value.dims) + + if not other_dims.issubset(self.dims): + raise UnsupportedShapeError( + f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}" + ) + + def __str__(self) -> str: + """Return a string representation of the prior.""" + param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()]) + param_str = "" if not param_str else f", {param_str}" + + dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else "" + centered_str = f", centered={self.centered}" if not self.centered else "" + transform_str = f', transform="{self.transform}"' if self.transform else "" + return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})' + + def __repr__(self) -> str: + """Return a string representation of the prior.""" + return f"{self}" + + def _create_parameter(self, param, value, name): + if not hasattr(value, "create_variable"): + return value + + child_name = f"{name}_{param}" + return self.dim_handler(value.create_variable(child_name), value.dims) + + def _create_centered_variable(self, name: str): + parameters = { + param: self._create_parameter(param, value, name) + for param, value in self.parameters.items() + } + return self.pymc_distribution(name, **parameters, dims=self.dims) + + def _create_non_centered_variable(self, name: str) -> pt.TensorVariable: + def handle_variable(var_name: str): + parameter = self.parameters[var_name] + if not hasattr(parameter, "create_variable"): + return parameter + + return self.dim_handler( + parameter.create_variable(f"{name}_{var_name}"), + parameter.dims, + ) + + defaults = self.non_centered_distributions[self.distribution] + other_parameters = { + param: handle_variable(param) + for param in self.parameters.keys() + if param not in defaults + } + offset = self.pymc_distribution( + f"{name}_offset", + **defaults, + **other_parameters, + dims=self.dims, + ) + if "mu" in self.parameters: + mu = ( + handle_variable("mu") + if isinstance(self.parameters["mu"], Prior) + else self.parameters["mu"] + ) + else: + mu = 0 + + sigma = ( + handle_variable("sigma") + if isinstance(self.parameters["sigma"], Prior) + else self.parameters["sigma"] + ) + + return pm.Deterministic( + name, + mu + sigma * offset, + dims=self.dims, + ) + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a PyMC variable from the prior. + + Must be used in a PyMC model context. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a hierarchical normal variable in larger PyMC model. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + with pm.Model(coords=coords): + var = dist.create_variable("var") + + """ + self.dim_handler = create_dim_handler(self.dims) + + if self.transform: + var_name = f"{name}_raw" + + def transform(var): + return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims) + else: + var_name = name + + def transform(var): + return var + + create_variable = ( + self._create_centered_variable if self.centered else self._create_non_centered_variable + ) + var = create_variable(name=var_name) + return transform(var) + + @property + def preliz(self): + """Create an equivalent preliz distribution. + + Helpful to visualize a distribution when it is univariate. + + Returns + ------- + preliz.distributions.Distribution + + Examples + -------- + Create a preliz distribution from a prior. + + .. code-block:: python + + from pymc_extras.prior import Prior + + dist = Prior("Gamma", alpha=5, beta=1) + dist.preliz.plot_pdf() + + """ + import preliz as pz + + return getattr(pz, self.distribution)(**self.parameters) + + def to_dict(self) -> dict[str, Any]: + """Convert the prior to dictionary format. + + Returns + ------- + dict[str, Any] + The dictionary format of the prior. + + Examples + -------- + Convert a prior to the dictionary format. + + .. code-block:: python + + from pymc_extras.prior import Prior + + dist = Prior("Normal", mu=0, sigma=1) + + dist.to_dict() + # {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} + + Convert a hierarchical prior to the dictionary format. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + dist.to_dict() + # { + # "dist": "Normal", + # "kwargs": { + # "mu": {"dist": "Normal"}, + # "sigma": {"dist": "HalfNormal"}, + # }, + # "dims": "channel", + # } + + """ + data: dict[str, Any] = { + "dist": self.distribution, + } + if self.parameters: + + def handle_value(value): + if isinstance(value, Prior): + return value.to_dict() + + if isinstance(value, pt.TensorVariable): + value = value.eval() + + if isinstance(value, np.ndarray): + return value.tolist() + + if hasattr(value, "to_dict"): + return value.to_dict() + + return value + + data["kwargs"] = { + param: handle_value(value) for param, value in self.parameters.items() + } + if not self.centered: + data["centered"] = False + + if self.dims: + data["dims"] = self.dims + + if self.transform: + data["transform"] = self.transform + + return data + + @classmethod + def from_dict(cls, data) -> Prior: + """Create a Prior from the dictionary format. + + Parameters + ---------- + data : dict[str, Any] + The dictionary format of the prior. + + Returns + ------- + Prior + The prior distribution. + + Examples + -------- + Convert prior in the dictionary format to a Prior instance. + + .. code-block:: python + + from pymc_extras.prior import Prior + + data = { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 1}, + } + + dist = Prior.from_dict(data) + dist + # Prior("Normal", mu=0, sigma=1) + + """ + if not isinstance(data, dict): + msg = ( + "Must be a dictionary representation of a prior distribution. " + f"Not of type: {type(data)}" + ) + raise ValueError(msg) + + dist = data["dist"] + kwargs = data.get("kwargs", {}) + + def handle_value(value): + if isinstance(value, dict): + return deserialize(value) + + if isinstance(value, list): + return np.array(value) + + return value + + kwargs = {param: handle_value(value) for param, value in kwargs.items()} + centered = data.get("centered", True) + dims = data.get("dims") + if isinstance(dims, list): + dims = tuple(dims) + transform = data.get("transform") + + return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs) + + def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior: + """Create a new prior with a given mass constrained within the given bounds. + + Wrapper around `preliz.maxent`. + + Parameters + ---------- + lower : float + The lower bound. + upper : float + The upper bound. + mass: float = 0.95 + The mass of the distribution to keep within the bounds. + kwargs : dict + Additional arguments to pass to `pz.maxent`. + + Returns + ------- + Prior + The maximum entropy prior with a mass constrained to the given bounds. + + Examples + -------- + Create a Beta distribution that is constrained to have 95% of the mass + between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + ).constrain(lower=0.5, upper=0.8) + + Create a Beta distribution with mean 0.6, that is constrained to + have 95% of the mass between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + mu=0.6, + ).constrain(lower=0.5, upper=0.8) + + """ + from preliz import maxent + + if self.transform: + raise ValueError("Can't constrain a transformed variable") + + if kwargs is None: + kwargs = {} + kwargs.setdefault("plot", False) + + if kwargs["plot"]: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict + else: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict + + return Prior( + self.distribution, + dims=self.dims, + transform=self.transform, + centered=self.centered, + **new_parameters, + ) + + def __eq__(self, other) -> bool: + """Check if two priors are equal.""" + if not isinstance(other, Prior): + return False + + try: + np.testing.assert_equal(self.parameters, other.parameters) + except AssertionError: + return False + + return ( + self.distribution == other.distribution + and self.dims == other.dims + and self.centered == other.centered + and self.transform == other.transform + ) + + def sample_prior( + self, + coords=None, + name: str = "var", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a hierarchical normal distribution. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def __deepcopy__(self, memo) -> Prior: + """Return a deep copy of the prior.""" + if id(self) in memo: + return memo[id(self)] + + copy_obj = Prior( + self.distribution, + dims=copy.copy(self.dims), + centered=self.centered, + transform=self.transform, + **copy.deepcopy(self.parameters), + ) + memo[id(self)] = copy_obj + return copy_obj + + def deepcopy(self) -> Prior: + """Return a deep copy of the prior.""" + return copy.deepcopy(self) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create the graph for a 2D transformed hierarchical distribution. + + .. code-block:: python + + from pymc_extras.prior import Prior + + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + sigma = Prior("HalfNormal", dims="channel") + dist = Prior( + "Normal", + mu=mu, + sigma=sigma, + dims=("channel", "geo"), + centered=False, + transform="sigmoid", + ) + + dist.to_graph() + + .. image:: /_static/example-graph.png + :alt: Example graph + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create a likelihood variable from the prior. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + + dist = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model(): + # Create the likelihood variable + mu = pm.Normal("mu", mu=0, sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution!r} is not supported." + ) + + if "mu" in self.parameters: + raise MuAlreadyExistsError(self) + + distribution = self.deepcopy() + distribution.parameters["mu"] = mu + distribution.parameters["observed"] = observed + return distribution.create_variable(name) + + +class VariableNotFound(Exception): + """Variable is not found.""" + + +def _remove_random_variable(var: pt.TensorVariable) -> None: + if var.name is None: + raise ValueError("This isn't removable") + + name: str = var.name + + model = pm.modelcontext(None) + for idx, free_rv in enumerate(model.free_RVs): + if var == free_rv: + index_to_remove = idx + break + else: + raise VariableNotFound(f"Variable {var.name!r} not found") + + var.name = None + model.free_RVs.pop(index_to_remove) + model.named_vars.pop(name) + + +@dataclass +class Censored: + """Create censored random variable. + + Examples + -------- + Create a censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + Create hierarchical censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + samples = censored_normal.sample_prior(coords=coords) + + """ + + distribution: InstanceOf[Prior] + lower: float | InstanceOf[pt.TensorVariable] = -np.inf + upper: float | InstanceOf[pt.TensorVariable] = np.inf + + def __post_init__(self) -> None: + """Check validity at initialization.""" + if not self.distribution.centered: + raise ValueError( + "Censored distribution must be centered so that .dist() API can be used on distribution." + ) + + if self.distribution.transform is not None: + raise ValueError( + "Censored distribution can't have a transform so that .dist() API can be used on distribution." + ) + + @property + def dims(self) -> tuple[str, ...]: + """The dims from the distribution to censor.""" + return self.distribution.dims + + @dims.setter + def dims(self, dims) -> None: + self.distribution.dims = dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create censored random variable.""" + dist = self.distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + def to_dict(self) -> dict[str, Any]: + """Convert the censored distribution to a dictionary.""" + + def handle_value(value): + if isinstance(value, pt.TensorVariable): + return value.eval().tolist() + + return value + + return { + "class": "Censored", + "data": { + "dist": self.distribution.to_dict(), + "lower": handle_value(self.lower), + "upper": handle_value(self.upper), + }, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Censored: + """Create a censored distribution from a dictionary.""" + data = data["data"] + return cls( # type: ignore + distribution=Prior.from_dict(data["dist"]), + lower=data["lower"], + upper=data["upper"], + ) + + def sample_prior( + self, + coords=None, + name: str = "variable", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a censored Gamma distribution. + + .. code-block:: python + + gamma = Prior("Gamma", mu=1, sigma=1, dims="channel") + dist = Censored(gamma, lower=0.5) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create graph for a censored Normal distribution + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + censored_normal.to_graph() + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create observed censored variable. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a censored likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal", sigma=Prior("HalfNormal")) + dist = Censored(normal, lower=0) + + observed = 1 + + with pm.Model(): + # Create the likelihood variable + mu = pm.HalfNormal("mu", sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution.distribution!r} is not supported." + ) + + if "mu" in self.distribution.parameters: + raise MuAlreadyExistsError(self.distribution) + + distribution = self.distribution.deepcopy() + distribution.parameters["mu"] = mu + + dist = distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + observed=observed, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + +class Scaled: + """Scaled distribution for numerical stability.""" + + def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None: + self.dist = dist + self.factor = factor + + @property + def dims(self) -> Dims: + """The dimensions of the scaled distribution.""" + return self.dist.dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a scaled variable. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The scaled variable. + """ + var = self.dist.create_variable(f"{name}_unscaled") + return pm.Deterministic(name, var * self.factor, dims=self.dims) + + +def _is_prior_type(data: dict) -> bool: + return "dist" in data + + +def _is_censored_type(data: dict) -> bool: + return data.keys() == {"class", "data"} and data["class"] == "Censored" + + +register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) +register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) diff --git a/requirements-dev.txt b/requirements-dev.txt index a28518d8..8bbe5642 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,5 @@ blackjax # Used as benchmark for statespace models statsmodels +pydantic>=2.0.0 +preliz diff --git a/requirements.txt b/requirements.txt index 49c7d88a..9636089d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pymc>=5.21.1 scikit-learn better-optimize +pydantic>=2.0.0 diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py new file mode 100644 index 00000000..2fcc28a1 --- /dev/null +++ b/tests/test_deserialize.py @@ -0,0 +1,59 @@ +import pytest + +from pymc_extras.deserialize import ( + DESERIALIZERS, + DeserializableError, + deserialize, + register_deserialization, +) + + +@pytest.mark.parametrize( + "unknown_data", + [ + {"unknown": 1}, + {"dist": "Normal", "kwargs": {"something": "else"}}, + 1, + ], + ids=["unknown_structure", "prior_like", "non_dict"], +) +def test_unknown_type_raises(unknown_data) -> None: + match = "Couldn't deserialize" + with pytest.raises(DeserializableError, match=match): + deserialize(unknown_data) + + +class ArbitraryObject: + def __init__(self, code: str): + self.code = code + self.value = 1 + + +@pytest.fixture +def register_arbitrary_object(): + register_deserialization( + is_type=lambda data: data.keys() == {"code"}, + deserialize=lambda data: ArbitraryObject(code=data["code"]), + ) + + yield + + DESERIALIZERS.pop() + + +def test_registration(register_arbitrary_object) -> None: + instance = deserialize({"code": "test"}) + + assert isinstance(instance, ArbitraryObject) + assert instance.code == "test" + + +def test_registeration_mixup() -> None: + data_that_looks_like_prior = { + "dist": "Normal", + "kwargs": {"something": "else"}, + } + + match = "Couldn't deserialize" + with pytest.raises(DeserializableError, match=match): + deserialize(data_that_looks_like_prior) diff --git a/tests/test_prior.py b/tests/test_prior.py new file mode 100644 index 00000000..d0a3490e --- /dev/null +++ b/tests/test_prior.py @@ -0,0 +1,1144 @@ +from copy import deepcopy +from typing import NamedTuple + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +import xarray as xr + +from graphviz.graphs import Digraph +from pydantic import ValidationError +from pymc.model_graph import fast_eval + +from pymc_extras.deserialize import ( + DESERIALIZERS, + deserialize, + register_deserialization, +) +from pymc_extras.prior import ( + Censored, + MuAlreadyExistsError, + Prior, + Scaled, + UnknownTransformError, + UnsupportedDistributionError, + UnsupportedParameterizationError, + UnsupportedShapeError, + VariableFactory, + handle_dims, + register_tensor_transform, + sample_prior, +) + +pz = pytest.importorskip("preliz") + + +@pytest.mark.parametrize( + "x, dims, desired_dims, expected_fn", + [ + (np.arange(3), "channel", "channel", lambda x: x), + (np.arange(3), "channel", ("geo", "channel"), lambda x: x), + (np.arange(3), "channel", ("channel", "geo"), lambda x: x[:, None]), + (np.arange(3), "channel", ("x", "y", "channel", "geo"), lambda x: x[:, None]), + ( + np.arange(3 * 2).reshape(3, 2), + ("channel", "geo"), + ("geo", "x", "y", "channel"), + lambda x: x.T[:, None, None, :], + ), + ( + np.arange(4 * 2 * 3).reshape(4, 2, 3), + ("channel", "geo", "store"), + ("geo", "x", "store", "channel"), + lambda x: x.swapaxes(0, 2).swapaxes(0, 1)[:, None, :, :], + ), + ], + ids=[ + "same_dims", + "different_dims", + "dim_padding", + "just_enough_dims", + "transpose_and_padding", + "swaps_and_padding", + ], +) +def test_handle_dims(x, dims, desired_dims, expected_fn) -> None: + result = handle_dims(x, dims, desired_dims) + if isinstance(result, pt.TensorVariable): + result = fast_eval(result) + + np.testing.assert_array_equal(result, expected_fn(x)) + + +@pytest.mark.parametrize( + "x, dims, desired_dims", + [ + (np.ones(3), "channel", "something_else"), + (np.ones((3, 2)), ("a", "b"), ("a", "B")), + ], + ids=["no_incommon", "some_incommon"], +) +def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None: + match = " are not a subset of the desired dims " + with pytest.raises(UnsupportedShapeError, match=match): + handle_dims(x, dims, desired_dims) + + +def test_missing_transform() -> None: + match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" + with pytest.raises(UnknownTransformError, match=match): + Prior("Normal", transform="foo_bar") + + +def test_get_item() -> None: + var = Prior("Normal", mu=0, sigma=1) + + assert var["mu"] == 0 + assert var["sigma"] == 1 + + +def test_noncentered_needs_params() -> None: + with pytest.raises(ValueError): + Prior( + "Normal", + centered=False, + ) + + +def test_different_than_pymc_params() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, b=1) + + +def test_non_unique_dims() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, sigma=1, dims=("channel", "channel")) + + +def test_doesnt_check_validity_parameterization() -> None: + try: + Prior("Normal", mu=0, sigma=1, tau=1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_doesnt_check_validity_values() -> None: + try: + Prior("Normal", mu=0, sigma=-1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_preliz() -> None: + var = Prior("Normal", mu=0, sigma=1) + dist = var.preliz + assert isinstance(dist, pz.distributions.Distribution) + + +@pytest.mark.parametrize( + "var, expected", + [ + (Prior("Normal", mu=0, sigma=1), 'Prior("Normal", mu=0, sigma=1)'), + ( + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")), + 'Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"))', + ), + (Prior("Normal", dims="channel"), 'Prior("Normal", dims="channel")'), + ( + Prior("Normal", mu=0, sigma=1, transform="sigmoid"), + 'Prior("Normal", mu=0, sigma=1, transform="sigmoid")', + ), + ], +) +def test_str(var, expected) -> None: + assert str(var) == expected + + +@pytest.mark.parametrize( + "var", + [ + Prior("Normal", mu=0, sigma=1), + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel"), + Prior("Normal", dims=("geo", "channel")), + ], +) +def test_repr(var) -> None: + assert eval(repr(var)) == var + + +def test_invalid_distribution() -> None: + with pytest.raises(UnsupportedDistributionError): + Prior("Invalid") + + +def test_broadcast_doesnt_work(): + with pytest.raises(UnsupportedShapeError): + Prior( + "Normal", + mu=0, + sigma=Prior("HalfNormal", sigma=1, dims="x"), + dims="y", + ) + + +def test_dim_workaround_flaw() -> None: + distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="y", + ) + + try: + distribution["mu"].dims = "x" + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + with pytest.raises(UnsupportedShapeError): + distribution._param_dims_work() + + +def test_noncentered_error() -> None: + with pytest.raises(UnsupportedParameterizationError): + Prior( + "Gamma", + mu=0, + sigma=1, + dims="x", + centered=False, + ) + + +def test_create_variable_multiple_times() -> None: + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + + coords = { + "channel": ["a", "b", "c"], + } + with pm.Model(coords=coords) as model: + mu.create_variable("mu") + mu.create_variable("mu_2") + + suffixes = [ + "", + "_offset", + "_mu", + "_sigma", + ] + dims = [(3,), (3,), (), ()] + + for prefix in ["mu", "mu_2"]: + for suffix, dim in zip(suffixes, dims, strict=False): + assert fast_eval(model[f"{prefix}{suffix}"]).shape == dim + + +@pytest.fixture +def large_var() -> Prior: + mu = Prior( + "Normal", + mu=Prior("Normal", mu=1), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + sigma = Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo") + + return Prior("Normal", mu=mu, sigma=sigma, dims=("geo", "channel")) + + +def test_create_variable(large_var) -> None: + coords = { + "channel": ["a", "b", "c"], + "geo": ["x", "y"], + } + with pm.Model(coords=coords) as model: + large_var.create_variable("var") + + var_names = [ + "var", + "var_mu", + "var_sigma", + "var_mu_offset", + "var_mu_mu", + "var_mu_sigma", + "var_sigma_sigma", + ] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [ + (2, 3), + (3,), + (2,), + (3,), + (), + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_transform() -> None: + var = Prior("Normal", mu=0, sigma=1, transform="sigmoid") + + with pm.Model() as model: + var.create_variable("var") + + var_names = [ + "var", + "var_raw", + ] + dims = [ + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_to_dict(large_var) -> None: + data = large_var.to_dict() + + assert data == { + "dist": "Normal", + "kwargs": { + "mu": { + "dist": "Normal", + "kwargs": { + "mu": { + "dist": "Normal", + "kwargs": { + "mu": 1, + }, + }, + "sigma": { + "dist": "HalfNormal", + }, + }, + "centered": False, + "dims": ("channel",), + }, + "sigma": { + "dist": "HalfNormal", + "kwargs": { + "sigma": { + "dist": "HalfNormal", + }, + }, + "dims": ("geo",), + }, + }, + "dims": ("geo", "channel"), + } + + +def test_to_dict_numpy() -> None: + var = Prior("Normal", mu=np.array([0, 10, 20]), dims="channel") + assert var.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": [0, 10, 20], + }, + "dims": ("channel",), + } + + +def test_dict_round_trip(large_var) -> None: + assert Prior.from_dict(large_var.to_dict()) == large_var + + +def test_constrain_with_transform_error() -> None: + var = Prior("Normal", transform="sigmoid") + + with pytest.raises(ValueError): + var.constrain(lower=0, upper=1) + + +def test_constrain(mocker) -> None: + var = Prior("Normal") + + mocker.patch( + "preliz.maxent", + return_value=mocker.Mock(params_dict={"mu": 5, "sigma": 2}), + ) + + new_var = var.constrain(lower=0, upper=1) + assert new_var == Prior("Normal", mu=5, sigma=2) + + +def test_dims_change() -> None: + var = Prior("Normal", mu=0, sigma=1) + var.dims = "channel" + + assert var.dims == ("channel",) + + +def test_dims_change_error() -> None: + mu = Prior("Normal", dims="channel") + var = Prior("Normal", mu=mu, dims="channel") + + with pytest.raises(UnsupportedShapeError): + var.dims = "geo" + + +def test_deepcopy() -> None: + priors = { + "alpha": Prior("Beta", alpha=1, beta=1), + "gamma": Prior("Normal", mu=0, sigma=1), + } + + new_priors = deepcopy(priors) + priors["alpha"].dims = "channel" + + assert new_priors["alpha"].dims == () + + +@pytest.fixture +def mmm_default_model_config(): + return { + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "likelihood": { + "dist": "Normal", + "kwargs": { + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + }, + }, + "gamma_control": { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 2}, + "dims": "control", + }, + "gamma_fourier": { + "dist": "Laplace", + "kwargs": {"mu": 0, "b": 1}, + "dims": "fourier_mode", + }, + } + + +def test_backwards_compat(mmm_default_model_config) -> None: + result = {param: Prior.from_dict(value) for param, value in mmm_default_model_config.items()} + assert result == { + "intercept": Prior("Normal", mu=0, sigma=2), + "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)), + "gamma_control": Prior("Normal", mu=0, sigma=2, dims="control"), + "gamma_fourier": Prior("Laplace", mu=0, b=1, dims="fourier_mode"), + } + + +def test_sample_prior() -> None: + var = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + transform="sigmoid", + ) + + coords = {"channel": ["A", "B", "C"]} + prior = var.sample_prior(coords=coords, samples=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_sample_prior_missing_coords() -> None: + dist = Prior("Normal", dims="channel") + + with pytest.raises(KeyError, match="Coords"): + dist.sample_prior() + + +def test_to_graph() -> None: + hierarchical_distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + G = hierarchical_distribution.to_graph() + assert isinstance(G, Digraph) + + +def test_from_dict_list() -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": [0, 1, 2], + "sigma": 1, + }, + "dims": "channel", + } + + var = Prior.from_dict(data) + assert var.dims == ("channel",) + assert isinstance(var["mu"], np.ndarray) + np.testing.assert_array_equal(var["mu"], [0, 1, 2]) + + +def test_from_dict_list_dims() -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": 0, + "sigma": 1, + }, + "dims": ["channel", "geo"], + } + + var = Prior.from_dict(data) + assert var.dims == ("channel", "geo") + + +def test_to_dict_transform() -> None: + dist = Prior("Normal", transform="sigmoid") + + data = dist.to_dict() + assert data == { + "dist": "Normal", + "transform": "sigmoid", + } + + +def test_equality_non_prior() -> None: + dist = Prior("Normal") + + assert dist != 1 + + +def test_deepcopy_memo() -> None: + memo = {} + dist = Prior("Normal") + memo[id(dist)] = dist + deepcopy(dist, memo) + assert len(memo) == 1 + deepcopy(dist, memo) + + assert len(memo) == 1 + + +def test_create_likelihood_variable() -> None: + distribution = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model() as model: + mu = pm.Normal("mu") + + data = distribution.create_likelihood_variable("data", mu=mu, observed=10) + + assert model.observed_RVs == [data] + assert "data_sigma" in model + + +def test_create_likelihood_variable_already_has_mu() -> None: + distribution = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + + with pm.Model(): + mu = pm.Normal("mu") + + with pytest.raises(MuAlreadyExistsError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_create_likelihood_non_mu_parameterized_distribution() -> None: + distribution = Prior("Cauchy") + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_non_centered_student_t() -> None: + try: + Prior( + "StudentT", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + nu=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_cant_reset_distribution() -> None: + dist = Prior("Normal") + with pytest.raises(AttributeError, match="Can't change the distribution"): + dist.distribution = "Cauchy" + + +def test_nonstring_distribution() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior(pm.Normal) + + +def test_change_the_transform() -> None: + dist = Prior("Normal") + dist.transform = "logit" + assert dist.transform == "logit" + + +def test_nonstring_transform() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior("Normal", transform=pm.math.log) + + +def test_checks_param_value_types() -> None: + with pytest.raises(ValueError, match="Parameters must be one of the following types"): + Prior("Normal", mu="str", sigma="str") + + +def test_check_equality_with_numpy() -> None: + dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1) + assert dist == dist.deepcopy() + + +def clear_custom_transforms() -> None: + global CUSTOM_TRANSFORMS + CUSTOM_TRANSFORMS = {} + + +def test_custom_transform() -> None: + new_transform_name = "foo_bar" + with pytest.raises(UnknownTransformError): + Prior("Normal", transform=new_transform_name) + + register_tensor_transform(new_transform_name, lambda x: x**2) + + dist = Prior("Normal", transform=new_transform_name) + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) + + +def test_custom_transform_comes_first() -> None: + # function in pytensor.tensor + register_tensor_transform("square", lambda x: 2 * x) + + dist = Prior("Normal", transform="square") + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) + + clear_custom_transforms() + + +def test_serialize_with_pytensor() -> None: + sigma = pt.arange(1, 4) + dist = Prior("Normal", mu=0, sigma=sigma) + + assert dist.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": 0, + "sigma": [1, 2, 3], + }, + } + + +def test_zsn_non_centered() -> None: + try: + Prior("ZeroSumNormal", sigma=1, centered=False) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +class Arbitrary: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) + + +class ArbitraryWithoutName: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + with pm.Model(name=name): + location = pm.Normal("location", dims=self.dims) + scale = pm.HalfNormal("scale", dims=self.dims) + + return pm.Normal("standard_normal") * scale + location + + +def test_sample_prior_arbitrary() -> None: + var = Arbitrary(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + + +def test_sample_prior_arbitrary_no_name() -> None: + var = ArbitraryWithoutName(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + assert "var" not in prior + + prior_with = sample_prior( + var, + coords={"channel": ["A", "B", "C"]}, + draws=25, + wrap=True, + ) + + assert isinstance(prior_with, xr.Dataset) + assert "var" in prior_with + + +def test_create_prior_with_arbitrary() -> None: + dist = Prior( + "Normal", + mu=Arbitrary(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + + coords = { + "channel": ["C1", "C2", "C3"], + "geo": ["G1", "G2"], + } + with pm.Model(coords=coords) as model: + dist.create_variable("var") + + assert "var_mu" in model + var_mu = model["var_mu"] + + assert fast_eval(var_mu).shape == (len(coords["channel"]),) + + +def test_censored_is_variable_factory() -> None: + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + assert isinstance(censored_normal, VariableFactory) + + +@pytest.mark.parametrize( + "dims, expected_dims", + [ + ("channel", ("channel",)), + (("channel", "geo"), ("channel", "geo")), + ], + ids=["string", "tuple"], +) +def test_censored_dims_from_distribution(dims, expected_dims) -> None: + normal = Prior("Normal", dims=dims) + censored_normal = Censored(normal, lower=0) + + assert censored_normal.dims == expected_dims + + +def test_censored_variables_created() -> None: + normal = Prior("Normal", mu=Prior("Normal"), dims="dim") + censored_normal = Censored(normal, lower=0) + + coords = {"dim": range(3)} + with pm.Model(coords=coords) as model: + censored_normal.create_variable("var") + + var_names = ["var", "var_mu"] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [(3,), ()] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_censored_sample_prior() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": ["A", "B", "C"]} + prior = censored_normal.sample_prior(coords=coords, samples=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_censored_to_graph() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + G = censored_normal.to_graph() + assert isinstance(G, Digraph) + + +def test_censored_likelihood_variable() -> None: + normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu") + variable = censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=[1, 2, 3], + ) + + assert isinstance(variable, pt.TensorVariable) + assert model.observed_RVs == [variable] + assert "likelihood_sigma" in model + + +def test_censored_likelihood_unsupported_distribution() -> None: + cauchy = Prior("Cauchy") + censored_cauchy = Censored(cauchy, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + censored_cauchy.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_likelihood_already_has_mu() -> None: + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + censored_normal = Censored(normal, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(MuAlreadyExistsError): + censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_to_dict() -> None: + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + censored_normal = Censored(normal, lower=0) + + data = censored_normal.to_dict() + assert data == { + "class": "Censored", + "data": {"dist": normal.to_dict(), "lower": 0, "upper": float("inf")}, + } + + +def test_deserialize_censored() -> None: + data = { + "class": "Censored", + "data": { + "dist": { + "dist": "Normal", + }, + "lower": 0, + "upper": float("inf"), + }, + } + + instance = deserialize(data) + assert isinstance(instance, Censored) + assert isinstance(instance.distribution, Prior) + assert instance.lower == 0 + assert instance.upper == float("inf") + + +class ArbitrarySerializable(Arbitrary): + def to_dict(self): + return {"dims": self.dims} + + +@pytest.fixture +def arbitrary_serialized_data() -> dict: + return {"dims": ("channel",)} + + +def test_create_prior_with_arbitrary_serializable(arbitrary_serialized_data) -> None: + dist = Prior( + "Normal", + mu=ArbitrarySerializable(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + + assert dist.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": arbitrary_serialized_data, + "sigma": 1, + }, + "dims": ("channel", "geo"), + } + + +@pytest.fixture +def register_arbitrary_deserialization(): + register_deserialization( + lambda data: isinstance(data, dict) and data.keys() == {"dims"}, + lambda data: ArbitrarySerializable(**data), + ) + + yield + + DESERIALIZERS.pop() + + +def test_deserialize_arbitrary_within_prior( + arbitrary_serialized_data, + register_arbitrary_deserialization, +) -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": arbitrary_serialized_data, + "sigma": 1, + }, + "dims": ("channel", "geo"), + } + + dist = deserialize(data) + assert isinstance(dist["mu"], ArbitrarySerializable) + assert dist["mu"].dims == ("channel",) + + +def test_censored_with_tensor_variable() -> None: + normal = Prior("Normal", dims="channel") + lower = pt.as_tensor_variable([0, 1, 2]) + censored_normal = Censored(normal, lower=lower) + + assert censored_normal.to_dict() == { + "class": "Censored", + "data": { + "dist": normal.to_dict(), + "lower": [0, 1, 2], + "upper": float("inf"), + }, + } + + +def test_censored_dims_setter() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + censored_normal.dims = "date" + assert normal.dims == ("date",) + + +class ModelData(NamedTuple): + mu: float + observed: list[float] + + +@pytest.fixture(scope="session") +def model_data() -> ModelData: + return ModelData(mu=0, observed=[0, 1, 2, 3, 4]) + + +@pytest.fixture(scope="session") +def normal_model_with_censored_API(model_data) -> pm.Model: + coords = {"idx": range(len(model_data.observed))} + with pm.Model(coords=coords) as model: + sigma = Prior("HalfNormal") + normal = Prior("Normal", sigma=sigma, dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + mu=model_data.mu, + observed=model_data.observed, + ) + + return model + + +@pytest.fixture(scope="session") +def normal_model_with_censored_logp(normal_model_with_censored_API): + return normal_model_with_censored_API.compile_logp() + + +@pytest.fixture(scope="session") +def expected_normal_model(model_data) -> pm.Model: + n_points = len(model_data.observed) + with pm.Model() as expected_model: + sigma = pm.HalfNormal("censored_normal_sigma") + normal = pm.Normal.dist(mu=model_data.mu, sigma=sigma, shape=n_points) + pm.Censored( + "censored_normal", + normal, + lower=0, + upper=np.inf, + observed=model_data.observed, + ) + + return expected_model + + +@pytest.fixture(scope="session") +def expected_normal_model_logp(expected_normal_model): + return expected_normal_model.compile_logp() + + +@pytest.mark.parametrize("sigma_log__", [-10, -5, -2.5, 0, 2.5, 5, 10]) +def test_censored_normal_logp( + sigma_log__, + normal_model_with_censored_logp, + expected_normal_model_logp, +) -> None: + points = {"censored_normal_sigma_log__": sigma_log__} + normal_model_logp = normal_model_with_censored_logp(points) + expected_model_logp = expected_normal_model_logp(points) + np.testing.assert_allclose(normal_model_logp, expected_model_logp) + + +@pytest.mark.parametrize( + "mu", + [ + 0, + np.arange(10), + ], + ids=["scalar", "vector"], +) +def test_censored_logp(mu) -> None: + n_points = 10 + observed = np.zeros(n_points) + coords = {"idx": range(n_points)} + with pm.Model(coords=coords) as model: + normal = Prior("Normal", dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + observed=observed, + mu=mu, + ) + logp = model.compile_logp() + + with pm.Model() as expected_model: + pm.Censored( + "censored_normal", + pm.Normal.dist(mu=mu, sigma=1, shape=n_points), + lower=0, + upper=np.inf, + observed=observed, + ) + expected_logp = expected_model.compile_logp() + + point = {} + np.testing.assert_allclose(logp(point), expected_logp(point)) + + +def test_scaled_initializes_correctly() -> None: + """Test that the Scaled class initializes correctly.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + assert scaled.dist == normal + assert scaled.factor == 2.0 + + +def test_scaled_dims_property() -> None: + """Test that the dims property returns the dimensions of the underlying distribution.""" + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + scaled = Scaled(normal, factor=2.0) + + assert scaled.dims == ("channel",) + + # Test with multiple dimensions + normal.dims = ("channel", "geo") + assert scaled.dims == ("channel", "geo") + + +def test_scaled_create_variable() -> None: + """Test that the create_variable method properly scales the variable.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + with pm.Model() as model: + scaled_var = scaled.create_variable("scaled_var") + + # Check that both the scaled and unscaled variables exist + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + + # The deterministic node should be the scaled variable + assert model["scaled_var"] == scaled_var + + +def test_scaled_creates_correct_dimensions() -> None: + """Test that the scaled variable has the correct dimensions.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords): + scaled_var = scaled.create_variable("scaled_var") + + # Check that the scaled variable has the correct dimensions + assert fast_eval(scaled_var).shape == (3,) + + +def test_scaled_applies_factor() -> None: + """Test that the scaling factor is correctly applied.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = 3.5 + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify scaling + prior = sample_prior(scaled, samples=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * factor) + + +def test_scaled_with_tensor_factor() -> None: + """Test that the Scaled class works with a tensor factor.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = pt.as_tensor_variable(2.5) + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify tensor scaling + prior = sample_prior(scaled, samples=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * 2.5) + + +def test_scaled_with_hierarchical_prior() -> None: + """Test that the Scaled class works with hierarchical priors.""" + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords) as model: + scaled.create_variable("scaled_var") + + # Check that all necessary variables were created + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + assert "scaled_var_unscaled_mu" in model + assert "scaled_var_unscaled_sigma" in model + + +def test_scaled_sample_prior() -> None: + """Test that sample_prior works with the Scaled class.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + prior = sample_prior(scaled, coords=coords, draws=25, name="scaled_var") + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + assert "scaled_var" in prior + assert "scaled_var_unscaled" in prior