Skip to content

Commit d8c41f3

Browse files
authored
Refactor rateOf handling (#3036)
* doc: Skip petab-sciml for now GitHub Action runners run out of disk space when installing petab-sciml with all its huge dependencies. Don't install that for now. So far, it's not used anywhere for the documentation build as far as I can see. This won't prevent enabling intersphinx later on. * Refactor rateOf handling Pull rateOf-handling out of DEModel and keep it along other SBML processing where it belongs. This is easier to follow and prevents some lingering issues with the old approach due to the xdot / w interdependencies. This also handles rateOf expressions in some additional, previously unsupported places like event assignments.
1 parent f28efae commit d8c41f3

File tree

3 files changed

+193
-185
lines changed

3 files changed

+193
-185
lines changed

python/sdist/amici/de_model.py

Lines changed: 59 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import contextlib
65
import copy
76
import itertools
87
import logging
@@ -47,7 +46,6 @@
4746
from .exporters.sundials.cxxcodeprinter import csc_matrix
4847
from .importers.utils import (
4948
ObservableTransformation,
50-
SBMLException,
5149
_default_simplify,
5250
amici_time_symbol,
5351
smart_subs_dict,
@@ -340,6 +338,10 @@ def algebraic_states(self) -> list[AlgebraicState]:
340338
"""Get all algebraic states."""
341339
return self._algebraic_states
342340

341+
def algebraic_equations(self) -> list[AlgebraicEquation]:
342+
"""Get all algebraic equations."""
343+
return self._algebraic_equations
344+
343345
def observables(self) -> list[Observable]:
344346
"""Get all observables."""
345347
return self._observables
@@ -404,165 +406,6 @@ def states(self) -> list[State]:
404406
"""Get all states."""
405407
return self._differential_states + self._algebraic_states
406408

407-
def _process_sbml_rate_of(self) -> None:
408-
"""Substitute any SBML-rateOf constructs in the model equations"""
409-
from sbmlmath import rate_of as rate_of_func
410-
411-
species_sym_to_xdot = dict(
412-
zip(self.sym("x"), self.sym("xdot"), strict=True)
413-
)
414-
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}
415-
416-
def get_rate(symbol: sp.Symbol):
417-
"""Get rate of change of the given symbol"""
418-
if symbol.find(rate_of_func):
419-
raise SBMLException("Nesting rateOf() is not allowed.")
420-
421-
# Replace all rateOf(some_species) by their respective xdot equation
422-
with contextlib.suppress(KeyError):
423-
return self._eqs["xdot"][species_sym_to_idx[symbol]]
424-
425-
# For anything other than a state, rateOf(.) is 0 or invalid
426-
return 0
427-
428-
# replace rateOf-instances in xdot by xdot symbols
429-
made_substitutions = False
430-
for i_state in range(len(self.eq("xdot"))):
431-
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
432-
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
433-
{
434-
# either the rateOf argument is a state, or it's 0
435-
rate_of: species_sym_to_xdot.get(rate_of.args[0], 0)
436-
for rate_of in rate_ofs
437-
}
438-
)
439-
made_substitutions = True
440-
441-
if made_substitutions:
442-
# substitute in topological order
443-
subs = toposort_symbols(
444-
dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True))
445-
)
446-
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
447-
448-
# replace rateOf-instances in w by xdot equation
449-
# here we may need toposort, as xdot may depend on w
450-
made_substitutions = False
451-
for i_expr in range(len(self.eq("w"))):
452-
new, replacement = self._eqs["w"][i_expr].replace(
453-
rate_of_func, get_rate, map=True
454-
)
455-
if replacement:
456-
self._eqs["w"][i_expr] = new
457-
made_substitutions = True
458-
459-
if made_substitutions:
460-
# Sort expressions in self._expressions, w symbols, and w equations
461-
# in topological order. Ideally, this would already happen before
462-
# adding the expressions to the model, but at that point, we don't
463-
# have access to xdot yet.
464-
# NOTE: elsewhere, conservations law expressions are expected to
465-
# occur before any other w expressions, so we must maintain their
466-
# position
467-
# toposort everything but conservation law expressions,
468-
# then prepend conservation laws
469-
w_sorted = toposort_symbols(
470-
dict(
471-
zip(
472-
self.sym("w")[self.num_cons_law() :, :],
473-
self.eq("w")[self.num_cons_law() :, :],
474-
strict=True,
475-
)
476-
)
477-
)
478-
w_sorted = (
479-
dict(
480-
zip(
481-
self.sym("w")[: self.num_cons_law(), :],
482-
self.eq("w")[: self.num_cons_law(), :],
483-
strict=True,
484-
)
485-
)
486-
| w_sorted
487-
)
488-
old_syms = tuple(self._syms["w"])
489-
topo_expr_syms = tuple(w_sorted.keys())
490-
new_order = [old_syms.index(s) for s in topo_expr_syms]
491-
self._expressions = [self._expressions[i] for i in new_order]
492-
self._syms["w"] = sp.Matrix(topo_expr_syms)
493-
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
494-
495-
# replace rateOf-instances in x0 by xdot equation
496-
# indices of state variables whose x0 was modified
497-
changed_indices = []
498-
for i_state in range(len(self.eq("x0"))):
499-
new, replacement = self._eqs["x0"][i_state].replace(
500-
rate_of_func, get_rate, map=True
501-
)
502-
if replacement:
503-
self._eqs["x0"][i_state] = new
504-
changed_indices.append(i_state)
505-
if changed_indices:
506-
# Replace any newly introduced state variables
507-
# by their x0 expressions.
508-
# Also replace any newly introduced `w` symbols by their
509-
# expressions (after `w` was toposorted above).
510-
subs = toposort_symbols(
511-
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
512-
)
513-
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
514-
for i_state in changed_indices:
515-
self._eqs["x0"][i_state] = smart_subs_dict(
516-
self._eqs["x0"][i_state], subs
517-
)
518-
519-
for component in chain(
520-
self.observables(),
521-
self.events(),
522-
self._algebraic_equations,
523-
):
524-
if rate_ofs := component.get_val().find(rate_of_func):
525-
if isinstance(component, Event):
526-
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
527-
# see, e.g., sbml test case 01293
528-
raise SBMLException(
529-
"AMICI does currently not support rateOf(.) inside event trigger functions."
530-
)
531-
532-
if isinstance(component, AlgebraicEquation):
533-
# TODO IDACalcIC fails with
534-
# "The linesearch algorithm failed: step too small or too many backtracks."
535-
# see, e.g., sbml test case 01482
536-
raise SBMLException(
537-
"AMICI does currently not support rateOf(.) inside AlgebraicRules."
538-
)
539-
540-
component.set_val(
541-
component.get_val().subs(
542-
{
543-
rate_of: get_rate(rate_of.args[0])
544-
for rate_of in rate_ofs
545-
}
546-
)
547-
)
548-
549-
for event in self.events():
550-
state_update = event.get_state_update(
551-
x=self.sym("x"), x_old=self.sym("x")
552-
)
553-
if state_update is None:
554-
continue
555-
556-
for i_state in range(len(state_update)):
557-
if rate_ofs := state_update[i_state].find(rate_of_func):
558-
raise SBMLException(
559-
"AMICI does currently not support rateOf(.) inside event state updates."
560-
)
561-
# TODO here we need xdot sym, not eqs
562-
# event._state_update[i_state] = event._state_update[i_state].subs(
563-
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
564-
# )
565-
566409
def add_component(
567410
self, component: ModelQuantity, insert_first: bool | None = False
568411
) -> None:
@@ -1271,9 +1114,7 @@ def generate_basic_variables(self) -> None:
12711114
Generates the symbolic identifiers for all variables in
12721115
``DEModel._variable_prototype``
12731116
"""
1274-
# We need to process events and Heaviside functions in the ``DEModel`,
1275-
# before adding it to DEExporter
1276-
self.parse_events()
1117+
self._reorder_events()
12771118

12781119
for var in self._variable_prototype:
12791120
if var not in self._syms:
@@ -1335,7 +1176,11 @@ def parse_events(self) -> None:
13351176
for event in self.events():
13361177
event.set_val(event.get_val().subs(w_toposorted))
13371178

1338-
# re-order events - first those that require root tracking, then the others
1179+
def _reorder_events(self) -> None:
1180+
"""
1181+
Re-order events - first those that require root tracking,
1182+
then the others.
1183+
"""
13391184
constant_syms = set(self.sym("k")) | set(self.sym("p"))
13401185
self._events = list(
13411186
chain(
@@ -2694,22 +2539,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
26942539
self._observables = [self._observables[i] for i in new_order]
26952540

26962541
if added_expressions:
2697-
# toposort expressions
2698-
w_sorted = toposort_symbols(
2699-
dict(
2700-
zip(
2701-
self.sym("w"),
2702-
self.eq("w"),
2703-
strict=True,
2704-
)
2705-
)
2706-
)
2707-
old_syms = tuple(self._syms["w"])
2708-
topo_expr_syms = tuple(w_sorted.keys())
2709-
new_order = [old_syms.index(s) for s in topo_expr_syms]
2710-
self._expressions = [self._expressions[i] for i in new_order]
2711-
self._syms["w"] = sp.Matrix(topo_expr_syms)
2712-
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
2542+
self.toposort_expressions()
27132543

27142544
def get_explicit_roots(self) -> set[sp.Expr]:
27152545
"""
@@ -2752,3 +2582,51 @@ def has_event_assignments(self) -> bool:
27522582
boolean indicating if event assignments are present
27532583
"""
27542584
return any(event.updates_state for event in self._events)
2585+
2586+
def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
2587+
"""
2588+
Sort expressions in topological order.
2589+
2590+
:return:
2591+
dict of expression symbols to expressions in topological order
2592+
"""
2593+
# ensure no symbols or equations that depend on `w` have been generated
2594+
# yet, otherwise the re-ordering might break dependencies
2595+
if (
2596+
generated := set(self._syms)
2597+
| set(self._eqs)
2598+
| set(self._sparsesyms)
2599+
| set(self._sparseeqs)
2600+
) - {"w", "p", "k", "x", "x_rdata"}:
2601+
raise AssertionError(
2602+
"This function must be called before computing any "
2603+
"derivatives. The following symbols/equations are already "
2604+
f"generated: {generated}"
2605+
)
2606+
2607+
# NOTE: elsewhere, conservations law expressions are expected to
2608+
# occur before any other w expressions, so we must maintain their
2609+
# position.
2610+
# toposort everything but conservation law expressions,
2611+
# then prepend conservation laws
2612+
2613+
w_toposorted = toposort_symbols(
2614+
{
2615+
e.get_sym(): e.get_val()
2616+
for e in self.expressions()[self.num_cons_law() :]
2617+
}
2618+
)
2619+
2620+
w_toposorted = {
2621+
e.get_sym(): e.get_val()
2622+
for e in self.expressions()[: self.num_cons_law()]
2623+
} | w_toposorted
2624+
2625+
old_syms = tuple(e.get_sym() for e in self.expressions())
2626+
topo_expr_syms = tuple(w_toposorted)
2627+
new_order = [old_syms.index(s) for s in topo_expr_syms]
2628+
self._expressions = [self._expressions[i] for i in new_order]
2629+
self._syms["w"] = sp.Matrix(topo_expr_syms)
2630+
self._eqs["w"] = sp.Matrix(list(w_toposorted.values()))
2631+
2632+
return w_toposorted

python/sdist/amici/importers/pysb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def ode_model_from_pysb_importer(
398398

399399
_process_stoichiometric_matrix(model, ode, fixed_parameters)
400400

401+
ode.parse_events()
401402
ode.generate_basic_variables()
402403

403404
return ode

0 commit comments

Comments
 (0)