-
-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added JumpStepWrapper #484
base: main
Are you sure you want to change the base?
Conversation
78b122a
to
0eac356
Compare
I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have diffrax/diffrax/_step_size_controller/adaptive.py Lines 569 to 574 in 501bed5
I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.
I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere. P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper. |
501bed5
to
d022ac1
Compare
d022ac1
to
4702380
Compare
Hi @patrick-kidger, diffrax/benchmarks/jump_step_timing.py Lines 126 to 128 in 345e23a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, quick first pass at a review!
self.controller = controller | ||
self.step_ts = _none_or_array(step_ts) | ||
self.jump_ts = _none_or_array(jump_ts) | ||
self.rejected_step_buffer_len = rejected_step_buffer_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we default this to max_steps
, do you think? (Maybe it is set as None
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice it rarely needs to be more than 10, so max_steps
would be way overkill in my opinion. But I agree, using None
to turn it off is more transparent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, but this is just a buffer of scalar times, so even something of size max_step
should be cheap. I'm primarily concerned with user convenience here -- as high as max_steps
isn't a strong feeling, but I think it definitely needs to larger than 10 -- I've definitely seen that many rejected steps in a row before at the start of a solve. Maybe 100 instead?
(EDIT: see below on buffer size.)
i_rjct = eqx.error_if( | ||
i_rjct, | ||
i_rjct < 0, | ||
"Maximum number of rejected steps reached. " | ||
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in particular we should be able to skip this check if we set this value to max_steps
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll run some timing tests to see if a long buffer significantly hinders performance. If not, I'm happy to change it to max_steps
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems that using a long buffer doesn't impact performance at all. So I think it could be fine to just make a binary parameter revisit_rejected
and if it is True
, then the buffer length is max_steps
. However when vmaping over this that could create a significant but unnecessary memory overhead, so maybe giving the option to change this buffer length might be reasonable. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good point about the vmap case.
(On buffer length: see below.)
"Maximum number of rejected steps reached. " | ||
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.", | ||
) | ||
rjct_buff = jnp.where(keep_step, rjct_buff, rjct_buff.at[i_rjct].set(t1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is a very expensive way to describe this operation! You're copying the whole buffer. XLA will sometimes optimize this out -- because I added that optimization to it! -- but not always.
Better is to do rjct_buff.at[i_rjct].set(jnp.where(keep_step, rjct_buff[i_rjct], t1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than that, I think we may need to extend the API here slightly -- we should be able to mark state like this as being a buffer for the purposes of:
Line 621 in 0679807
final_state = outer_while_loop( |
which is needed to avoid spurious copies during backpropagation.
(You can see that both of these comments are basically us having to work around limitations of the XLA compiler.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first one makes sense, I should have seen this.
I don't really know what you want me to do in your second comment. And frankly diffeqsolve
is something I haven't even started digesting yet. Are you telling me rejected_buffer
should be one of the outer_buffers
, meaning that I should make it an instance of SaveState
or sth like that? I would apprecaite a bit more guidance.
Also damn how you managed to write all this code is beyond me. Even trying to begin understanding it seems a lot! Very impressive!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, you're too kind!
As for my second comment -- I think I've realised that I was wrong. Let me explain. Our backpropagation involves saving copies of our state in checkpoints. Let's suppose we set RecursiveCheckpointAdjoint(checkpoints=max_steps)
, so that's O(max_steps)
memory right? Well, not quite: our updating buffer here is potentially of length max_steps
(as per the debate above), and we're saving a copy of it in every checkpont, so we'd actually be using O(max_steps^2)
memory! That's not acceptable.
The simple solution to this will just be to set the size of this buffer to e.g. 100 by default, and just allow those copies to be made. And given the behaviour you have here -- in which you potentially overwrite values -- then that is actually what's necessary as well.
As for the complicated solution that I was wrong about: let's consider the case of SaveAt(steps=True)
. This also involves a buffer of length max_steps
, that we save into as we go along. Fortunately, this one has a useful extra property, which is that we never overwrite a value. That means we don't actually need to copy our buffer for every checkpoint! We can use a single buffer that is shared across all checkpoints, getting gradually filled in. To support this case then we actually have a special argument eqxi.while_loop(..., buffers=...)
, to declare which elements of our loop state have this behaviour. Unfortunately that's not the case here because we do overwrite the values. (And side-note the presence of this buffers
parameter is the reason I've not made this public API in Equinox, because the buffer-ness is completely unchecked and it's very easy to shoot yourself in the foot.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I see. Thanks for the in-depth explanation! So let's see if I understand this correctly. If this was not getting rewritten, then I should make it register as a buffer in the outer while loop of diffeqsolve. But, because it does get rewritten, I should not do that(??). Still, I am curious, if I did want to register it as a buffer, how would I accomplish that? Is it indeed by making it an instance of SaveState
, or is it something else entirely?
Other than that, should I keep it an Optional[Int]
and just add something like this to the docstring:
For most SDEs, setting this to `100` should be sufficient, but if more consecutive steps are rejected, then an error will be raised.
?
Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of |
0050fa2
to
c3c4dcf
Compare
If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.
I've just done another revivew, LMK what you think!
jump_ts: Optional[Array] | ||
inner_state: _ControllerState | ||
|
||
def get(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd argue that this kind of dataclass-to-tuple conversion is an antipattern. One of the points of having a dataclass over a tuple is that you can do myclass.some_attribute
to unambiguously get the appropriate attribute -- without needing to implicitly rely on the correct unpacking order when doing ..., some_attribute, ... = myclass
as you do in your callsite of this .get()
.
That said not a strong feeling here, just a general for-your-own-learning kind of comment.
|
||
def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike: | ||
i_min_len = jnp.minimum(i, len(ts) - 1) | ||
return jnp.where(i == len(ts), jnp.inf, ts[i_min_len]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the inf
here -- can you add a test for using this with a backward solve with t0 > t1
? Just to make sure that we're correctly handling that case.
@@ -39,7 +39,7 @@ def solve(t0): | |||
err = captured.err.strip() | |||
assert re.match("0.00%|[ ]+|", err.split("\r", 1)[0]) | |||
assert re.match("100.00%|█+|", err.rsplit("\r", 1)[1]) | |||
assert captured.err.count("\r") == num_lines | |||
assert captured.err.count("\r") - num_lines in [0, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be fixed in recent versions of dev
. (I recall having to fix an issue here recently that came from a JAX version bump.)
|
||
|
||
if TYPE_CHECKING: | ||
from typing import ClassVar as AbstractVar | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this TYPE_CHECKING
block I think.
self.step_ts = _none_or_array(step_ts) | ||
self.jump_ts = _none_or_array(jump_ts) | ||
if rejected_step_buffer_len is not None: | ||
assert rejected_step_buffer_len > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert
-> ValueError
as it's a user error not a developer error.
# Upcast step_ts to the same dtype as t0, t1 | ||
step_ts = upcast_or_raise( | ||
self.step_ts, | ||
jnp.zeros((), tdtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's adjust upcast_or_raise
to accept a dtype here?
keep_step, | ||
next_t0, | ||
next_t1, | ||
_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think discarding here is correct. We should do the right thing even if we have a doubly-nested JumpStepWrapper(JumpStepWraper(PIDController(...), ...), ...)
.
if step_ts is not None: | ||
# If we stepped to `t1 == step_ts[i_step]` and kept the step, then we | ||
# increment i_step and move on to the next t in step_ts. | ||
step_inc_cond = keep_step & (t1 == _get_t(i_step, step_ts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I'm comfortable with this ==
between floating point numbers.
More generally speaking I think there could be cases in which the times being passed here do not perfectly align with the times that the adaptive step size controller suggested on the previous step (e.g. because of further wrapping of the step size controller), so I think this kind of logic is wrong anyway. I think you need something more like the jump_ts
branch below, where you just want to snap i_step
to the correct value. (Nothing that the correct value should be determinable statically, we only have state here for efficiency purposes.)
next_jump_t = _get_t(i_jump, jump_ts) | ||
jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t)) | ||
i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, if someone else is fiddling with the times then this seems to me like it might be fragile.
Recalling the previous implementation with its inefficient use of searchsorted. I think the robust approach here might be to write something with the same API as that, but whose implementation is just a simple linear search forwards or backwards from the current position (which is a 'hint' about where to start searching).
Most of the time that will just iterate once and be done, as here. But in the edge cases it should now do the right thing.
Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship. |
Hi Patrick,
I factored the
jump_ts
andstep_ts
out of thePIDController
intoJumpStepWrapper
(I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:t1-t0 <= prev_dt
(this is checked viaeqx.error_if
), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).next_dt
must be>=prev_dt
.next_dt
must be< t1-t0
.We achieve this in a very simple way here:
diffrax/diffrax/_step_size_controller/jump_step_wrapper.py
Lines 119 to 123 in 78b122a
The next step is to add a parameter
JumpStepWrapper.revisit_rejected_steps
which does what you expect. That will appear in a future commit in this same PR.