-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor glass.fields.iternorm
#486
Comments
I have been looking into the code and the paper, and I think it would be best to have a separate JAX implementation for the function. The array mutations are a smart choice for CPUs as they help in keeping the memory overhead low. Changing the behavior to create copies will result in a bad performance on CPUs, and IIRC, we had a discussion that GLASS should run smoothly on both the architectures as both of them are required for specific needs. JAX indexing syntax does create copies of arrays when running on CPU, but it does not when JIT compiled. So, the JAX implementation would be accelerated on GPUs and memory efficient on CPUs with JIT compilation. |
cc: @ntessore |
Maybe it makes sense to split this issue into two separate tasks: On the science side: figure out a way to have different values of If we go to As you can see, this is made more awkward by the HEALPix order of the alms, where we have blocks in On the GPU side: make the code run on GPU/JAX. This should happen with whatever solution to the above we come up with. |
Add your issue here
The function mutates arrays in-place heavily, and it should be refactored to work with immutable arrays and data structures. Moreover, the correlations do not need to be regular and could be irregular/ragged to save memory; hence, we should investigate #382 in parallel with this issue.
The text was updated successfully, but these errors were encountered: