Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JumpStepWrapper #484

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

andyElking
Copy link
Contributor

Hi Patrick,

I factored the jump_ts and step_ts out of the PIDController into JumpStepWrapper (I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:

  1. We always have t1-t0 <= prev_dt (this is checked via eqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

We achieve this in a very simple way here:

dt_proposal = next_t1 - next_t0
dt_proposal = jnp.where(
keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal
)
new_prev_dt = dt_proposal

The next step is to add a parameter JumpStepWrapper.revisit_rejected_steps which does what you expect. That will appear in a future commit in this same PR.

@andyElking
Copy link
Contributor Author

I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of step_ts and jump_ts, because the controller no longer searches the whole array each time, but keeps an index of where in the array it was previously.

Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have factor>1. To remedy this I modified the following:

factormax = jnp.where(keep_step, self.factormax, self.safety)
factor = jnp.clip(
self.safety * factor1 * factor2 * factor3,
min=factormin,
max=factormax,
)

I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.

I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of eqx.error_if statements to make sure the necessary invariants are always maintained. But this is a bit experimental, so maybe there are some bugs I didn't test for.

I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere.

P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper.

@andyElking
Copy link
Contributor Author

Hi @patrick-kidger,
I got rid of some eqx.error_ifs that I added to my JumpStepWrapper and redid the timing benchmarks. My new implementation was already faster than the old PIDController before, but now this is way more significant, especially when step_ts is long (think >100). Surprisingly, it is faster even when it has to revisit rejected steps. See

# ======= RESULTS =======
# New controller: 0.22829 s, Old controller: 0.31039 s
# Revisiting controller: 0.23212 s

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, quick first pass at a review!

diffrax/_step_size_controller/adaptive.py Show resolved Hide resolved
self.controller = controller
self.step_ts = _none_or_array(step_ts)
self.jump_ts = _none_or_array(jump_ts)
self.rejected_step_buffer_len = rejected_step_buffer_len
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we default this to max_steps, do you think? (Maybe it is set as None?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice it rarely needs to be more than 10, so max_steps would be way overkill in my opinion. But I agree, using None to turn it off is more transparent.

Copy link
Owner

@patrick-kidger patrick-kidger Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, but this is just a buffer of scalar times, so even something of size max_step should be cheap. I'm primarily concerned with user convenience here -- as high as max_steps isn't a strong feeling, but I think it definitely needs to larger than 10 -- I've definitely seen that many rejected steps in a row before at the start of a solve. Maybe 100 instead?

(EDIT: see below on buffer size.)

diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
Comment on lines 352 to 357
i_rjct = eqx.error_if(
i_rjct,
i_rjct < 0,
"Maximum number of rejected steps reached. "
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.",
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in particular we should be able to skip this check if we set this value to max_steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll run some timing tests to see if a long buffer significantly hinders performance. If not, I'm happy to change it to max_steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that using a long buffer doesn't impact performance at all. So I think it could be fine to just make a binary parameter revisit_rejected and if it is True, then the buffer length is max_steps. However when vmaping over this that could create a significant but unnecessary memory overhead, so maybe giving the option to change this buffer length might be reasonable. What do you think?

Copy link
Owner

@patrick-kidger patrick-kidger Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point about the vmap case.

(On buffer length: see below.)

"Maximum number of rejected steps reached. "
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.",
)
rjct_buff = jnp.where(keep_step, rjct_buff, rjct_buff.at[i_rjct].set(t1))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a very expensive way to describe this operation! You're copying the whole buffer. XLA will sometimes optimize this out -- because I added that optimization to it! -- but not always.

Better is to do rjct_buff.at[i_rjct].set(jnp.where(keep_step, rjct_buff[i_rjct], t1))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that, I think we may need to extend the API here slightly -- we should be able to mark state like this as being a buffer for the purposes of:

final_state = outer_while_loop(

which is needed to avoid spurious copies during backpropagation.

(You can see that both of these comments are basically us having to work around limitations of the XLA compiler.)

Copy link
Contributor Author

@andyElking andyElking Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one makes sense, I should have seen this.

I don't really know what you want me to do in your second comment. And frankly diffeqsolve is something I haven't even started digesting yet. Are you telling me rejected_buffer should be one of the outer_buffers, meaning that I should make it an instance of SaveState or sth like that? I would apprecaite a bit more guidance.

Also damn how you managed to write all this code is beyond me. Even trying to begin understanding it seems a lot! Very impressive!

Copy link
Owner

@patrick-kidger patrick-kidger Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, you're too kind!

As for my second comment -- I think I've realised that I was wrong. Let me explain. Our backpropagation involves saving copies of our state in checkpoints. Let's suppose we set RecursiveCheckpointAdjoint(checkpoints=max_steps), so that's O(max_steps) memory right? Well, not quite: our updating buffer here is potentially of length max_steps (as per the debate above), and we're saving a copy of it in every checkpont, so we'd actually be using O(max_steps^2) memory! That's not acceptable.

The simple solution to this will just be to set the size of this buffer to e.g. 100 by default, and just allow those copies to be made. And given the behaviour you have here -- in which you potentially overwrite values -- then that is actually what's necessary as well.

As for the complicated solution that I was wrong about: let's consider the case of SaveAt(steps=True). This also involves a buffer of length max_steps, that we save into as we go along. Fortunately, this one has a useful extra property, which is that we never overwrite a value. That means we don't actually need to copy our buffer for every checkpoint! We can use a single buffer that is shared across all checkpoints, getting gradually filled in. To support this case then we actually have a special argument eqxi.while_loop(..., buffers=...), to declare which elements of our loop state have this behaviour. Unfortunately that's not the case here because we do overwrite the values. (And side-note the presence of this buffers parameter is the reason I've not made this public API in Equinox, because the buffer-ness is completely unchecked and it's very easy to shoot yourself in the foot.)

Copy link
Contributor Author

@andyElking andyElking Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see. Thanks for the in-depth explanation! So let's see if I understand this correctly. If this was not getting rewritten, then I should make it register as a buffer in the outer while loop of diffeqsolve. But, because it does get rewritten, I should not do that(??). Still, I am curious, if I did want to register it as a buffer, how would I accomplish that? Is it indeed by making it an instance of SaveState, or is it something else entirely?

Other than that, should I keep it an Optional[Int] and just add something like this to the docstring:

For most SDEs, setting this to `100` should be sufficient, but if more consecutive steps are rejected, then an error will be raised.

?

@andyElking
Copy link
Contributor Author

Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of prev_dt entirely, as you suggested in #483?

@patrick-kidger
Copy link
Owner

Also, should I get rid of prev_dt entirely, as you suggested in #483?

If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.

I've just done another revivew, LMK what you think!

jump_ts: Optional[Array]
inner_state: _ControllerState

def get(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue that this kind of dataclass-to-tuple conversion is an antipattern. One of the points of having a dataclass over a tuple is that you can do myclass.some_attribute to unambiguously get the appropriate attribute -- without needing to implicitly rely on the correct unpacking order when doing ..., some_attribute, ... = myclass as you do in your callsite of this .get().

That said not a strong feeling here, just a general for-your-own-learning kind of comment.


def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike:
i_min_len = jnp.minimum(i, len(ts) - 1)
return jnp.where(i == len(ts), jnp.inf, ts[i_min_len])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the inf here -- can you add a test for using this with a backward solve with t0 > t1? Just to make sure that we're correctly handling that case.

@@ -39,7 +39,7 @@ def solve(t0):
err = captured.err.strip()
assert re.match("0.00%|[ ]+|", err.split("\r", 1)[0])
assert re.match("100.00%|█+|", err.rsplit("\r", 1)[1])
assert captured.err.count("\r") == num_lines
assert captured.err.count("\r") - num_lines in [0, 1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fixed in recent versions of dev. (I recall having to fix an issue here recently that came from a JAX version bump.)



if TYPE_CHECKING:
from typing import ClassVar as AbstractVar
pass
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this TYPE_CHECKING block I think.

self.step_ts = _none_or_array(step_ts)
self.jump_ts = _none_or_array(jump_ts)
if rejected_step_buffer_len is not None:
assert rejected_step_buffer_len > 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert -> ValueError as it's a user error not a developer error.

# Upcast step_ts to the same dtype as t0, t1
step_ts = upcast_or_raise(
self.step_ts,
jnp.zeros((), tdtype),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's adjust upcast_or_raise to accept a dtype here?

keep_step,
next_t0,
next_t1,
_,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think discarding here is correct. We should do the right thing even if we have a doubly-nested JumpStepWrapper(JumpStepWraper(PIDController(...), ...), ...).

if step_ts is not None:
# If we stepped to `t1 == step_ts[i_step]` and kept the step, then we
# increment i_step and move on to the next t in step_ts.
step_inc_cond = keep_step & (t1 == _get_t(i_step, step_ts))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm comfortable with this == between floating point numbers.

More generally speaking I think there could be cases in which the times being passed here do not perfectly align with the times that the adaptive step size controller suggested on the previous step (e.g. because of further wrapping of the step size controller), so I think this kind of logic is wrong anyway. I think you need something more like the jump_ts branch below, where you just want to snap i_step to the correct value. (Nothing that the correct value should be determinable statically, we only have state here for efficiency purposes.)

Comment on lines +348 to +350
next_jump_t = _get_t(i_jump, jump_ts)
jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t))
i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, if someone else is fiddling with the times then this seems to me like it might be fragile.

Recalling the previous implementation with its inefficient use of searchsorted. I think the robust approach here might be to write something with the same API as that, but whose implementation is just a simple linear search forwards or backwards from the current position (which is a 'hint' about where to start searching).

Most of the time that will just iterate once and be done, as here. But in the edge cases it should now do the right thing.

@andyElking
Copy link
Contributor Author

Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants