-
-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matching the performances of jax.lax.scan
in adjoint calculation
#274
Comments
First of all, thankyou for providing such a careful benchmark. Second, sorry for taking so long to get back to you -- tackling this has turned out to be an interesting problem, which turned out to take longer than I thought! (This also turned up two XLA bugs along the way: jax-ml/jax#16663, jax-ml/jax#16661.) Anyway, I'm happy to say that as of #276, the performance is now much closer. On my V100:
Note that I did make one change to the benchmark to ensure a fair comparison: I switched Diffrax's As for the changes that I made: most of the overhead turned out to be due to the extra complexity of the recursive checkpointing (as opposed to simply checkpointing on every step). The relavant changes are in patrick-kidger/equinox#415. This improvement will appear in the next release of Diffrax. And there's clearly still a small discrepancy on the backward pass -- it looks like that still needs some more careful profiling. Let me know what this looks like for your actual use-case! |
Thanks a lot, that looks great! I'm on holidays until the end of next week, but as soon as I'm back I'll give it a try on the actual simulator. |
Sorry if this took a while, but just tested it in the simulator and it works amazingly 😄 I do have a couple of warnings at startup due to deprecation of Also... Did I just got a runtime XLA error based on array values ?? What is this new amazing sorcery 😍 ?? Thanks again for fixing this! |
Marvellous!
Yup, next release of Diffrax (in the next few days) will avoid that code path.
Yes you did! Equinox recently added public support for these. Documentation available here. Indeed "sorcery" is the appropriate word, since this is something new under the JAX sun. |
Just released the new version of Diffrax. I think everything discussed here should now work, fast, without warnings. As such I'm closing this, but please do re-open it if you think this isn't fixed. |
Hello again!
I am still on my quest to add proper integration + adjoint calculation (and checkpointing) to my wave simulator 😄
I appreciate that
diffrax
offers a range of methods for calculating adjoints, each with its own trade-off between computational complexity and memory requirements.However, for smaller simulations, it might be beneficial to maximize checkpoint usage and potentially save the entire ODE trajectory for reverse-mode AD. This approach takes full advantage of the GPU memory, thereby reducing computational times.
My understanding is that this can be achieved by using
RecursiveCheckpointAdjoint
with a large value ofcheckpoints
, potentially as high as the number of steps in the forward integrator.I've attempted to implement this without much success. To be precise, while I am obtaining the correct numerical results, the computation times are far longer than expected.
Here is an MWE:
Where I get the following timings on an RTX 4000:
As expected, for
scan
, the AD calculation is roughly twice the execution time required by the forward pass. This can be made almost exactly 2x if thejax.checkpoint
decorator is removed.For the forward pass of
SemiImplicitEuler
, the timings I get are approximately twice those of the scan alone. However, this could easily be attributed to the more sophisticated implementation of thediffrax
integrator, so overall that's completely fine.However, the timings for performing AD are about 7x those required by the scan method. In a more complex example within my simulator, it can reach up to 30x the time required by the equivalent scan integrator.
Am I missing something about the correct approach to calculating the adjoint?
Also, I'm not sure if the
RecursiveCheckpointAdjoint
is using the samesolver
as the forward integrator (based on my understanding of the documentation, it isn't), and I can't seem to find a way to pass a specificsolver
to it.Is it be necessary to define a new class derived from
AbstractAdjoint
with a customloop
method to achieve this?Thanks a lot!
The text was updated successfully, but these errors were encountered: