|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import contextlib |
6 | 5 | import copy |
7 | 6 | import itertools |
8 | 7 | import logging |
|
47 | 46 | from .exporters.sundials.cxxcodeprinter import csc_matrix |
48 | 47 | from .importers.utils import ( |
49 | 48 | ObservableTransformation, |
50 | | - SBMLException, |
51 | 49 | _default_simplify, |
52 | 50 | amici_time_symbol, |
53 | 51 | smart_subs_dict, |
@@ -340,6 +338,10 @@ def algebraic_states(self) -> list[AlgebraicState]: |
340 | 338 | """Get all algebraic states.""" |
341 | 339 | return self._algebraic_states |
342 | 340 |
|
| 341 | + def algebraic_equations(self) -> list[AlgebraicEquation]: |
| 342 | + """Get all algebraic equations.""" |
| 343 | + return self._algebraic_equations |
| 344 | + |
343 | 345 | def observables(self) -> list[Observable]: |
344 | 346 | """Get all observables.""" |
345 | 347 | return self._observables |
@@ -404,165 +406,6 @@ def states(self) -> list[State]: |
404 | 406 | """Get all states.""" |
405 | 407 | return self._differential_states + self._algebraic_states |
406 | 408 |
|
407 | | - def _process_sbml_rate_of(self) -> None: |
408 | | - """Substitute any SBML-rateOf constructs in the model equations""" |
409 | | - from sbmlmath import rate_of as rate_of_func |
410 | | - |
411 | | - species_sym_to_xdot = dict( |
412 | | - zip(self.sym("x"), self.sym("xdot"), strict=True) |
413 | | - ) |
414 | | - species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))} |
415 | | - |
416 | | - def get_rate(symbol: sp.Symbol): |
417 | | - """Get rate of change of the given symbol""" |
418 | | - if symbol.find(rate_of_func): |
419 | | - raise SBMLException("Nesting rateOf() is not allowed.") |
420 | | - |
421 | | - # Replace all rateOf(some_species) by their respective xdot equation |
422 | | - with contextlib.suppress(KeyError): |
423 | | - return self._eqs["xdot"][species_sym_to_idx[symbol]] |
424 | | - |
425 | | - # For anything other than a state, rateOf(.) is 0 or invalid |
426 | | - return 0 |
427 | | - |
428 | | - # replace rateOf-instances in xdot by xdot symbols |
429 | | - made_substitutions = False |
430 | | - for i_state in range(len(self.eq("xdot"))): |
431 | | - if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): |
432 | | - self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( |
433 | | - { |
434 | | - # either the rateOf argument is a state, or it's 0 |
435 | | - rate_of: species_sym_to_xdot.get(rate_of.args[0], 0) |
436 | | - for rate_of in rate_ofs |
437 | | - } |
438 | | - ) |
439 | | - made_substitutions = True |
440 | | - |
441 | | - if made_substitutions: |
442 | | - # substitute in topological order |
443 | | - subs = toposort_symbols( |
444 | | - dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True)) |
445 | | - ) |
446 | | - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) |
447 | | - |
448 | | - # replace rateOf-instances in w by xdot equation |
449 | | - # here we may need toposort, as xdot may depend on w |
450 | | - made_substitutions = False |
451 | | - for i_expr in range(len(self.eq("w"))): |
452 | | - new, replacement = self._eqs["w"][i_expr].replace( |
453 | | - rate_of_func, get_rate, map=True |
454 | | - ) |
455 | | - if replacement: |
456 | | - self._eqs["w"][i_expr] = new |
457 | | - made_substitutions = True |
458 | | - |
459 | | - if made_substitutions: |
460 | | - # Sort expressions in self._expressions, w symbols, and w equations |
461 | | - # in topological order. Ideally, this would already happen before |
462 | | - # adding the expressions to the model, but at that point, we don't |
463 | | - # have access to xdot yet. |
464 | | - # NOTE: elsewhere, conservations law expressions are expected to |
465 | | - # occur before any other w expressions, so we must maintain their |
466 | | - # position |
467 | | - # toposort everything but conservation law expressions, |
468 | | - # then prepend conservation laws |
469 | | - w_sorted = toposort_symbols( |
470 | | - dict( |
471 | | - zip( |
472 | | - self.sym("w")[self.num_cons_law() :, :], |
473 | | - self.eq("w")[self.num_cons_law() :, :], |
474 | | - strict=True, |
475 | | - ) |
476 | | - ) |
477 | | - ) |
478 | | - w_sorted = ( |
479 | | - dict( |
480 | | - zip( |
481 | | - self.sym("w")[: self.num_cons_law(), :], |
482 | | - self.eq("w")[: self.num_cons_law(), :], |
483 | | - strict=True, |
484 | | - ) |
485 | | - ) |
486 | | - | w_sorted |
487 | | - ) |
488 | | - old_syms = tuple(self._syms["w"]) |
489 | | - topo_expr_syms = tuple(w_sorted.keys()) |
490 | | - new_order = [old_syms.index(s) for s in topo_expr_syms] |
491 | | - self._expressions = [self._expressions[i] for i in new_order] |
492 | | - self._syms["w"] = sp.Matrix(topo_expr_syms) |
493 | | - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) |
494 | | - |
495 | | - # replace rateOf-instances in x0 by xdot equation |
496 | | - # indices of state variables whose x0 was modified |
497 | | - changed_indices = [] |
498 | | - for i_state in range(len(self.eq("x0"))): |
499 | | - new, replacement = self._eqs["x0"][i_state].replace( |
500 | | - rate_of_func, get_rate, map=True |
501 | | - ) |
502 | | - if replacement: |
503 | | - self._eqs["x0"][i_state] = new |
504 | | - changed_indices.append(i_state) |
505 | | - if changed_indices: |
506 | | - # Replace any newly introduced state variables |
507 | | - # by their x0 expressions. |
508 | | - # Also replace any newly introduced `w` symbols by their |
509 | | - # expressions (after `w` was toposorted above). |
510 | | - subs = toposort_symbols( |
511 | | - dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True)) |
512 | | - ) |
513 | | - subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs |
514 | | - for i_state in changed_indices: |
515 | | - self._eqs["x0"][i_state] = smart_subs_dict( |
516 | | - self._eqs["x0"][i_state], subs |
517 | | - ) |
518 | | - |
519 | | - for component in chain( |
520 | | - self.observables(), |
521 | | - self.events(), |
522 | | - self._algebraic_equations, |
523 | | - ): |
524 | | - if rate_ofs := component.get_val().find(rate_of_func): |
525 | | - if isinstance(component, Event): |
526 | | - # TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates |
527 | | - # see, e.g., sbml test case 01293 |
528 | | - raise SBMLException( |
529 | | - "AMICI does currently not support rateOf(.) inside event trigger functions." |
530 | | - ) |
531 | | - |
532 | | - if isinstance(component, AlgebraicEquation): |
533 | | - # TODO IDACalcIC fails with |
534 | | - # "The linesearch algorithm failed: step too small or too many backtracks." |
535 | | - # see, e.g., sbml test case 01482 |
536 | | - raise SBMLException( |
537 | | - "AMICI does currently not support rateOf(.) inside AlgebraicRules." |
538 | | - ) |
539 | | - |
540 | | - component.set_val( |
541 | | - component.get_val().subs( |
542 | | - { |
543 | | - rate_of: get_rate(rate_of.args[0]) |
544 | | - for rate_of in rate_ofs |
545 | | - } |
546 | | - ) |
547 | | - ) |
548 | | - |
549 | | - for event in self.events(): |
550 | | - state_update = event.get_state_update( |
551 | | - x=self.sym("x"), x_old=self.sym("x") |
552 | | - ) |
553 | | - if state_update is None: |
554 | | - continue |
555 | | - |
556 | | - for i_state in range(len(state_update)): |
557 | | - if rate_ofs := state_update[i_state].find(rate_of_func): |
558 | | - raise SBMLException( |
559 | | - "AMICI does currently not support rateOf(.) inside event state updates." |
560 | | - ) |
561 | | - # TODO here we need xdot sym, not eqs |
562 | | - # event._state_update[i_state] = event._state_update[i_state].subs( |
563 | | - # {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} |
564 | | - # ) |
565 | | - |
566 | 409 | def add_component( |
567 | 410 | self, component: ModelQuantity, insert_first: bool | None = False |
568 | 411 | ) -> None: |
@@ -1271,9 +1114,7 @@ def generate_basic_variables(self) -> None: |
1271 | 1114 | Generates the symbolic identifiers for all variables in |
1272 | 1115 | ``DEModel._variable_prototype`` |
1273 | 1116 | """ |
1274 | | - # We need to process events and Heaviside functions in the ``DEModel`, |
1275 | | - # before adding it to DEExporter |
1276 | | - self.parse_events() |
| 1117 | + self._reorder_events() |
1277 | 1118 |
|
1278 | 1119 | for var in self._variable_prototype: |
1279 | 1120 | if var not in self._syms: |
@@ -1335,7 +1176,11 @@ def parse_events(self) -> None: |
1335 | 1176 | for event in self.events(): |
1336 | 1177 | event.set_val(event.get_val().subs(w_toposorted)) |
1337 | 1178 |
|
1338 | | - # re-order events - first those that require root tracking, then the others |
| 1179 | + def _reorder_events(self) -> None: |
| 1180 | + """ |
| 1181 | + Re-order events - first those that require root tracking, |
| 1182 | + then the others. |
| 1183 | + """ |
1339 | 1184 | constant_syms = set(self.sym("k")) | set(self.sym("p")) |
1340 | 1185 | self._events = list( |
1341 | 1186 | chain( |
@@ -2694,22 +2539,7 @@ def _process_hybridization(self, hybridization: dict) -> None: |
2694 | 2539 | self._observables = [self._observables[i] for i in new_order] |
2695 | 2540 |
|
2696 | 2541 | if added_expressions: |
2697 | | - # toposort expressions |
2698 | | - w_sorted = toposort_symbols( |
2699 | | - dict( |
2700 | | - zip( |
2701 | | - self.sym("w"), |
2702 | | - self.eq("w"), |
2703 | | - strict=True, |
2704 | | - ) |
2705 | | - ) |
2706 | | - ) |
2707 | | - old_syms = tuple(self._syms["w"]) |
2708 | | - topo_expr_syms = tuple(w_sorted.keys()) |
2709 | | - new_order = [old_syms.index(s) for s in topo_expr_syms] |
2710 | | - self._expressions = [self._expressions[i] for i in new_order] |
2711 | | - self._syms["w"] = sp.Matrix(topo_expr_syms) |
2712 | | - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) |
| 2542 | + self.toposort_expressions() |
2713 | 2543 |
|
2714 | 2544 | def get_explicit_roots(self) -> set[sp.Expr]: |
2715 | 2545 | """ |
@@ -2752,3 +2582,51 @@ def has_event_assignments(self) -> bool: |
2752 | 2582 | boolean indicating if event assignments are present |
2753 | 2583 | """ |
2754 | 2584 | return any(event.updates_state for event in self._events) |
| 2585 | + |
| 2586 | + def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]: |
| 2587 | + """ |
| 2588 | + Sort expressions in topological order. |
| 2589 | +
|
| 2590 | + :return: |
| 2591 | + dict of expression symbols to expressions in topological order |
| 2592 | + """ |
| 2593 | + # ensure no symbols or equations that depend on `w` have been generated |
| 2594 | + # yet, otherwise the re-ordering might break dependencies |
| 2595 | + if ( |
| 2596 | + generated := set(self._syms) |
| 2597 | + | set(self._eqs) |
| 2598 | + | set(self._sparsesyms) |
| 2599 | + | set(self._sparseeqs) |
| 2600 | + ) - {"w", "p", "k", "x", "x_rdata"}: |
| 2601 | + raise AssertionError( |
| 2602 | + "This function must be called before computing any " |
| 2603 | + "derivatives. The following symbols/equations are already " |
| 2604 | + f"generated: {generated}" |
| 2605 | + ) |
| 2606 | + |
| 2607 | + # NOTE: elsewhere, conservations law expressions are expected to |
| 2608 | + # occur before any other w expressions, so we must maintain their |
| 2609 | + # position. |
| 2610 | + # toposort everything but conservation law expressions, |
| 2611 | + # then prepend conservation laws |
| 2612 | + |
| 2613 | + w_toposorted = toposort_symbols( |
| 2614 | + { |
| 2615 | + e.get_sym(): e.get_val() |
| 2616 | + for e in self.expressions()[self.num_cons_law() :] |
| 2617 | + } |
| 2618 | + ) |
| 2619 | + |
| 2620 | + w_toposorted = { |
| 2621 | + e.get_sym(): e.get_val() |
| 2622 | + for e in self.expressions()[: self.num_cons_law()] |
| 2623 | + } | w_toposorted |
| 2624 | + |
| 2625 | + old_syms = tuple(e.get_sym() for e in self.expressions()) |
| 2626 | + topo_expr_syms = tuple(w_toposorted) |
| 2627 | + new_order = [old_syms.index(s) for s in topo_expr_syms] |
| 2628 | + self._expressions = [self._expressions[i] for i in new_order] |
| 2629 | + self._syms["w"] = sp.Matrix(topo_expr_syms) |
| 2630 | + self._eqs["w"] = sp.Matrix(list(w_toposorted.values())) |
| 2631 | + |
| 2632 | + return w_toposorted |
0 commit comments