Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions catalax/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd
from jax import Array
from pydantic import BaseModel, Field
from pyenzyme import EnzymeMLDocument
from pyenzyme import EnzymeMLDocument, read_enzymeml

from .croissant import extract_record_set, json_lines_to_dict
from .measurement import Measurement
Expand Down Expand Up @@ -266,16 +266,20 @@ def read_enzymeml(cls, path: Path | str) -> "Dataset":

Returns:
Dataset: The Dataset object.

Raises:
ValueError: If the file format is not supported.
NotImplementedError: If OMEX files are not supported yet.
"""

if not isinstance(path, Path):
path = Path(path)

# If it ends with .json, it's a v2 file
if path.suffix == ".json":
enzmldoc = EnzymeMLDocument.read(path)
enzmldoc = read_enzymeml(path)
elif path.suffix == ".omex":
enzmldoc = EnzymeMLDocument.from_sbml(path)
raise NotImplementedError("OMEX files are not supported yet.")
else:
raise ValueError(
"Unknown file format. Please provide a .json or .omex file."
Expand Down Expand Up @@ -401,10 +405,10 @@ def from_model(cls, model: "Model"):
Returns:
Dataset: The dataset object.
"""

from ..model import Model

assert isinstance(model, Model), "Expected a Model object."
if not isinstance(model, Model):
raise ValueError(f"Expected a Model object. Got {type(model)}")

return cls(
id=model.name,
Expand Down Expand Up @@ -500,7 +504,7 @@ def plot(
path: Optional[str] = None,
measurement_ids: List[str] = [],
figsize: Tuple[int, int] = (5, 3),
model: "Model" = None,
model: Optional["Model"] = None,
):
"""Plots all measurements in the dataset.

Expand Down
4 changes: 2 additions & 2 deletions catalax/dataset/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def plot(
Args:
show (bool): Whether to show the plot. Defaults to True.
ax (Optional[plt.Axes]): The axes to plot the data on. Defaults to None.
model (Model): The model to plot the fit of. Defaults to None.
model (Optional[Model]): The model to plot the fit of. Defaults to None.
"""

is_subplot = ax is not None
Expand Down Expand Up @@ -379,7 +379,7 @@ def plot(
ax.grid(alpha=0.3, linestyle="--")
ax.set_xlabel("Time", fontsize=12)
ax.set_ylabel("Concentration", fontsize=12)
ax.set_title(init_title, fontsize=12)
ax.set_title(init_title, fontsize=8)

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

Expand Down
17 changes: 13 additions & 4 deletions catalax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from .mcmc import run_mcmc
from .plotting import plot_corner, plot_posterior, plot_trace, plot_forest
from . import priors

import arviz as az

from . import priors
from .mcmc import run_mcmc
from .plotting import plot_corner, plot_forest, plot_posterior, plot_trace

# Set plotting style
az.style.use("arviz-doc")

__all__ = [
"run_mcmc",
"plot_corner",
"plot_posterior",
"plot_trace",
"plot_forest",
"priors",
]
44 changes: 28 additions & 16 deletions catalax/mcmc/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Dict, Tuple

import arviz as az
Expand All @@ -11,33 +12,40 @@
def plot_corner(
mcmc: MCMC,
quantiles: Tuple[float, float, float] = (0.16, 0.5, 0.84),
):
digits_after_decimal: int = 2,
) -> plt.Figure:
"""Plots the correlation between the parameters.

Args:
mcmc (MCMC): The MCMC object to plot.
model (Model): The model to infer the parameters of.
quantiles (Tuple[float, float, float]): The quantiles to plot. Defaults to (0.16, 0.5, 0.84).
digits_after_decimal (int): The number of digits to show after the decimal point. Defaults to 2.

Returns:
plt.Figure: The figure.
"""

data = az.from_numpyro(mcmc)
fig = corner.corner(
data,
plot_contours=False,
quantiles=list(quantiles),
bins=20,
show_titles=True,
title_kwargs={"fontsize": 12},
divergences=True,
use_math_text=False,
var_names=[var for var in mcmc.get_samples().keys() if var != "sigma"],
)

fig.tight_layout()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
fig = corner.corner(
data,
plot_contours=False,
quantiles=list(quantiles),
bins=20,
show_titles=True,
title_kwargs={"fontsize": 12},
divergences=True,
use_math_text=False,
var_names=[var for var in mcmc.get_samples().keys() if var != "sigma"],
title_fmt=f".{digits_after_decimal}f",
)

return fig


def plot_posterior(mcmc, model, **kwargs) -> None:
def plot_posterior(mcmc: MCMC, model, **kwargs):
"""Plots the posterior distribution of the given bayes analysis"""

inf_data = az.from_numpyro(mcmc)
Expand All @@ -47,7 +55,11 @@ def plot_posterior(mcmc, model, **kwargs) -> None:


def plot_credibility_interval(
mcmc, model, initial_condition: Dict[str, float], time: jax.Array, dt0: float = 0.1
mcmc: MCMC,
model,
initial_condition: Dict[str, float],
time: jax.Array,
dt0: float = 0.1,
) -> None:
"""Plots the credibility interval for a single simulation"""

Expand Down
57 changes: 40 additions & 17 deletions catalax/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dotted_dict import DottedDict
from jax import Array
from pydantic import ConfigDict, Field, PrivateAttr, field_validator
from pyenzyme import EnzymeMLDocument
from pyenzyme import EnzymeMLDocument, read_enzymeml
from sympy import Expr, Matrix, Symbol, symbols, sympify

from catalax.mcmc import priors
Expand Down Expand Up @@ -75,7 +75,7 @@ class Model(CatalaxBase):
def add_ode(
self,
species: str,
equation: str, # type: ignore
equation: str,
observable: bool = True,
species_map: Optional[Dict[str, str]] = None,
):
Expand Down Expand Up @@ -103,7 +103,7 @@ def add_ode(
equation (str): The equation that describes the dynamics of the species.

Raises:
ValueError: _description_
ValueError: If the species is not a string or a SymPy expression.
"""

if any(str(ode.species.name) == species for ode in self.odes.values()):
Expand All @@ -112,13 +112,22 @@ def add_ode(
)

if isinstance(equation, str):
equation: Expr = sympify(equation)
sympy_equation: Expr = sympify(equation)
elif isinstance(equation, Expr):
sympy_equation = equation
else:
raise ValueError(
f"Equation must be a string or a SymPy expression, got {type(equation)}"
)

if species not in self.species:
self.add_species(name=species, species_map=species_map)
if species_map: # not None and not empty
self.add_species(**species_map)
else:
self.add_species(species_string=species)

self.odes[species] = ODE(
equation=equation,
equation=sympy_equation,
species=self.species[species],
observable=observable,
)
Expand Down Expand Up @@ -774,9 +783,12 @@ def from_dict(cls, model_dict: Dict):
return model

@classmethod
def from_enzymeml(cls, path: Path | str, name: str | None = None):
with open(path, "r") as f:
enzmldoc = EnzymeMLDocument.model_validate_json(f.read())
def from_enzymeml(cls, enzmldoc: EnzymeMLDocument, name: str | None = None):
"""Initializes a model from an EnzymeML document.

Args:
enzmldoc (EnzymeMLDocument): The EnzymeML document to create the model from.
"""

if name is None:
name = enzmldoc.name
Expand All @@ -797,31 +809,42 @@ def from_enzymeml(cls, path: Path | str, name: str | None = None):

# create regex match objects for all species
species_regex = re.compile(r"|".join(map(re.escape, all_species)))

for reaction in enzmldoc.reactions:
# get species from reaction
match_species = species_regex.findall(reaction.kinetic_law.equation)

# add species to model
for species in set(match_species):
model.add_species(species)
model.add_species(", ".join(match_species))

# add ode to model
model.add_ode(
species=reaction.kinetic_law.species_id,
equation=reaction.kinetic_law.equation,
observable=True
if reaction.kinetic_law.species_id in observables
else False,
observable=True,
)

# add for all model.species that don't have an ode, a constant rate of 0
for species in model.species:
if species not in model.odes:
for species in non_observables:
if species not in model.species:
model.add_ode(species=species, equation="0", observable=False)

return model

@classmethod
def read_enzymeml(cls, path: Path | str, name: str | None = None):
"""Initializes a model from the "reactions" section of an EnzymeML document.

Args:
path (Path | str): Path to the EnzymeML document.
name (str | None, optional): Name of the model. Defaults to None.

Returns:
Model: Resulting model instance.
"""
enzmldoc = read_enzymeml(path)

return cls.from_enzymeml(enzmldoc, name)

def update_enzymeml_parameters(self, enzmldoc: EnzymeMLDocument):
"""Updates model parameters of enzymeml document with model parameters.
Existing parameters will be updated, non-existing parameters will be added.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "catalax"
version = "0.4.1"
version = "0.4.2"
description = "A JAX-based framework for (neural) ODE modelling in biocatalysis."
authors = ["Jan Range <[email protected]>"]
license = "MIT License"
Expand Down