Skip to content

Commit 7c2e2d7

Browse files
authored
Refactor steady-state simulation (#2855)
Some changes to make it easier to implement event handling during pre-equilibration later on. * Move getWrmsState to lambda * Change `t` in `writeSolution` from pointer to reference (missed one overload previously) * non-const Solver
1 parent 81316c6 commit 7c2e2d7

File tree

7 files changed

+34
-43
lines changed

7 files changed

+34
-43
lines changed

include/amici/forwardproblem.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ class EventHandlingSimulator {
222222
*/
223223
EventHandlingSimulator(
224224
gsl::not_null<Model*> model, gsl::not_null<Solver*> solver,
225-
gsl::not_null<FwdSimWorkspace*> ws,
226-
gsl::not_null<std::vector<realtype>*> dJzdx
225+
gsl::not_null<FwdSimWorkspace*> ws, std::vector<realtype>* dJzdx
227226
)
228227
: model_(model)
229228
, solver_(solver)

include/amici/solver.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ class Solver {
651651
* @param xQ quadrature
652652
*/
653653
void writeSolution(
654-
realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
654+
realtype& t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
655655
AmiVector& xQ
656656
) const;
657657

include/amici/steadystateproblem.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,8 @@ class SteadystateProblem {
239239
* @param it Index of the current output time point.
240240
* @param t0 Initial time for the steady state simulation.
241241
*/
242-
void workSteadyStateProblem(
243-
Solver const& solver, Model& model, int it, realtype t0
244-
);
242+
void
243+
workSteadyStateProblem(Solver& solver, Model& model, int it, realtype t0);
245244

246245
/**
247246
* @brief Compute the gradient via adjoint steady state sensitivities.
@@ -375,8 +374,7 @@ class SteadystateProblem {
375374
* @param it Index of the current output time point.
376375
* @param t0 Initial time for the steady state simulation.
377376
*/
378-
void
379-
findSteadyState(Solver const& solver, Model& model, int it, realtype t0);
377+
void findSteadyState(Solver& solver, Model& model, int it, realtype t0);
380378

381379
/**
382380
* @brief Try to determine the steady state by using Newton's method.
@@ -396,7 +394,7 @@ class SteadystateProblem {
396394
* successfully, or if it failed.
397395
*/
398396
SteadyStateStatus findSteadyStateBySimulation(
399-
Solver const& solver, Model& model, int it, realtype t0
397+
Solver& solver, Model& model, int it, realtype t0
400398
);
401399

402400
/**
@@ -438,13 +436,6 @@ class SteadystateProblem {
438436
bool tried_newton_1, bool tried_simulation, bool tried_newton_2
439437
) const;
440438

441-
/**
442-
* @brief Checks steady-state convergence for state variables
443-
* @param model Model instance
444-
* @return weighted root mean squared residuals of the RHS
445-
*/
446-
realtype getWrmsState(Model& model);
447-
448439
/**
449440
* @brief Checks convergence for state sensitivities
450441
* @param model Model instance
@@ -460,7 +451,7 @@ class SteadystateProblem {
460451
* @param model Model instance.
461452
* simulation.
462453
*/
463-
void runSteadystateSimulationFwd(Solver const& solver, Model& model);
454+
void runSteadystateSimulationFwd(Solver& solver, Model& model);
464455

465456
/**
466457
* @brief Launch backward simulation if Newton solver or linear system solve

src/forwardproblem.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,14 +450,15 @@ void EventHandlingSimulator::store_event(ExpData const* edata) {
450450
continue;
451451
}
452452

453-
if (edata && solver_->computingASA())
453+
if (edata && solver_->computingASA()) {
454+
Expects(dJzdx_ != nullptr);
454455
model_->getAdjointStateEventUpdate(
455456
slice(
456457
*dJzdx_, ws_->nroots.at(ie), model_->nx_solver * model_->nJ
457458
),
458459
ie, ws_->nroots.at(ie), t_, ws_->x, *edata
459460
);
460-
461+
}
461462
ws_->nroots.at(ie)++;
462463
}
463464

src/model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
#include <algorithm>
99
#include <cassert>
10-
#include <sstream>
1110
#include <cmath>
1211
#include <cstring>
1312
#include <numeric>
1413
#include <regex>
14+
#include <sstream>
1515
#include <utility>
1616

1717
namespace amici {

src/solver.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,15 +1307,15 @@ void Solver::resetMutableMemory(
13071307
}
13081308

13091309
void Solver::writeSolution(
1310-
realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx, AmiVector& xQ
1310+
realtype& t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx, AmiVector& xQ
13111311
) const {
1312-
*t = gett();
1312+
t = gett();
13131313
if (quad_initialized_)
1314-
xQ.copy(getQuadrature(*t));
1314+
xQ.copy(getQuadrature(t));
13151315
if (sens_initialized_)
1316-
sx.copy(getStateSensitivity(*t));
1317-
x.copy(getState(*t));
1318-
dx.copy(getDerivativeState(*t));
1316+
sx.copy(getStateSensitivity(t));
1317+
x.copy(getState(t));
1318+
dx.copy(getDerivativeState(t));
13191319
}
13201320

13211321
void Solver::writeSolution(

src/steadystateproblem.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ SteadystateProblem::SteadystateProblem(
173173
}
174174

175175
void SteadystateProblem::workSteadyStateProblem(
176-
Solver const& solver, Model& model, int it, realtype t0
176+
Solver& solver, Model& model, int it, realtype t0
177177
) {
178178
if (model.ne > 0) {
179179
solver.logger->log(
@@ -280,7 +280,7 @@ void SteadystateProblem::workSteadyStateBackwardProblem(
280280
}
281281

282282
void SteadystateProblem::findSteadyState(
283-
Solver const& solver, Model& model, int it, realtype t0
283+
Solver& solver, Model& model, int it, realtype t0
284284
) {
285285
steady_state_status_.resize(3, SteadyStateStatus::not_run);
286286
// Turn off Newton's method if 'integrationOnly' approach is chosen for
@@ -365,7 +365,7 @@ void SteadystateProblem::findSteadyStateByNewtonsMethod(
365365
}
366366

367367
SteadyStateStatus SteadystateProblem::findSteadyStateBySimulation(
368-
Solver const& solver, Model& model, int const it, realtype const t0
368+
Solver& solver, Model& model, int const it, realtype const t0
369369
) {
370370
try {
371371
if (it < 0) {
@@ -551,17 +551,6 @@ void SteadystateProblem::getQuadratureBySimulation(
551551
throw AmiException(errorString.c_str());
552552
}
553553

554-
realtype SteadystateProblem::getWrmsState(Model& model) {
555-
updateRightHandSide(model);
556-
557-
if (newton_step_conv_) {
558-
newtons_method_.compute_step(ws_->xdot, {ws_->t, ws_->x, ws_->dx});
559-
return wrms_computer_x_.wrms(newtons_method_.get_delta(), ws_->x);
560-
}
561-
562-
return wrms_computer_x_.wrms(ws_->xdot, ws_->x);
563-
}
564-
565554
realtype
566555
SteadystateProblem::getWrmsFSA(Model& model, WRMSComputer& wrms_computer_sx) {
567556
// Forward sensitivities: Compute weighted error norm for their RHS
@@ -601,7 +590,7 @@ bool SteadystateProblem::checkSteadyStateSuccess() const {
601590
}
602591

603592
void SteadystateProblem::runSteadystateSimulationFwd(
604-
Solver const& solver, Model& model
593+
Solver& solver, Model& model
605594
) {
606595
if (model.nx_solver == 0)
607596
return;
@@ -643,14 +632,25 @@ void SteadystateProblem::runSteadystateSimulationFwd(
643632
sensi_converged = []() { return true; };
644633
}
645634

635+
// Returns the WRMS for the current state
636+
auto get_wrms_state = [&]() {
637+
updateRightHandSide(model);
638+
if (newton_step_conv_) {
639+
newtons_method_.compute_step(ws_->xdot, {ws_->t, ws_->x, ws_->dx});
640+
return wrms_computer_x_.wrms(newtons_method_.get_delta(), ws_->x);
641+
}
642+
643+
return wrms_computer_x_.wrms(ws_->xdot, ws_->x);
644+
};
645+
646646
int& sim_steps = numsteps_.at(1);
647647
int convergence_check_frequency = newton_step_conv_ ? 25 : 1;
648648

649649
while (true) {
650650
if (sim_steps % convergence_check_frequency == 0) {
651651
// Check for convergence (already before simulation, since we might
652652
// start in steady state)
653-
wrms_ = getWrmsState(model);
653+
wrms_ = get_wrms_state();
654654
if (wrms_ < conv_thresh && sensi_converged()) {
655655
break;
656656
}
@@ -756,7 +756,7 @@ void SteadystateProblem::runSteadystateSimulationBwd(
756756
solver.step(std::max(final_state_.t, 1.0) * 10);
757757

758758
solver.writeSolution(
759-
&final_state_.t, xB_, final_state_.dx, final_state_.sx, xQ_
759+
final_state_.t, xB_, final_state_.dx, final_state_.sx, xQ_
760760
);
761761
}
762762
}

0 commit comments

Comments
 (0)