diff --git a/src/causalprog/algorithms/expectation.py b/src/causalprog/algorithms/expectation.py index 10d87fd..6ad8345 100644 --- a/src/causalprog/algorithms/expectation.py +++ b/src/causalprog/algorithms/expectation.py @@ -1,6 +1,6 @@ """Algorithms for estimating the expectation and standard deviation.""" -import numpy as np +import jax import numpy.typing as npt from causalprog.graph import Graph @@ -8,8 +8,10 @@ def sample( graph: Graph, - outcome_node_label: str | None = None, - samples: int = 1000, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, ) -> npt.NDArray[float]: """Sample data from a graph.""" if outcome_node_label is None: @@ -18,24 +20,30 @@ def sample( nodes = graph.roots_down_to_outcome(outcome_node_label) values: dict[str, npt.NDArray[float]] = {} - for node in nodes: - values[node.label] = node.sample(values, samples) + keys = jax.random.split(rng_key, len(nodes)) + + for node, key in zip(nodes, keys, strict=False): + values[node.label] = node.sample(values, samples, rng_key=key) return values[outcome_node_label] def expectation( graph: Graph, - outcome_node_label: str | None = None, - samples: int = 1000, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, ) -> float: """Estimate the expectation of a graph.""" - return sample(graph, outcome_node_label, samples).mean() + return sample(graph, outcome_node_label, samples, rng_key=rng_key).mean() def standard_deviation( graph: Graph, - outcome_node_label: str | None = None, - samples: int = 1000, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, ) -> float: """Estimate the standard deviation of a graph.""" - return np.std(sample(graph, outcome_node_label, samples)) + return sample(graph, outcome_node_label, samples, rng_key=rng_key).std() diff --git a/src/causalprog/distribution/family.py b/src/causalprog/distribution/family.py index c3e150d..3ee2175 100644 --- a/src/causalprog/distribution/family.py +++ b/src/causalprog/distribution/family.py @@ -39,8 +39,8 @@ class DistributionFamily(Generic[CreatesDistribution], Labelled): @property def _member(self) -> Callable[..., Distribution]: """Constructor method for family members, given parameters.""" - return lambda *parameters: Distribution( - self._family(*parameters), + return lambda **parameters: Distribution( + self._family(**parameters), backend_translator=self._family_translator, ) @@ -67,13 +67,13 @@ def __init__( self._family = backend_family self._family_translator = backend_translator - def construct(self, *parameters: ArrayLike) -> Distribution: + def construct(self, **parameters: ArrayLike) -> Distribution: """ Create a distribution from an explicit set of parameters. Args: - *parameters (ArrayLike): Parameters that define a member of this family, + **parameters (ArrayLike): Parameters that define a member of this family, passed as sequential arguments. """ - return self._member(*parameters) + return self._member(**parameters) diff --git a/src/causalprog/distribution/normal.py b/src/causalprog/distribution/normal.py index 349cffa..a8dfbb6 100644 --- a/src/causalprog/distribution/normal.py +++ b/src/causalprog/distribution/normal.py @@ -58,6 +58,8 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None: cov (ArrayCompatible): Matrix of covariates, $\Sigma$. """ + mean = jnp.atleast_1d(mean) + cov = jnp.atleast_2d(cov) super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") @@ -76,7 +78,7 @@ def __init__(self) -> None: """Create a family of normal distributions.""" super().__init__(Normal, family_name="Normal") - def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: + def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003 r""" Construct a normal distribution with the given mean and covariates. @@ -85,4 +87,4 @@ def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: cov (ArrayCompatible): Matrix of covariates, $\Sigma$. """ - return super().construct(mean, cov) + return super().construct(mean=mean, cov=cov) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index bc51157..05fc183 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -3,48 +3,17 @@ from __future__ import annotations import typing -from abc import ABC, abstractmethod +from abc import abstractmethod +import jax import numpy as np if typing.TYPE_CHECKING: import numpy.typing as npt -from causalprog._abc.labelled import Labelled - - -class Distribution(ABC): - """Placeholder class.""" - - @abstractmethod - def sample( - self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int - ) -> npt.NDArray[float]: - """Sample.""" - - -class NormalDistribution(Distribution): - """Normal distribution.""" + from causalprog.distribution.family import DistributionFamily - def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None: - """Initialise.""" - self.mean = mean - self.std_dev = std_dev - - def sample( - self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int - ) -> npt.NDArray[float]: - """Sample a normal distribution with mean 1.""" - values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002 - if isinstance(self.std_dev, str): - values *= sampled_dependencies[self.std_dev] - else: - values *= self.std_dev - if isinstance(self.mean, str): - values += sampled_dependencies[self.mean] - else: - values += self.mean - return values +from causalprog._abc.labelled import Labelled class Node(Labelled): @@ -57,7 +26,10 @@ def __init__(self, label: str, *, is_outcome: bool = False) -> None: @abstractmethod def sample( - self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int + self, + sampled_dependencies: dict[str, npt.NDArray[float]], + samples: int, + rng_key: jax.Array, ) -> float: """Sample a value from the node.""" @@ -72,20 +44,40 @@ class DistributionNode(Node): def __init__( self, - distribution: Distribution, + distribution: DistributionFamily, label: str, *, + parameters: dict[str, str] | None = None, + constant_parameters: dict[str, float] | None = None, is_outcome: bool = False, ) -> None: """Initialise.""" self._dist = distribution + self._constant_parameters = constant_parameters if constant_parameters else {} + self._parameters = parameters if parameters else {} super().__init__(label, is_outcome=is_outcome) def sample( - self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int - ) -> float: + self, + sampled_dependencies: dict[str, npt.NDArray[float]], + samples: int, + rng_key: jax.Array, + ) -> npt.NDArray[float]: """Sample a value from the node.""" - return self._dist.sample(sampled_dependencies, samples) + if not self._parameters: + concrete_dist = self._dist.construct(**self._constant_parameters) + return concrete_dist.sample(rng_key, samples) + output = np.zeros(samples) + new_key = jax.random.split(rng_key, samples) + for sample in range(samples): + parameters = { + i: sampled_dependencies[j][sample] for i, j in self._parameters.items() + } + concrete_dist = self._dist.construct( + **parameters, **self._constant_parameters + ) + output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0] + return output def __repr__(self) -> str: return f'DistributionNode("{self.label}")' diff --git a/tests/test_distributions/conftest.py b/tests/conftest.py similarity index 85% rename from tests/test_distributions/conftest.py rename to tests/conftest.py index e245de4..6400019 100644 --- a/tests/test_distributions/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ def rng_key(seed: int): @pytest.fixture -def n_dim_std_normal(request) -> tuple[Array, Array]: +def n_dim_std_normal(request) -> dict[str, Array]: """ Mean and covariance matrix of the n-dimensional standard normal distribution. @@ -24,4 +24,4 @@ def n_dim_std_normal(request) -> tuple[Array, Array]: n_dims = request.param mean = jnp.array([0.0] * n_dims) cov = jnp.diag(jnp.array([1.0] * n_dims)) - return mean, cov + return {"mean": mean, "cov": cov} diff --git a/tests/test_distributions/test_family.py b/tests/test_distributions/test_family.py index 9a95942..72a29e7 100644 --- a/tests/test_distributions/test_family.py +++ b/tests/test_distributions/test_family.py @@ -19,8 +19,10 @@ def test_builder_matches_backend(n_dim_std_normal) -> None: mnv = distrax.MultivariateNormalFullCovariance mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed")) - via_family = mnv_family.construct(*n_dim_std_normal) - via_backend = mnv(*n_dim_std_normal) + via_family = mnv_family.construct( + loc=n_dim_std_normal["mean"], covariance_matrix=n_dim_std_normal["cov"] + ) + via_backend = mnv(n_dim_std_normal["mean"], n_dim_std_normal["cov"]) assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0) assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0) diff --git a/tests/test_graph.py b/tests/test_graph.py index ca42200..7ad091f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,16 +2,19 @@ import re +import jax import numpy as np import pytest import causalprog +from causalprog.distribution.normal import NormalFamily +from causalprog.graph import DistributionNode, Graph def test_label(): - d = causalprog.graph.node.NormalDistribution() - node = causalprog.graph.DistributionNode(d, "X") - node2 = causalprog.graph.DistributionNode(d, "Y") + d = NormalFamily() + node = DistributionNode(d, "X") + node2 = DistributionNode(d, "Y") node_copy = node assert node.label == node_copy.label == "X" @@ -23,12 +26,12 @@ def test_label(): def test_duplicate_label(): - d = causalprog.graph.node.NormalDistribution() + d = NormalFamily() - graph = causalprog.graph.Graph("G0") - graph.add_node(causalprog.graph.DistributionNode(d, "X")) + graph = Graph("G0") + graph.add_node(DistributionNode(d, "X")) with pytest.raises(ValueError, match=re.escape("Duplicate node label: X")): - graph.add_node(causalprog.graph.DistributionNode(d, "X")) + graph.add_node(DistributionNode(d, "X")) @pytest.mark.parametrize( @@ -38,12 +41,12 @@ def test_duplicate_label(): def test_build_graph(*, use_labels: bool) -> None: root_label = "root" outcome_label = "outcome_label" - d = causalprog.graph.node.NormalDistribution() + d = NormalFamily() - root_node = causalprog.graph.DistributionNode(d, root_label) - outcome_node = causalprog.graph.DistributionNode(d, outcome_label, is_outcome=True) + root_node = DistributionNode(d, root_label) + outcome_node = DistributionNode(d, outcome_label, is_outcome=True) - graph = causalprog.graph.Graph("G0") + graph = Graph("G0") graph.add_node(root_node) graph.add_node(outcome_node) @@ -56,16 +59,16 @@ def test_build_graph(*, use_labels: bool) -> None: def test_roots_down_to_outcome() -> None: - d = causalprog.graph.node.NormalDistribution() + d = NormalFamily() - graph = causalprog.graph.Graph("G0") + graph = Graph("G0") - u = causalprog.graph.DistributionNode(d, "U") - v = causalprog.graph.DistributionNode(d, "V") - w = causalprog.graph.DistributionNode(d, "W") - x = causalprog.graph.DistributionNode(d, "X") - y = causalprog.graph.DistributionNode(d, "Y") - z = causalprog.graph.DistributionNode(d, "Z") + u = DistributionNode(d, "U") + v = DistributionNode(d, "V") + w = DistributionNode(d, "W") + x = DistributionNode(d, "X") + y = DistributionNode(d, "Y") + z = DistributionNode(d, "Z") graph.add_node(u) graph.add_node(v) @@ -95,13 +98,13 @@ def test_roots_down_to_outcome() -> None: def test_cycle() -> None: - d = causalprog.graph.node.NormalDistribution() + d = NormalFamily() - node0 = causalprog.graph.DistributionNode(d, "X") - node1 = causalprog.graph.DistributionNode(d, "Y") - node2 = causalprog.graph.DistributionNode(d, "Z") + node0 = DistributionNode(d, "X") + node1 = DistributionNode(d, "Y") + node2 = DistributionNode(d, "Z") - graph = causalprog.graph.Graph("G0") + graph = Graph("G0") graph.add_edge(node0, node1) graph.add_edge(node1, node2) graph.add_edge(node2, node0) @@ -113,27 +116,61 @@ def test_cycle() -> None: @pytest.mark.parametrize( ("mean", "stdev", "samples", "rtol"), [ - pytest.param(1.0, 1.0, 10, 1, id="std normal, 10 samples"), - pytest.param(2.0, 0.8, 1000, 1e-1, id="non-standard normal, 100 samples"), - pytest.param(1.0, 1.0, 100000, 1e-2, id="std normal, 10^5 samples"), - pytest.param(1.0, 1.0, 10000000, 1e-3, id="std normal, 10^7 samples"), + pytest.param(1.0, 1.0, 10, 1, id="N(mean=1, stdev=1), 10 samples"), + pytest.param(2.0, 0.8, 1000, 1e-1, id="N(mean=2, stdev=0.8), 1000 samples"), + pytest.param(1.0, 0.8, 100000, 1e-2, id="N(mean=1, stdev=0.8), 10^5 samples"), + pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_single_normal_node(samples, rtol, mean, stdev): - normal = causalprog.graph.node.NormalDistribution(mean, stdev) - node = causalprog.graph.DistributionNode(normal, "X", is_outcome=True) +def test_single_normal_node(samples, rtol, mean, stdev, rng_key): + node = DistributionNode( + NormalFamily(), + "X", + constant_parameters={"mean": mean, "cov": stdev**2}, + is_outcome=True, + ) - graph = causalprog.graph.Graph("G0") + graph = Graph("G0") graph.add_node(node) + # To compensate for rng-key splitting in sample methods, note the "split" key + # that is actually used to draw the samples from the distribution, so we can + # attempt to replicate its behaviour explicitly. + key = jax.random.split(rng_key, 1)[0] + what_we_should_get = jax.random.multivariate_normal( + key, jax.numpy.atleast_1d(mean), jax.numpy.atleast_2d(stdev**2), shape=samples + ) + expected_mean = what_we_should_get.mean() + expected_std_dev = what_we_should_get.std() + + # Check within hand-computation assert np.isclose( - causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + mean, + rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation(graph, samples=samples), + causalprog.algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), stdev, rtol=rtol, ) + # Check within computational distance + assert np.isclose( + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + expected_mean, + ) + assert np.isclose( + causalprog.algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + expected_std_dev, + ) @pytest.mark.parametrize( @@ -165,20 +202,37 @@ def test_single_normal_node(samples, rtol, mean, stdev): ), ], ) -def test_two_node_graph(samples, rtol, mean, stdev, stdev2): - normal = causalprog.graph.node.NormalDistribution(mean, stdev) - normal2 = causalprog.graph.node.NormalDistribution("UX", stdev2) - +def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): + if samples > 100: # noqa: PLR2004 + pytest.xfail("Test currently too slow") graph = causalprog.graph.Graph("G0") - graph.add_node(causalprog.graph.DistributionNode(normal, "UX")) - graph.add_node(causalprog.graph.DistributionNode(normal2, "X", is_outcome=True)) + graph.add_node( + DistributionNode( + NormalFamily(), "UX", constant_parameters={"mean": mean, "cov": stdev**2} + ) + ) + graph.add_node( + DistributionNode( + NormalFamily(), + "X", + parameters={"mean": "UX"}, + constant_parameters={"cov": stdev2**2}, + is_outcome=True, + ) + ) graph.add_edge("UX", "X") assert np.isclose( - causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + mean, + rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation(graph, samples=samples), + causalprog.algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), np.sqrt(stdev**2 + stdev2**2), rtol=rtol, )