Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better custom autodiff #407

Merged
merged 4 commits into from
Jul 5, 2023
Merged

better custom autodiff #407

merged 4 commits into from
Jul 5, 2023

Conversation

patrick-kidger
Copy link
Owner

  • Now using symbolic zeros in filter_custom_{jvp,vjp}. This also brings with it an API tweak for backward compatibility. (E.g. filter_custom_jvp.defjvp -> filter_custom_jvp.def_jvp for the new behaviour.)
  • Removed unnecessary lax.select in eqxi.while_loop when not vmap'ing. This should improve speed slightly when not vmap'ing.
  • Improved speed of backprop through eqxi.while_loop buffers, in particular avoiding Big performance discrepancy between JAX and TensorFlow with in-place updates jax-ml/jax#10197. In particular this can be used to reduce some programs with quadratic runtime down to just linear runtime.

@patrick-kidger
Copy link
Owner Author

TODO: benchmark effect on Diffrax.

@patrick-kidger patrick-kidger merged commit b7c6f63 into main Jul 5, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the better-custom-autodiff branch July 5, 2023 08:21
@patrick-kidger patrick-kidger restored the better-custom-autodiff branch July 5, 2023 08:24
@patrick-kidger patrick-kidger deleted the better-custom-autodiff branch July 5, 2023 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant