diff --git a/benchmarks/jump_step_timing.py b/benchmarks/jump_step_timing.py new file mode 100644 index 00000000..9250de4f --- /dev/null +++ b/benchmarks/jump_step_timing.py @@ -0,0 +1,116 @@ +from warnings import simplefilter + + +simplefilter(action="ignore", category=FutureWarning) + +import timeit +from functools import partial + +import diffrax +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +from old_pid_controller import OldPIDController + + +t0 = 0 +t1 = 5 +dt0 = 0.5 +y0 = 1.0 +drift = diffrax.ODETerm(lambda t, y, args: -0.2 * y) + + +def diffusion_vf(t, y, args): + return jnp.ones((), dtype=y.dtype) + + +def get_terms(key): + bm = diffrax.VirtualBrownianTree(t0, t1, 2**-5, (), key) + diffusion = diffrax.ControlTerm(diffusion_vf, bm) + return diffrax.MultiTerm(drift, diffusion) + + +solver = diffrax.Heun() +step_ts = jnp.linspace(t0, t1, 129, endpoint=True) +pid_controller = diffrax.PIDController( + rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7 +) +new_controller = diffrax.JumpStepWrapper( + pid_controller, + step_ts=step_ts, + rejected_step_buffer_len=None, +) +old_controller = OldPIDController( + rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts +) + + +@eqx.filter_jit +@partial(jax.vmap, in_axes=(0, None)) +def solve(key, controller): + term = get_terms(key) + return diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + stepsize_controller=controller, + saveat=diffrax.SaveAt(ts=step_ts), + ) + + +num_samples = 100 +keys = jr.split(jr.PRNGKey(0), num_samples) + + +def do_timing(controller): + @jax.jit + @eqx.debug.assert_max_traces(max_traces=1) + def time_controller_fun(): + sols = solve(keys, controller) + assert sols.ys is not None + assert sols.ys.shape == (num_samples, len(step_ts)) + return sols.ys + + def time_controller(): + jax.block_until_ready(time_controller_fun()) + + return min(timeit.repeat(time_controller, number=3, repeat=20)) + + +time_new = do_timing(new_controller) + +time_old = do_timing(old_controller) + +print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s") + +# How expensive is revisiting rejected steps? +revisiting_controller_short = diffrax.JumpStepWrapper( + pid_controller, + step_ts=step_ts, + rejected_step_buffer_len=10, +) + +revisiting_controller_long = diffrax.JumpStepWrapper( + pid_controller, + step_ts=step_ts, + rejected_step_buffer_len=4096, +) + +time_revisiting_short = do_timing(revisiting_controller_short) +time_revisiting_long = do_timing(revisiting_controller_long) + +print( + f"Revisiting controller\n" + f"with buffer len 10: {time_revisiting_short:.5} s\n" + f"with buffer len 4096: {time_revisiting_long:.5} s" +) + +# ======= RESULTS ======= +# New controller: 0.23506 s, Old controller: 0.30735 s +# Revisiting controller +# with buffer len 10: 0.23636 s +# with buffer len 4096: 0.23965 s diff --git a/benchmarks/old_pid_controller.py b/benchmarks/old_pid_controller.py new file mode 100644 index 00000000..f6d78098 --- /dev/null +++ b/benchmarks/old_pid_controller.py @@ -0,0 +1,414 @@ +from collections.abc import Callable +from typing import cast, Optional, TypeVar + +import equinox as eqx +import equinox.internal as eqxi +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.tree_util as jtu +import lineax.internal as lxi +import optimistix as optx +from diffrax import AbstractTerm, ODETerm, RESULTS +from diffrax._custom_types import ( + Args, + BoolScalarLike, + IntScalarLike, + RealScalarLike, + VF, + Y, +) +from diffrax._misc import static_select, upcast_or_raise +from diffrax._step_size_controller import AbstractAdaptiveStepSizeController +from equinox.internal import ω +from jaxtyping import Array, PyTree, Real +from lineax.internal import complex_to_real_dtype + + +ω = cast(Callable, ω) + + +def _select_initial_step( + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + y0: Y, + args: Args, + func: Callable[ + [PyTree[AbstractTerm], RealScalarLike, Y, Args], + VF, + ], + error_order: RealScalarLike, + rtol: RealScalarLike, + atol: RealScalarLike, + norm: Callable[[PyTree], RealScalarLike], +) -> RealScalarLike: + # TODO: someone needs to figure out an initial step size algorithm for SDEs. + if not isinstance(terms, ODETerm): + return 0.01 + + def fn(carry): + t, y, _h0, _d1, _f, _ = carry + f = func(terms, t, y, args) + return t, y, _h0, _d1, _f, f + + def intermediate(carry): + _, _, _, _, _, f0 = carry + d0 = norm((y0**ω / scale**ω).ω) + d1 = norm((f0**ω / scale**ω).ω) + _cond = (d0 < 1e-5) | (d1 < 1e-5) + _d1 = jnp.where(_cond, 1, d1) + h0 = jnp.where(_cond, 1e-6, 0.01 * (d0 / _d1)) + t1 = t0 + h0 + y1 = (y0**ω + h0 * f0**ω).ω + return t1, y1, h0, d1, f0, f0 + + scale = (atol + ω(y0).call(jnp.abs) * rtol).ω + dummy_h = t0 + dummy_d = eqxi.eval_empty(norm, y0) + dummy_f = eqxi.eval_empty(lambda: func(terms, t0, y0, args)) + _, _, h0, d1, f0, f1 = eqxi.scan_trick( + fn, [intermediate], (t0, y0, dummy_h, dummy_d, dummy_f, dummy_f) + ) + d2 = norm(((f1**ω - f0**ω) / scale**ω).ω) / h0 + max_d = jnp.maximum(d1, d2) + h1 = jnp.where( + max_d <= 1e-15, + jnp.maximum(1e-6, h0 * 1e-3), + (0.01 / max_d) ** (1 / error_order), + ) + return jnp.minimum(100 * h0, h1) + + +_ControllerState = TypeVar("_ControllerState") +_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) + +_PidState = tuple[ + BoolScalarLike, BoolScalarLike, RealScalarLike, RealScalarLike, RealScalarLike +] + + +def _none_or_array(x): + if x is None: + return None + else: + return jnp.asarray(x) + + +class OldPIDController( + AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]] +): + r"""See the doc of diffrax.PIDController for more information.""" + + rtol: RealScalarLike + atol: RealScalarLike + pcoeff: RealScalarLike = 0 + icoeff: RealScalarLike = 1 + dcoeff: RealScalarLike = 0 + dtmin: Optional[RealScalarLike] = None + dtmax: Optional[RealScalarLike] = None + force_dtmin: bool = True + step_ts: Optional[Real[Array, " steps"]] = eqx.field( + default=None, converter=_none_or_array + ) + jump_ts: Optional[Real[Array, " jumps"]] = eqx.field( + default=None, converter=_none_or_array + ) + factormin: RealScalarLike = 0.2 + factormax: RealScalarLike = 10.0 + norm: Callable[[PyTree], RealScalarLike] = optx.rms_norm + safety: RealScalarLike = 0.9 + error_order: Optional[RealScalarLike] = None + + def __check_init__(self): + if self.jump_ts is not None and not jnp.issubdtype( + self.jump_ts.dtype, jnp.inexact + ): + raise ValueError( + f"jump_ts must be floating point, not {self.jump_ts.dtype}" + ) + + def wrap(self, direction: IntScalarLike): + step_ts = None if self.step_ts is None else self.step_ts * direction + jump_ts = None if self.jump_ts is None else self.jump_ts * direction + return eqx.tree_at( + lambda s: (s.step_ts, s.jump_ts), + self, + (step_ts, jump_ts), + is_leaf=lambda x: x is None, + ) + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + dt0: Optional[RealScalarLike], + args: Args, + func: Callable[[PyTree[AbstractTerm], RealScalarLike, Y, Args], VF], + error_order: Optional[RealScalarLike], + ) -> tuple[RealScalarLike, _PidState]: + del t1 + if dt0 is None: + error_order = self._get_error_order(error_order) + dt0 = _select_initial_step( + terms, + t0, + y0, + args, + func, + error_order, + self.rtol, + self.atol, + self.norm, + ) + + dt0 = lax.stop_gradient(dt0) + if self.dtmax is not None: + dt0 = jnp.minimum(dt0, self.dtmax) + if self.dtmin is None: + at_dtmin = jnp.array(False) + else: + at_dtmin = dt0 <= self.dtmin + dt0 = jnp.maximum(dt0, self.dtmin) + + t1 = self._clip_step_ts(t0, t0 + dt0) + t1, jump_next_step = self._clip_jump_ts(t0, t1) + + y_leaves = jtu.tree_leaves(y0) + if len(y_leaves) == 0: + y_dtype = lxi.default_floating_dtype() + else: + y_dtype = jnp.result_type(*y_leaves) + return t1, ( + jump_next_step, + at_dtmin, + dt0, + jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), + jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), + ) + + def adapt_step_size( + self, + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + y1_candidate: Y, + args: Args, + y_error: Optional[Y], + error_order: RealScalarLike, + controller_state: _PidState, + ) -> tuple[ + BoolScalarLike, + RealScalarLike, + RealScalarLike, + BoolScalarLike, + _PidState, + RESULTS, + ]: + del args + if y_error is None and y0 is not None: + # y0 is not None check is included to handle the edge case that the state + # is just a trivial `None` PyTree. In this case `y_error` has the same + # PyTree structure and thus overlaps with our special usage of `None` to + # indicate a lack of error estimate. + raise RuntimeError( + "Cannot use adaptive step sizes with a solver that does not provide " + "error estimates." + ) + ( + made_jump, + at_dtmin, + prev_dt, + prev_inv_scaled_error, + prev_prev_inv_scaled_error, + ) = controller_state + error_order = self._get_error_order(error_order) + prev_dt = jnp.where(made_jump, prev_dt, t1 - t0) + + # + # Figure out how things went on the last step: error, and whether to + # accept/reject it. + # + + def _scale(_y0, _y1_candidate, _y_error): + # In case the solver steps into a region for which the vector field isn't + # defined. + _nan = jnp.isnan(_y1_candidate).any() + _y1_candidate = jnp.where(_nan, _y0, _y1_candidate) + _y = jnp.maximum(jnp.abs(_y0), jnp.abs(_y1_candidate)) + with jax.numpy_dtype_promotion("standard"): + return _y_error / (self.atol + _y * self.rtol) + + scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) + keep_step = scaled_error < 1 + if self.dtmin is not None: + keep_step = keep_step | at_dtmin + # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. + inv_scaled_error = 1 / jnp.asarray(scaled_error) + inv_scaled_error = lax.stop_gradient( + inv_scaled_error + ) # See note in init above. + # Note: if you ever remove this lax.stop_gradient, then you'll need to do a lot + # of work to get safe gradients through these operations. + # When `inv_scaled_error` has a (non-symbolic) zero cotangent, and `y_error` + # is either zero or inf, then we get a `0 * inf = nan` on the backward pass. + + # + # Adjust next step size + # + + _zero_coeff = lambda c: isinstance(c, (int, float)) and c == 0 + coeff1 = (self.icoeff + self.pcoeff + self.dcoeff) / error_order + coeff2 = -cast(RealScalarLike, self.pcoeff + 2 * self.dcoeff) / error_order + coeff3 = self.dcoeff / error_order + factor1 = 1 if _zero_coeff(coeff1) else inv_scaled_error**coeff1 + factor2 = 1 if _zero_coeff(coeff2) else prev_inv_scaled_error**coeff2 + factor3 = 1 if _zero_coeff(coeff3) else prev_prev_inv_scaled_error**coeff3 + factormin = jnp.where(keep_step, 1, self.factormin) + factor = jnp.clip( + self.safety * factor1 * factor2 * factor3, + min=factormin, + max=self.factormax, + ) + # Once again, see above. In case we have gradients on {i,p,d}coeff. + # (Probably quite common for them to have zero tangents if passed across + # a grad API boundary as part of a larger model.) + factor = lax.stop_gradient(factor) + factor = eqxi.nondifferentiable(factor) + dt = prev_dt * factor.astype(jnp.result_type(prev_dt)) + + # E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0, + # so factor=factormin, and we shrunk our step. + # If we're using a PI or PID controller we shouldn't then force shrinking on + # the next or next two steps as well! + pred = (inv_scaled_error == 0) | jnp.isinf(inv_scaled_error) + inv_scaled_error = jnp.where(pred, 1, inv_scaled_error) + + # + # Clip next step size based on dtmin/dtmax + # + + result = RESULTS.successful + if self.dtmax is not None: + dt = jnp.minimum(dt, self.dtmax) + if self.dtmin is None: + at_dtmin = jnp.array(False) + else: + if not self.force_dtmin: + result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) + at_dtmin = dt <= self.dtmin + dt = jnp.maximum(dt, self.dtmin) + + # + # Clip next step size based on step_ts/jump_ts + # + + if jnp.issubdtype(jnp.result_type(t1), jnp.inexact): + # Two nextafters. If made_jump then t1 = prevbefore(jump location) + # so now _t1 = nextafter(jump location) + # This is important because we don't know whether or not the jump is as a + # result of a left- or right-discontinuity, so we have to skip the jump + # location altogether. + _t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) + else: + _t1 = t1 + next_t0 = jnp.where(keep_step, _t1, t0) + next_t1 = self._clip_step_ts(next_t0, next_t0 + dt) + next_t1, next_made_jump = self._clip_jump_ts(next_t0, next_t1) + + inv_scaled_error = jnp.where(keep_step, inv_scaled_error, prev_inv_scaled_error) + prev_inv_scaled_error = jnp.where( + keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error + ) + controller_state = ( + next_made_jump, + at_dtmin, + dt, + inv_scaled_error, + prev_inv_scaled_error, + ) + return keep_step, next_t0, next_t1, made_jump, controller_state, result + + def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarLike: + # Attribute takes priority, if the user knows the correct error order better + # than our guess. + error_order = error_order if self.error_order is None else self.error_order + if error_order is None: + raise ValueError( + "The order of convergence for the solver has not been specified; pass " + "`PIDController(..., error_order=...)` manually instead. If solving " + "an ODE then this should be equal to the (global) order plus one. If " + "solving an SDE then should be equal to the (global) order plus 0.5." + ) + return error_order + + def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: + if self.step_ts is None: + return t1 + + step_ts0 = upcast_or_raise( + self.step_ts, + t0, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + step_ts1 = upcast_or_raise( + self.step_ts, + t1, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping + # track of where we were last, and using that as a hint for the next search. + t0_index = jnp.searchsorted(step_ts0, t0, side="right") + t1_index = jnp.searchsorted(step_ts1, t1, side="right") + # This minimum may or may not actually be necessary. The left branch is taken + # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must + # already satisfy the minimum. + # However, that branch is actually executed unconditionally and then where'd, + # so we clamp it just to be sure we're not hitting undefined behaviour. + t1 = jnp.where( + t0_index < t1_index, + step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], + t1, + ) + return t1 + + def _clip_jump_ts( + self, t0: RealScalarLike, t1: RealScalarLike + ) -> tuple[RealScalarLike, BoolScalarLike]: + if self.jump_ts is None: + return t1, False + assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) + if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t0)}." + ) + if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t1)}." + ) + jump_ts0 = upcast_or_raise( + self.jump_ts, + t0, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + jump_ts1 = upcast_or_raise( + self.jump_ts, + t1, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + t0_index = jnp.searchsorted(jump_ts0, t0, side="right") + t1_index = jnp.searchsorted(jump_ts1, t1, side="right") + next_made_jump = t0_index < t1_index + t1 = jnp.where( + next_made_jump, + eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), + t1, + ) + return t1, next_made_jump diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 67b4ca50..dc93c879 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -117,6 +117,7 @@ AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, AbstractStepSizeController as AbstractStepSizeController, ConstantStepSize as ConstantStepSize, + JumpStepWrapper as JumpStepWrapper, PIDController as PIDController, StepTo as StepTo, ) diff --git a/diffrax/_step_size_controller/__init__.py b/diffrax/_step_size_controller/__init__.py index 18d19c00..7aa2b0b3 100644 --- a/diffrax/_step_size_controller/__init__.py +++ b/diffrax/_step_size_controller/__init__.py @@ -1,6 +1,9 @@ -from .adaptive import ( +from .adaptive_base import ( AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, - PIDController as PIDController, ) from .base import AbstractStepSizeController as AbstractStepSizeController from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo +from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper +from .pid import ( + PIDController as PIDController, +) diff --git a/diffrax/_step_size_controller/adaptive_base.py b/diffrax/_step_size_controller/adaptive_base.py new file mode 100644 index 00000000..7f984e05 --- /dev/null +++ b/diffrax/_step_size_controller/adaptive_base.py @@ -0,0 +1,42 @@ +from collections.abc import Callable +from typing import Optional, TypeVar + +from equinox import AbstractVar +from jaxtyping import PyTree + +from .._custom_types import RealScalarLike +from .base import AbstractStepSizeController + + +_ControllerState = TypeVar("_ControllerState") +_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) + + +class AbstractAdaptiveStepSizeController( + AbstractStepSizeController[_ControllerState, _Dt0] +): + """Indicates an adaptive step size controller. + + Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit + solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will + automatically be used as the tolerances for the nonlinear solver passed to the + implicit solver, if they are not specified manually. + """ + + rtol: AbstractVar[RealScalarLike] + atol: AbstractVar[RealScalarLike] + norm: AbstractVar[Callable[[PyTree], RealScalarLike]] + + def __check_init__(self): + if self.rtol is None or self.atol is None: + raise ValueError( + "The default values for `rtol` and `atol` were removed in Diffrax " + "version 0.1.0. (As the choice of tolerance is nearly always " + "something that you, as an end user, should make an explicit choice " + "about.)\n" + "If you want to match the previous defaults then specify " + "`rtol=1e-3`, `atol=1e-6`. For example:\n" + "```\n" + "diffrax.PIDController(rtol=1e-3, atol=1e-6)\n" + "```\n" + ) diff --git a/diffrax/_step_size_controller/jump_step_wrapper.py b/diffrax/_step_size_controller/jump_step_wrapper.py new file mode 100644 index 00000000..2f3cb4f7 --- /dev/null +++ b/diffrax/_step_size_controller/jump_step_wrapper.py @@ -0,0 +1,424 @@ +from collections.abc import Callable +from typing import Generic, Optional, TYPE_CHECKING, TypeVar + +import equinox as eqx +import equinox.internal as eqxi +import jax +import jax.numpy as jnp +from jaxtyping import Array, PyTree, Real + +from .._custom_types import ( + Args, + BoolScalarLike, + IntScalarLike, + RealScalarLike, + VF, + Y, +) +from .._misc import static_select, upcast_or_raise +from .._solution import RESULTS +from .._term import AbstractTerm +from .adaptive_base import AbstractAdaptiveStepSizeController +from .base import AbstractStepSizeController + + +_ControllerState = TypeVar("_ControllerState") +_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) + + +class _JumpStepState(eqx.Module, Generic[_ControllerState]): + made_jump: BoolScalarLike + prev_dt: RealScalarLike + step_index: IntScalarLike + jump_index: IntScalarLike + rejected_index: IntScalarLike + rejected_buffer: Optional[Array] + step_ts: Optional[Array] + jump_ts: Optional[Array] + inner_state: _ControllerState + + def get(self): + return ( + self.made_jump, + self.prev_dt, + self.step_index, + self.jump_index, + self.rejected_index, + self.rejected_buffer, + self.step_ts, + self.jump_ts, + self.inner_state, + ) + + +def _none_or_array(x): + if x is None: + return None + else: + return jnp.asarray(x) + + +def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike: + i_min_len = jnp.minimum(i, len(ts) - 1) + return jnp.where(i == len(ts), jnp.inf, ts[i_min_len]) + + +def _clip_ts( + t0: RealScalarLike, + t1: RealScalarLike, + i: IntScalarLike, + ts: Optional[Array], + check_inexact: bool, +) -> tuple[RealScalarLike, BoolScalarLike]: + if ts is None: + return t1, False + + if check_inexact: + assert jnp.issubdtype(ts.dtype, jnp.inexact) + if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t0)}." + ) + if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t1)}." + ) + + _t1 = _get_t(i, ts) + next_made_jump = _t1 <= t1 + _t1 = jnp.where(next_made_jump, _t1, t1) + return _t1, next_made_jump + + +def _find_index(t: RealScalarLike, ts: Optional[Array]) -> IntScalarLike: + if ts is None: + return 0 + + ts = upcast_or_raise( + ts, + t, + "`JumpStepWrapper.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + return jnp.searchsorted(ts, t, side="right") + + +def _revisit_rejected( + t0: RealScalarLike, + t1: RealScalarLike, + i_reject: IntScalarLike, + rejected_buffer: Optional[Array], +) -> RealScalarLike: + if rejected_buffer is None: + return t1 + _t1 = _get_t(i_reject, rejected_buffer) + _t1 = jnp.minimum(_t1, t1) + return _t1 + + +# EXPLANATION OF STEP_TS AND JUMP_TS +# ----------------------------------- +# The `step_ts` and `jump_ts` are used to force the solver to step to certain times. +# They mostly act in the same way, except that when we hit an element of `jump_ts`, +# the controller must return `made_jump = True`, so that the diffeqsolve function +# knows that the vector field has a discontinuity at that point. In addition, the +# exact time of the jump will be skipped using jnp.prevbefore and jnp.nextafter. +# So now to the explanation of the two (we will use `step_ts` as an example, but the +# same applies to `jump_ts`): +# +# If `step_ts` is not None, we assume it is a sorted array of times. +# At the start of the run, the init function finds the smallest index `i_step` such +# that `step_ts[i_step] > t0`. At init and after each step of the solver, the +# controller will propose a step t1_next, and we will clip it to +# `t1_next = min(t1_next, step_ts[i_step])`. +# At the start of the next step, if the step ended at t1 == step_ts[i_step] and +# if the controller decides to keep the step, then this time has been successfully +# stepped to and we increment `i_step` by 1. +# We use a convenience function _get_t(i, ts) which returns ts[i] if i < len(ts) and +# infinity otherwise. + +# EXPLANATION OF REVISITING REJECTED STEPS +# ---------------------------------------- +# We use a "stack" of rejected steps, composed of a buffer `rejected_buffer` of length +# `rejected_step_buffer_len` and a counter `i_reject`. The "stack" are all the items +# in `rejected_buffer[i_reject:]` with `rejected_buffer[i_reject]` being the top of +# the stack. +# When `i_reject == rejected_step_buffer_len`, the stack is empty. +# At the start of the run, `i_reject = rejected_step_buffer_len`. Each time a step is +# rejected `i_reject -=1` and `rejected_buffer[i_reject] = t1`. Each time a step ends at +# `t1 == rejected_buffer[i_reject]`, we increment `i_reject` by 1 (even if the step was +# rejected, in which case we will re-add `t1` to the stack immediately). +# We clip the next step to `t1_next = min(t1_next, rejected_buffer[i_reject])`. +# If `i_reject < 0` then an error is raised. + + +class JumpStepWrapper( + AbstractStepSizeController[_JumpStepState[_ControllerState], _Dt0] +): + """Wraps an existing step controller and adds the ability to specify `step_ts` + and `jump_ts`. The former are times to which the controller should step and the + latter are times at which the vector field has a discontinuity (jump).""" + + controller: AbstractAdaptiveStepSizeController[_ControllerState, _Dt0] + step_ts: Optional[Real[Array, " steps"]] + jump_ts: Optional[Real[Array, " jumps"]] + rejected_step_buffer_len: Optional[int] = eqx.field(static=True) + callback_on_reject: Optional[Callable] = eqx.field(static=True) + + @eqxi.doc_remove_args("_callback_on_reject") + def __init__( + self, + controller, + step_ts=None, + jump_ts=None, + rejected_step_buffer_len=None, + _callback_on_reject=None, + ): + r""" + **Arguments**: + + - `controller`: The controller to wrap. + Can be any diffrax.AbstractAdaptiveStepSizeController. + - `step_ts`: Denotes extra times that must be stepped to. + - `jump_ts`: Denotes extra times that must be stepped to, and at which the + vector field has a known discontinuity. (This is used to force FSAL solvers + so re-evaluate the vector field.) + - `rejected_step_buffer_len`: The length of the buffer storing rejected steps. + Can either be None or a positive integer. + If it is > 0, then the controller will revisit rejected steps. This is + useful for SDEs, where the solution is guaranteed to be correct if the + SDE is evaluated at all times at which the Brownian motion (BM) is + evaluated. Since the BM is also evaluated at rejected steps, we must later + evaluate the SDE at these times as well. + """ + self.controller = controller + self.step_ts = _none_or_array(step_ts) + self.jump_ts = _none_or_array(jump_ts) + if rejected_step_buffer_len is not None: + assert rejected_step_buffer_len > 0 + self.rejected_step_buffer_len = rejected_step_buffer_len + self.callback_on_reject = _callback_on_reject + + def __check_init__(self): + if self.jump_ts is not None and not jnp.issubdtype( + self.jump_ts.dtype, jnp.inexact + ): + raise ValueError( + f"jump_ts must be floating point, not {self.jump_ts.dtype}" + ) + + def wrap(self, direction: IntScalarLike): + step_ts = None if self.step_ts is None else self.step_ts * direction + jump_ts = None if self.jump_ts is None else self.jump_ts * direction + controller = self.controller.wrap(direction) + return eqx.tree_at( + lambda s: (s.step_ts, s.jump_ts, s.controller), + self, + (step_ts, jump_ts, controller), + is_leaf=lambda x: x is None, + ) + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + dt0: _Dt0, + args: Args, + func: Callable[[PyTree[AbstractTerm], RealScalarLike, Y, Args], VF], + error_order: Optional[RealScalarLike], + ) -> tuple[RealScalarLike, _JumpStepState[_ControllerState]]: + t1, inner_state = self.controller.init( + terms, t0, t1, y0, dt0, args, func, error_order + ) + dt_proposal = t1 - t0 + tdtype = jnp.result_type(t0, t1) + + if self.step_ts is None: + step_ts = None + else: + # Upcast step_ts to the same dtype as t0, t1 + step_ts = upcast_or_raise( + self.step_ts, + jnp.zeros((), tdtype), + "`JumpStepWrapper.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + + if self.jump_ts is None: + jump_ts = None + else: + # Upcast jump_ts to the same dtype as t0, t1 + jump_ts = upcast_or_raise( + self.jump_ts, + jnp.zeros((), tdtype), + "`JumpStepWrapper.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + + if self.rejected_step_buffer_len is None: + rejected_buffer = None + i_reject = jnp.asarray(0) + else: + rejected_buffer = jnp.zeros( + (self.rejected_step_buffer_len,) + jnp.shape(t1), dtype=tdtype + ) + # rejected_buffer[len(rejected_buffer)] = jnp.inf (see def of _get_t) + i_reject = jnp.asarray(self.rejected_step_buffer_len) + + # Find index of first element of step_ts/jump_ts greater than t0 + i_step = _find_index(t0, step_ts) + i_jump = _find_index(t0, jump_ts) + # Clip t1 to the next element of step_ts or jump_ts + t1, _ = _clip_ts(t0, t1, i_step, step_ts, False) + t1, jump_next_step = _clip_ts(t0, t1, i_jump, jump_ts, True) + + state = _JumpStepState( + jump_next_step, + dt_proposal, + i_step, + i_jump, + i_reject, + rejected_buffer, + step_ts, + jump_ts, + inner_state, + ) + + return t1, state + + def adapt_step_size( + self, + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + y1_candidate: Y, + args: Args, + y_error: Optional[Y], + error_order: RealScalarLike, + controller_state: _JumpStepState[_ControllerState], + ) -> tuple[ + BoolScalarLike, + RealScalarLike, + RealScalarLike, + BoolScalarLike, + _JumpStepState[_ControllerState], + RESULTS, + ]: + ( + made_jump, + prev_dt, + i_step, + i_jump, + i_reject, + rejected_buffer, + step_ts, + jump_ts, + inner_state, + ) = controller_state.get() + + # Let the controller do its thing + ( + keep_step, + next_t0, + next_t1, + _, + inner_state, + result, + ) = self.controller.adapt_step_size( + t0, t1, y0, y1_candidate, args, y_error, error_order, inner_state + ) + + # This is just a logging utility for testing purposes + if self.callback_on_reject is not None: + jax.debug.callback(self.callback_on_reject, keep_step, t1) + + # Check whether we stepped over an element of step_ts/jump_ts/rejected_buffer + # This is all still bookkeeping for the PREVIOUS STEP. + if step_ts is not None: + # If we stepped to `t1 == step_ts[i_step]` and kept the step, then we + # increment i_step and move on to the next t in step_ts. + step_inc_cond = keep_step & (t1 == _get_t(i_step, step_ts)) + i_step = jnp.where(step_inc_cond, i_step + 1, i_step) + + if jump_ts is not None: + next_jump_t = _get_t(i_jump, jump_ts) + jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t)) + i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump) + + if self.rejected_step_buffer_len is not None: + assert rejected_buffer is not None + # If the step ended at t1==rejected_buffer[i_reject], then we have + # successfully stepped to this time and we increment i_reject. + # We increment i_reject even if the step was rejected, because we will + # re-add the rejected time to the buffer immediately. + rjct_inc_cond = t1 == _get_t(i_reject, rejected_buffer) + i_reject = jnp.where(rjct_inc_cond, i_reject + 1, i_reject) + + # If the step was rejected, then we need to store the rejected time in the + # rejected buffer and decrement the rejected index. + i_reject = jnp.where(keep_step, i_reject, i_reject - 1) + i_reject = eqx.error_if( + i_reject, + i_reject < 0, + "Maximum number of rejected steps reached. " + "Consider increasing JumpStepWrapper.rejected_step_buffer_len.", + ) + clipped_i = jnp.clip(i_reject, 0, self.rejected_step_buffer_len - 1) + update_rejected_t = jnp.where(keep_step, rejected_buffer[clipped_i], t1) + rejected_buffer = rejected_buffer.at[clipped_i].set(update_rejected_t) + + # Now move on to the NEXT STEP + dt_proposal = next_t1 - next_t0 + # The following line is so that in case prev_dt was intended to be large, + # but then clipped to very small (because of step_ts or jump_ts), we don't + # want it to stick to very small steps (e.g. the PID controller can only + # increase steps by a factor of 10 at a time). + dt_proposal = jnp.where( + keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal + ) + new_prev_dt = dt_proposal + next_t1 = next_t0 + dt_proposal + + # If t1 hit a jump point, and the step was kept then we need to set + # `next_t0 = nextafter(nextafter(t1))` to ensure that we really skip + # over the jump and don't evaluate the vector field at the discontinuity. + if jnp.issubdtype(jnp.result_type(next_t0), jnp.inexact): + # Two nextafters. If made_jump then t1 = prevbefore(jump location) + # so now _t1 = nextafter(jump location) + # This is important because we don't know whether or not the jump is as a + # result of a left- or right-discontinuity, so we have to skip the jump + # location altogether. + jump_keep = made_jump & keep_step + next_t0 = static_select( + jump_keep, eqxi.nextafter(eqxi.nextafter(next_t0)), next_t0 + ) + + if TYPE_CHECKING: + assert isinstance( + next_t0, RealScalarLike + ), f"type(next_t0) = {type(next_t0)}" + + # Clip the step to the next element of jump_ts or step_ts or + # rejected_buffer. Important to do jump_ts last because otherwise + # jump_next_step could be a false positive. + next_t1 = _revisit_rejected(next_t0, next_t1, i_reject, rejected_buffer) + next_t1, _ = _clip_ts(next_t0, next_t1, i_step, step_ts, False) + next_t1, jump_next_step = _clip_ts(next_t0, next_t1, i_jump, jump_ts, True) + + state = _JumpStepState( + jump_next_step, + new_prev_dt, + i_step, + i_jump, + i_reject, + rejected_buffer, + step_ts, + jump_ts, + inner_state, + ) + + return keep_step, next_t0, next_t1, made_jump, state, result diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/pid.py similarity index 75% rename from diffrax/_step_size_controller/adaptive.py rename to diffrax/_step_size_controller/pid.py index 9d181c95..58949719 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/pid.py @@ -1,6 +1,6 @@ import typing from collections.abc import Callable -from typing import cast, Optional, TYPE_CHECKING, TypeVar +from typing import cast, Optional, TYPE_CHECKING import equinox as eqx import equinox.internal as eqxi @@ -10,15 +10,14 @@ import jax.tree_util as jtu import lineax.internal as lxi import optimistix as optx -from jaxtyping import Real +from equinox.internal import ω if TYPE_CHECKING: - from typing import ClassVar as AbstractVar + pass else: - from equinox import AbstractVar -from equinox.internal import ω -from jaxtyping import Array, PyTree + pass +from jaxtyping import PyTree from lineax.internal import complex_to_real_dtype from .._custom_types import ( @@ -29,15 +28,30 @@ VF, Y, ) -from .._misc import static_select, upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm, ODETerm -from .base import AbstractStepSizeController +from .adaptive_base import AbstractAdaptiveStepSizeController +from .jump_step_wrapper import JumpStepWrapper ω = cast(Callable, ω) +# We use a metaclass for backwards compatibility. When a user calls +# PIDController(... step_ts=s, jump_ts=j) this should return a +# JumpStepWrapper(PIDController(...), s, j). +module_meta = type(eqx.Module) + + +class PIDMeta(module_meta): + def __call__(cls, *args, **kwargs): + step_ts = kwargs.pop("step_ts", None) + jump_ts = kwargs.pop("jump_ts", None) + if step_ts is not None or jump_ts is not None: + return JumpStepWrapper(cls(*args, **kwargs), step_ts, jump_ts) + return super().__call__(*args, **kwargs) + + def _select_initial_step( terms: PyTree[AbstractTerm], t0: RealScalarLike, @@ -89,50 +103,8 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -_ControllerState = TypeVar("_ControllerState") -_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) - - -class AbstractAdaptiveStepSizeController( - AbstractStepSizeController[_ControllerState, _Dt0] -): - """Indicates an adaptive step size controller. - - Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit - solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will - automatically be used as the tolerances for the nonlinear solver passed to the - implicit solver, if they are not specified manually. - """ - - rtol: AbstractVar[RealScalarLike] - atol: AbstractVar[RealScalarLike] - norm: AbstractVar[Callable[[PyTree], RealScalarLike]] - - def __check_init__(self): - if self.rtol is None or self.atol is None: - raise ValueError( - "The default values for `rtol` and `atol` were removed in Diffrax " - "version 0.1.0. (As the choice of tolerance is nearly always " - "something that you, as an end user, should make an explicit choice " - "about.)\n" - "If you want to match the previous defaults then specify " - "`rtol=1e-3`, `atol=1e-6`. For example:\n" - "```\n" - "diffrax.PIDController(rtol=1e-3, atol=1e-6)\n" - "```\n" - ) - - -_PidState = tuple[ - BoolScalarLike, BoolScalarLike, RealScalarLike, RealScalarLike, RealScalarLike -] - - -def _none_or_array(x): - if x is None: - return None - else: - return jnp.asarray(x) +# _PidState = (at_dtmin, prev_inv_scaled_error, prev_prev_inv_scaled_error) +_PidState = tuple[BoolScalarLike, RealScalarLike, RealScalarLike] if TYPE_CHECKING: @@ -157,7 +129,8 @@ def __repr__(self): # TODO: we don't currently offer a limiter, or a variant accept/reject scheme, as given # in Soderlind and Wang 2006. class PIDController( - AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]] + AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]], + metaclass=PIDMeta, ): r"""Adapts the step size to produce a solution accurate to a given tolerance. The tolerance is calculated as `atol + rtol * y` for the evolving solution `y`. @@ -353,35 +326,14 @@ def dynamics(t, y, args): dtmin: Optional[RealScalarLike] = None dtmax: Optional[RealScalarLike] = None force_dtmin: bool = True - step_ts: Optional[Real[Array, " steps"]] = eqx.field( - default=None, converter=_none_or_array - ) - jump_ts: Optional[Real[Array, " jumps"]] = eqx.field( - default=None, converter=_none_or_array - ) factormin: RealScalarLike = 0.2 factormax: RealScalarLike = 10.0 norm: Callable[[PyTree], RealScalarLike] = rms_norm safety: RealScalarLike = 0.9 error_order: Optional[RealScalarLike] = None - def __check_init__(self): - if self.jump_ts is not None and not jnp.issubdtype( - self.jump_ts.dtype, jnp.inexact - ): - raise ValueError( - f"jump_ts must be floating point, not {self.jump_ts.dtype}" - ) - def wrap(self, direction: IntScalarLike): - step_ts = None if self.step_ts is None else self.step_ts * direction - jump_ts = None if self.jump_ts is None else self.jump_ts * direction - return eqx.tree_at( - lambda s: (s.step_ts, s.jump_ts), - self, - (step_ts, jump_ts), - is_leaf=lambda x: x is None, - ) + return self def init( self, @@ -450,20 +402,18 @@ def init( at_dtmin = dt0 <= self.dtmin dt0 = jnp.maximum(dt0, self.dtmin) - t1 = self._clip_step_ts(t0, t0 + dt0) - t1, jump_next_step = self._clip_jump_ts(t0, t1) + t1 = t0 + dt0 y_leaves = jtu.tree_leaves(y0) if len(y_leaves) == 0: y_dtype = lxi.default_floating_dtype() else: y_dtype = jnp.result_type(*y_leaves) + real_dtype = complex_to_real_dtype(y_dtype) return t1, ( - jump_next_step, at_dtmin, - dt0, - jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), - jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), + jnp.array(1.0, dtype=real_dtype), + jnp.array(1.0, dtype=real_dtype), ) def adapt_step_size( @@ -543,22 +493,12 @@ def adapt_step_size( "error estimates." ) ( - made_jump, at_dtmin, - prev_dt, prev_inv_scaled_error, prev_prev_inv_scaled_error, ) = controller_state error_order = self._get_error_order(error_order) - # t1 - t0 is the step we actually took, so that's usually what we mean by the - # "previous dt". - # However if we made a jump then this t1 was clipped relatively to what it - # could have been, so for guessing the next step size it's probably better to - # use the size the step would have been, had there been no jump. - # There are cases in which something besides the step size controller modifies - # the step locations t0, t1; most notably the main integration routine clipping - # steps when we're right at the end of the interval. - prev_dt = jnp.where(made_jump, prev_dt, t1 - t0) + prev_dt = t1 - t0 # # Figure out how things went on the last step: error, and whether to @@ -576,7 +516,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 + # Automatically keep the step if we're at dtmin. if self.dtmin is not None: + at_dtmin = at_dtmin | (prev_dt <= self.dtmin) keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) @@ -600,10 +542,12 @@ def _scale(_y0, _y1_candidate, _y_error): factor2 = 1 if _zero_coeff(coeff2) else prev_inv_scaled_error**coeff2 factor3 = 1 if _zero_coeff(coeff3) else prev_prev_inv_scaled_error**coeff3 factormin = jnp.where(keep_step, 1, self.factormin) + # If the step is not kept, next step must be smaller, so factor must be <1. + factormax = jnp.where(keep_step, self.factormax, self.safety) factor = jnp.clip( self.safety * factor1 * factor2 * factor3, min=factormin, - max=self.factormax, + max=factormax, ) # Once again, see above. In case we have gradients on {i,p,d}coeff. # (Probably quite common for them to have zero tangents if passed across @@ -634,35 +578,20 @@ def _scale(_y0, _y1_candidate, _y_error): at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) - # - # Clip next step size based on step_ts/jump_ts - # - - if jnp.issubdtype(jnp.result_type(t1), jnp.inexact): - # Two nextafters. If made_jump then t1 = prevbefore(jump location) - # so now _t1 = nextafter(jump location) - # This is important because we don't know whether or not the jump is as a - # result of a left- or right-discontinuity, so we have to skip the jump - # location altogether. - _t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) - else: - _t1 = t1 - next_t0 = jnp.where(keep_step, _t1, t0) - next_t1 = self._clip_step_ts(next_t0, next_t0 + dt) - next_t1, next_made_jump = self._clip_jump_ts(next_t0, next_t1) + next_t0 = jnp.where(keep_step, t1, t0) + next_t1 = next_t0 + dt inv_scaled_error = jnp.where(keep_step, inv_scaled_error, prev_inv_scaled_error) prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) controller_state = ( - next_made_jump, at_dtmin, - dt, inv_scaled_error, prev_inv_scaled_error, ) - return keep_step, next_t0, next_t1, made_jump, controller_state, result + # made_jump is handled by JumpStepWrapper, so we automatically set it to False + return keep_step, next_t0, next_t1, False, controller_state, result def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarLike: # Attribute takes priority, if the user knows the correct error order better @@ -677,76 +606,6 @@ def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarL ) return error_order - def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: - if self.step_ts is None: - return t1 - - step_ts0 = upcast_or_raise( - self.step_ts, - t0, - "`PIDController.step_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - step_ts1 = upcast_or_raise( - self.step_ts, - t1, - "`PIDController.step_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping - # track of where we were last, and using that as a hint for the next search. - t0_index = jnp.searchsorted(step_ts0, t0, side="right") - t1_index = jnp.searchsorted(step_ts1, t1, side="right") - # This minimum may or may not actually be necessary. The left branch is taken - # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must - # already satisfy the minimum. - # However, that branch is actually executed unconditionally and then where'd, - # so we clamp it just to be sure we're not hitting undefined behaviour. - t1 = jnp.where( - t0_index < t1_index, - step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], - t1, - ) - return t1 - - def _clip_jump_ts( - self, t0: RealScalarLike, t1: RealScalarLike - ) -> tuple[RealScalarLike, BoolScalarLike]: - if self.jump_ts is None: - return t1, False - assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) - if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): - raise ValueError( - "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " - f"Got {jnp.result_type(t0)}." - ) - if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): - raise ValueError( - "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " - f"Got {jnp.result_type(t1)}." - ) - jump_ts0 = upcast_or_raise( - self.jump_ts, - t0, - "`PIDController.jump_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - jump_ts1 = upcast_or_raise( - self.jump_ts, - t1, - "`PIDController.jump_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - t0_index = jnp.searchsorted(jump_ts0, t0, side="right") - t1_index = jnp.searchsorted(jump_ts1, t1, side="right") - next_made_jump = t0_index < t1_index - t1 = jnp.where( - next_made_jump, - eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), - t1, - ) - return t1, next_made_jump - PIDController.__init__.__doc__ = """**Arguments:** @@ -761,10 +620,6 @@ def _clip_jump_ts( - `force_dtmin`: How to handle the step size hitting the minimum. If `True` then the step size is clipped to `dtmin`. If `False` then the differential equation solve halts with an error. -- `step_ts`: Denotes extra times that must be stepped to. -- `jump_ts`: Denotes extra times that must be stepped to, and at which the vector field - has a known discontinuity. (This is used to force FSAL solvers so re-evaluate the - vector field.) - `factormin`: Minimum amount a step size can be decreased relative to the previous step. - `factormax`: Maximum amount a step size can be increased relative to the previous diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 4cc996c8..1a3c9b33 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -4,7 +4,9 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.random as jr import jax.tree_util as jtu +import pytest from jaxtyping import Array from .helpers import tree_allclose @@ -17,7 +19,8 @@ def test_step_ts(): t1 = 5 dt0 = None y0 = 1.0 - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=[3, 4]) + pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, step_ts=[3, 4]) saveat = diffrax.SaveAt(steps=True) sol = diffrax.diffeqsolve( term, @@ -50,7 +53,8 @@ def vector_field(t, y, args): saveat = diffrax.SaveAt(steps=True) def run(**kwargs): - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, **kwargs) + pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, **kwargs) return diffrax.diffeqsolve( term, solver, @@ -75,13 +79,66 @@ def run(**kwargs): assert 8 in cast(Array, sol.ts) -def test_backprop(): +def test_revisit_steps(): + t0 = 0 + t1 = 5 + dt0 = 0.5 + y0 = 1.0 + drift = diffrax.ODETerm(lambda t, y, args: -0.2 * y) + + def diffusion_vf(t, y, args): + return jnp.ones((), dtype=y.dtype) + + bm = diffrax.VirtualBrownianTree(t0, t1, 2**-8, (), jr.key(0)) + diffusion = diffrax.ControlTerm(diffusion_vf, bm) + term = diffrax.MultiTerm(drift, diffusion) + solver = diffrax.Heun() + pid_controller = diffrax.PIDController( + rtol=0, atol=1e-3, dtmin=2**-7, pcoeff=0.5, icoeff=0.8 + ) + + rejected_ts = [] + + def callback_fun(keep_step, t1): + if not keep_step: + rejected_ts.append(t1) + + stepsize_controller = diffrax.JumpStepWrapper( + pid_controller, + step_ts=[3, 4], + rejected_step_buffer_len=10, + _callback_on_reject=callback_fun, + ) + saveat = diffrax.SaveAt(steps=True) + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + stepsize_controller=stepsize_controller, + saveat=saveat, + ) + + assert len(rejected_ts) > 10 + # check if all rejected ts are in the array sol.ts + assert all([t in sol.ts for t in rejected_ts]) + assert 3 in cast(Array, sol.ts) + assert 4 in cast(Array, sol.ts) + + +@pytest.mark.parametrize("use_jump_step", [True, False]) +def test_backprop(use_jump_step): + t0 = jnp.asarray(0, dtype=jnp.float64) + t1 = jnp.asarray(1, dtype=jnp.float64) + @eqx.filter_jit @eqx.filter_grad def run(ys, controller, state): y0, y1_candidate, y_error = ys _, tprev, tnext, _, state, _ = controller.adapt_step_size( - 0, 1, y0, y1_candidate, None, y_error, 5, state + t0, t1, y0, y1_candidate, None, y_error, 5, state ) with jax.numpy_dtype_promotion("standard"): return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) @@ -90,12 +147,16 @@ def run(ys, controller, state): y1_candidate = jnp.array(2.0) term = diffrax.ODETerm(lambda t, y, args: -y) solver = diffrax.Tsit5() - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-4) - _, state = stepsize_controller.init(term, 0, 1, y0, 0.1, None, solver.func, 5) + controller = diffrax.PIDController(rtol=1e-4, atol=1e-4) + if use_jump_step: + controller = diffrax.JumpStepWrapper( + controller, step_ts=[0.5], rejected_step_buffer_len=20 + ) + _, state = controller.init(term, t0, t1, y0, 0.1, None, solver.func, 5) for y_error in (jnp.array(0.0), jnp.array(3.0), jnp.array(jnp.inf)): ys = (y0, y1_candidate, y_error) - grads = run(ys, stepsize_controller, state) + grads = run(ys, controller, state) assert not any(jnp.isnan(grad).any() for grad in grads) @@ -113,9 +174,11 @@ def run(t): t1 = 1 dt0 = None y0 = 1.0 - stepsize_controller = diffrax.PIDController( - rtol=1e-8, atol=1e-8, step_ts=t[None] + pid_controller = diffrax.PIDController( + rtol=1e-8, + atol=1e-8, ) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, step_ts=t[None]) def forcing(s): return jnp.where(s < t, 0, 1) @@ -139,3 +202,17 @@ def forcing(s): finite_diff = (r(0.5) - r(0.5 - eps)) / eps autodiff = jax.jit(jax.grad(run))(0.5) assert tree_allclose(finite_diff, autodiff) + + +def test_pid_meta(): + ts = jnp.array([3, 4], dtype=jnp.float64) + pid1 = diffrax.PIDController(rtol=1e-4, atol=1e-6) + pid2 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts) + pid3 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts, jump_ts=ts) + assert not isinstance(pid1, diffrax.JumpStepWrapper) + assert isinstance(pid1, diffrax.PIDController) + assert isinstance(pid2, diffrax.JumpStepWrapper) + assert isinstance(pid3, diffrax.JumpStepWrapper) + assert all(pid2.step_ts == ts) + assert all(pid3.step_ts == ts) + assert all(pid3.jump_ts == ts) diff --git a/test/test_progress_meter.py b/test/test_progress_meter.py index 76028169..b7d5c4c8 100644 --- a/test/test_progress_meter.py +++ b/test/test_progress_meter.py @@ -39,7 +39,7 @@ def solve(t0): err = captured.err.strip() assert re.match("0.00%|[ ]+|", err.split("\r", 1)[0]) assert re.match("100.00%|█+|", err.rsplit("\r", 1)[1]) - assert captured.err.count("\r") == num_lines + assert captured.err.count("\r") - num_lines in [0, 1] assert captured.err.count("\n") == 1