From 91c7b69caad981818c626637e78c98cb4a5d2d15 Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 00:06:59 +0200 Subject: [PATCH 1/6] Refactor Dataset file handling and error management - Updated import for EnzymeMLDocument to use read_enzymeml function. - Added error handling for unsupported file formats, raising ValueError for unknown formats and NotImplementedError for OMEX files. - Enhanced model validation by raising ValueError if the provided model is not of type Model. - Changed model parameter to be optional in the plotting function. --- catalax/dataset/dataset.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/catalax/dataset/dataset.py b/catalax/dataset/dataset.py index 5fbab5e..6b63553 100644 --- a/catalax/dataset/dataset.py +++ b/catalax/dataset/dataset.py @@ -18,7 +18,7 @@ import pandas as pd from jax import Array from pydantic import BaseModel, Field -from pyenzyme import EnzymeMLDocument +from pyenzyme import EnzymeMLDocument, read_enzymeml from .croissant import extract_record_set, json_lines_to_dict from .measurement import Measurement @@ -266,6 +266,10 @@ def read_enzymeml(cls, path: Path | str) -> "Dataset": Returns: Dataset: The Dataset object. + + Raises: + ValueError: If the file format is not supported. + NotImplementedError: If OMEX files are not supported yet. """ if not isinstance(path, Path): @@ -273,9 +277,9 @@ def read_enzymeml(cls, path: Path | str) -> "Dataset": # If it ends with .json, it's a v2 file if path.suffix == ".json": - enzmldoc = EnzymeMLDocument.read(path) + enzmldoc = read_enzymeml(path) elif path.suffix == ".omex": - enzmldoc = EnzymeMLDocument.from_sbml(path) + raise NotImplementedError("OMEX files are not supported yet.") else: raise ValueError( "Unknown file format. Please provide a .json or .omex file." @@ -401,10 +405,10 @@ def from_model(cls, model: "Model"): Returns: Dataset: The dataset object. """ - from ..model import Model - assert isinstance(model, Model), "Expected a Model object." + if not isinstance(model, Model): + raise ValueError(f"Expected a Model object. Got {type(model)}") return cls( id=model.name, @@ -500,7 +504,7 @@ def plot( path: Optional[str] = None, measurement_ids: List[str] = [], figsize: Tuple[int, int] = (5, 3), - model: "Model" = None, + model: Optional["Model"] = None, ): """Plots all measurements in the dataset. From 9c580908be4cc4b71608f0edcd8290ee73e5d645 Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 00:07:29 +0200 Subject: [PATCH 2/6] Update Measurement class documentation and adjust plot title font size - Changed the model parameter in the plot function to be optional. - Reduced the font size of the plot title for better visual consistency. --- catalax/dataset/measurement.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catalax/dataset/measurement.py b/catalax/dataset/measurement.py index bab63a6..3beb8e9 100644 --- a/catalax/dataset/measurement.py +++ b/catalax/dataset/measurement.py @@ -339,7 +339,7 @@ def plot( Args: show (bool): Whether to show the plot. Defaults to True. ax (Optional[plt.Axes]): The axes to plot the data on. Defaults to None. - model (Model): The model to plot the fit of. Defaults to None. + model (Optional[Model]): The model to plot the fit of. Defaults to None. """ is_subplot = ax is not None @@ -379,7 +379,7 @@ def plot( ax.grid(alpha=0.3, linestyle="--") ax.set_xlabel("Time", fontsize=12) ax.set_ylabel("Concentration", fontsize=12) - ax.set_title(init_title, fontsize=12) + ax.set_title(init_title, fontsize=8) ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) From 8d002903c3586102d1f7f33ca44e1884ddaf52cf Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 00:08:10 +0200 Subject: [PATCH 3/6] Enhance Model class with improved EnzymeML handling - Updated the import statement for EnzymeMLDocument to include read_enzymeml. - Refactored the from_enzymeml method to accept an EnzymeMLDocument directly. - Added a new read_enzymeml class method for initializing a model from an EnzymeML document. - Improved error handling for equation input in add_ode method, ensuring it accepts either a string or a SymPy expression. - Adjusted species addition logic in add_ode to handle species mapping more effectively. --- catalax/model/model.py | 57 +++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/catalax/model/model.py b/catalax/model/model.py index d5bea6c..9498a49 100644 --- a/catalax/model/model.py +++ b/catalax/model/model.py @@ -14,7 +14,7 @@ from dotted_dict import DottedDict from jax import Array from pydantic import ConfigDict, Field, PrivateAttr, field_validator -from pyenzyme import EnzymeMLDocument +from pyenzyme import EnzymeMLDocument, read_enzymeml from sympy import Expr, Matrix, Symbol, symbols, sympify from catalax.mcmc import priors @@ -75,7 +75,7 @@ class Model(CatalaxBase): def add_ode( self, species: str, - equation: str, # type: ignore + equation: str, observable: bool = True, species_map: Optional[Dict[str, str]] = None, ): @@ -103,7 +103,7 @@ def add_ode( equation (str): The equation that describes the dynamics of the species. Raises: - ValueError: _description_ + ValueError: If the species is not a string or a SymPy expression. """ if any(str(ode.species.name) == species for ode in self.odes.values()): @@ -112,13 +112,22 @@ def add_ode( ) if isinstance(equation, str): - equation: Expr = sympify(equation) + sympy_equation: Expr = sympify(equation) + elif isinstance(equation, Expr): + sympy_equation = equation + else: + raise ValueError( + f"Equation must be a string or a SymPy expression, got {type(equation)}" + ) if species not in self.species: - self.add_species(name=species, species_map=species_map) + if species_map: # not None and not empty + self.add_species(**species_map) + else: + self.add_species(species_string=species) self.odes[species] = ODE( - equation=equation, + equation=sympy_equation, species=self.species[species], observable=observable, ) @@ -774,9 +783,12 @@ def from_dict(cls, model_dict: Dict): return model @classmethod - def from_enzymeml(cls, path: Path | str, name: str | None = None): - with open(path, "r") as f: - enzmldoc = EnzymeMLDocument.model_validate_json(f.read()) + def from_enzymeml(cls, enzmldoc: EnzymeMLDocument, name: str | None = None): + """Initializes a model from an EnzymeML document. + + Args: + enzmldoc (EnzymeMLDocument): The EnzymeML document to create the model from. + """ if name is None: name = enzmldoc.name @@ -797,31 +809,42 @@ def from_enzymeml(cls, path: Path | str, name: str | None = None): # create regex match objects for all species species_regex = re.compile(r"|".join(map(re.escape, all_species))) - for reaction in enzmldoc.reactions: # get species from reaction match_species = species_regex.findall(reaction.kinetic_law.equation) - # add species to model for species in set(match_species): - model.add_species(species) + model.add_species(", ".join(match_species)) # add ode to model model.add_ode( species=reaction.kinetic_law.species_id, equation=reaction.kinetic_law.equation, - observable=True - if reaction.kinetic_law.species_id in observables - else False, + observable=True, ) # add for all model.species that don't have an ode, a constant rate of 0 - for species in model.species: - if species not in model.odes: + for species in non_observables: + if species not in model.species: model.add_ode(species=species, equation="0", observable=False) return model + @classmethod + def read_enzymeml(cls, path: Path | str, name: str | None = None): + """Initializes a model from the "reactions" section of an EnzymeML document. + + Args: + path (Path | str): Path to the EnzymeML document. + name (str | None, optional): Name of the model. Defaults to None. + + Returns: + Model: Resulting model instance. + """ + enzmldoc = read_enzymeml(path) + + return cls.from_enzymeml(enzmldoc, name) + def update_enzymeml_parameters(self, enzmldoc: EnzymeMLDocument): """Updates model parameters of enzymeml document with model parameters. Existing parameters will be updated, non-existing parameters will be added. From 10d0e7867f4ad8f19a15fda688909646c60ced1c Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 00:08:38 +0200 Subject: [PATCH 4/6] Update version to 0.4.2 in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6e15c93..f4c2f1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "catalax" -version = "0.4.1" +version = "0.4.2" description = "A JAX-based framework for (neural) ODE modelling in biocatalysis." authors = ["Jan Range "] license = "MIT License" From dcc3a60d002c238db75fc17ef8d45e3d925d73cd Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 10:20:58 +0200 Subject: [PATCH 5/6] Refactor imports and define __all__ in mcmc module - Organized import statements for clarity and consistency. - Added __all__ to specify public API for the mcmc module, including run_mcmc, plotting functions, and priors. --- catalax/mcmc/__init__.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/catalax/mcmc/__init__.py b/catalax/mcmc/__init__.py index 4958ee0..fa5a2d3 100644 --- a/catalax/mcmc/__init__.py +++ b/catalax/mcmc/__init__.py @@ -1,8 +1,17 @@ -from .mcmc import run_mcmc -from .plotting import plot_corner, plot_posterior, plot_trace, plot_forest -from . import priors - import arviz as az +from . import priors +from .mcmc import run_mcmc +from .plotting import plot_corner, plot_forest, plot_posterior, plot_trace + # Set plotting style az.style.use("arviz-doc") + +__all__ = [ + "run_mcmc", + "plot_corner", + "plot_posterior", + "plot_trace", + "plot_forest", + "priors", +] From b33dd5178149b59d1bfeb9b5cc67b171cbe9a376 Mon Sep 17 00:00:00 2001 From: haeussma <83341109+haeussma@users.noreply.github.com> Date: Sat, 10 May 2025 10:21:35 +0200 Subject: [PATCH 6/6] Enhance plot_corner function with warnings handling and additional parameters - Added warnings handling to suppress UserWarnings during plotting. Due to layout control in corner.corner API. - Introduced a new parameter, digits_after_decimal, to customize the number of decimal places in plot titles. - Updated function documentation to reflect changes in parameters and return type. --- catalax/mcmc/plotting.py | 44 +++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/catalax/mcmc/plotting.py b/catalax/mcmc/plotting.py index bb6ca88..b64c8ea 100644 --- a/catalax/mcmc/plotting.py +++ b/catalax/mcmc/plotting.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict, Tuple import arviz as az @@ -11,33 +12,40 @@ def plot_corner( mcmc: MCMC, quantiles: Tuple[float, float, float] = (0.16, 0.5, 0.84), -): + digits_after_decimal: int = 2, +) -> plt.Figure: """Plots the correlation between the parameters. Args: mcmc (MCMC): The MCMC object to plot. - model (Model): The model to infer the parameters of. + quantiles (Tuple[float, float, float]): The quantiles to plot. Defaults to (0.16, 0.5, 0.84). + digits_after_decimal (int): The number of digits to show after the decimal point. Defaults to 2. + + Returns: + plt.Figure: The figure. """ data = az.from_numpyro(mcmc) - fig = corner.corner( - data, - plot_contours=False, - quantiles=list(quantiles), - bins=20, - show_titles=True, - title_kwargs={"fontsize": 12}, - divergences=True, - use_math_text=False, - var_names=[var for var in mcmc.get_samples().keys() if var != "sigma"], - ) - fig.tight_layout() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + fig = corner.corner( + data, + plot_contours=False, + quantiles=list(quantiles), + bins=20, + show_titles=True, + title_kwargs={"fontsize": 12}, + divergences=True, + use_math_text=False, + var_names=[var for var in mcmc.get_samples().keys() if var != "sigma"], + title_fmt=f".{digits_after_decimal}f", + ) return fig -def plot_posterior(mcmc, model, **kwargs) -> None: +def plot_posterior(mcmc: MCMC, model, **kwargs): """Plots the posterior distribution of the given bayes analysis""" inf_data = az.from_numpyro(mcmc) @@ -47,7 +55,11 @@ def plot_posterior(mcmc, model, **kwargs) -> None: def plot_credibility_interval( - mcmc, model, initial_condition: Dict[str, float], time: jax.Array, dt0: float = 0.1 + mcmc: MCMC, + model, + initial_condition: Dict[str, float], + time: jax.Array, + dt0: float = 0.1, ) -> None: """Plots the credibility interval for a single simulation"""