Skip to content

Finalise backend-agnosticity of Distributions #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

willGraham01
Copy link
Collaborator

@willGraham01 willGraham01 commented Mar 21, 2025

Creates the Translator class, which expands on the BackendAgnostic ABC. A Translator is aware of the backend_obj (backend object), and also aware of the signature of the required _frontend_methods.

On instantiation, one can pass a sequence of Translation objects - essentially dataclass wrappers for the arguments to convert_signature - to a Translator. For each _frontend_method, the Translator stores the map between the arguments the frontend method takes, and the arguments the corresponding backend method takes. This means that the Translator can be used by the frontend, and will have the expected syntax, but also takes care of the necessary mapping of the arguments provided by the frontend to those that the backend takes. There is a simple example of this in action in the tests for the Translator class, however the tests for different backends demonstrate how this functionality is envisioned to be used.

The Distribution class now inherits from Translator. The NativeDistribution class has also been introduced, for us to use as the "base" for any standard distributions that we want to provide via our jax-default backend. This class simply defaults all the "translations" to the identity map (IE, the arguments and method names of the backend are already what we expect of the frontend).

Users should use the Distribution class when they want to use a backend that is different to our "standard" backend (jax). If they envision using multiple distributions from the same backend, they can create a derived class to reduce the number of times they need to provide the mapping information, for example

from numpyro.distributions.continuous import MultivariateNormal

from causalprog.distributions.base import Distribution
from causalprog.backend.translation import Translation

NUMPYRO_TRANSLATIONS = (
    Translation(backend_method="sample", frontend_method="sample", param_map = {"seed": "rng_key"},
    ...
)

class NumpyroDistribution(Distribution):

    def __init__(self, *, backend, label: str) -> None:
        super().__init__(*NUMPYRO_TRANSLATIONS, backend=backend, label=label)      

# Specialized classes are further possible,
# if the user so desires.

class NumpyroNormal(NumpyroDistribution):

    def __init__(self, mean, cov, *, label: str) -> None:
        super().__init__(backend=MultivariateNormal(mean, cov), label=label)

import jax.numpy as jnp

mean, cov = jnp.array(...), jnp.array(...)
normal = NumpyroNormal(mean, cov, label="Numpyro Normal")
normal.sample(...) # Works with frontend syntax, but calls `numpyro` functionality.

Other Changes

  • convert_signature now returns the function that maps the frontend arguments to their backend counterparts, rather than the function that does this and then evaluates the backend function. This is so that we can recycle the static identity map method of Translator.

@willGraham01 willGraham01 force-pushed the wgraham/signature-converting branch from 5b28001 to 775489f Compare March 21, 2025 14:49
Copy link
Collaborator Author

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review with some typos

@willGraham01 willGraham01 marked this pull request as ready for review March 24, 2025 10:05
@willGraham01 willGraham01 requested review from mscroggs and removed request for mscroggs March 24, 2025 10:05
@willGraham01 willGraham01 marked this pull request as draft March 24, 2025 10:07
@willGraham01 willGraham01 marked this pull request as ready for review March 24, 2025 10:20
Base automatically changed from wgraham/signature-converting to main March 24, 2025 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant