-
Notifications
You must be signed in to change notification settings - Fork 7
fix: ensure we properly indicate to jax that beta does not have grads for moffat and spergel #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Merging this PR will not alter performance
|
|
@EiffL can you look at these stop_gradient calls and tell me if they make sense? |
|
Look reasonable, and I think this should solve the gradient problems, but I'd have to run my test script to be sure, which I can only do tonight (I've been surprised by similar things in the past, like the stop gradient not actually preventing the computation of some gradients on the path). |
|
It doesn't appear to help the benchmarks so I am guessing the gradient problems are not helped. |
|
GPU Benchmark: clarify-moffat-beta-deriv vs main vs besssel_improv |
|
doesnt seem to help |
|
Yep. The benchmark tests on the cpu seem to be a reasonably accurate predictor. We should still merge this pr after a few changes. |
I updated the doc string and I added a stop_gradient call to help JAX along instead of asking it to propagate zeros everywhere.
closes #184