Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor glass.fields.iternorm #486

Open
Saransh-cpp opened this issue Jan 23, 2025 · 3 comments
Open

Refactor glass.fields.iternorm #486

Saransh-cpp opened this issue Jan 23, 2025 · 3 comments
Assignees
Labels
api An (incompatible) API change performance Performance improvements or regressions

Comments

@Saransh-cpp
Copy link
Member

Add your issue here

The function mutates arrays in-place heavily, and it should be refactored to work with immutable arrays and data structures. Moreover, the correlations do not need to be regular and could be irregular/ragged to save memory; hence, we should investigate #382 in parallel with this issue.

@Saransh-cpp Saransh-cpp added api An (incompatible) API change performance Performance improvements or regressions labels Jan 23, 2025
@Saransh-cpp Saransh-cpp self-assigned this Jan 23, 2025
@Saransh-cpp
Copy link
Member Author

I have been looking into the code and the paper, and I think it would be best to have a separate JAX implementation for the function. The array mutations are a smart choice for CPUs as they help in keeping the memory overhead low. Changing the behavior to create copies will result in a bad performance on CPUs, and IIRC, we had a discussion that GLASS should run smoothly on both the architectures as both of them are required for specific needs.

JAX indexing syntax does create copies of arrays when running on CPU, but it does not when JIT compiled. So, the JAX implementation would be accelerated on GPUs and memory efficient on CPUs with JIT compilation.

@Saransh-cpp
Copy link
Member Author

cc: @ntessore

@ntessore
Copy link
Collaborator

Maybe it makes sense to split this issue into two separate tasks:


On the science side: figure out a way to have different values of ncorr for different $\ell$ (as in $a_{\ell m}$). If we write $a_{\ell m}^{(g)}$ for the alms of "generation" $g$, then we currently have a rectangular array of previous samples:

$$\left.\begin{matrix} a_{00}^{(1)} & \ldots & a_{00}^{(\mathtt{ncorr})} \\\ a_{10}^{(1)} & \ldots & a_{10}^{(\mathtt{ncorr})} \\\ a_{20}^{(1)} & \ldots & a_{20}^{(\mathtt{ncorr})} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=0$$ $$\left.\begin{matrix} a_{11}^{(1)} & \ldots & a_{11}^{(\mathtt{ncorr})} \\\ a_{21}^{(1)} & \ldots & a_{21}^{(\mathtt{ncorr})} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=1$$ $$\left.\begin{matrix} a_{22}^{(1)} & \ldots & a_{22}^{(\mathtt{ncorr})} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=2$$

If we go to $\ell$-dependent values of $\mathtt{ncorr}_\ell$ then this turns into a complicated ragged array:

$$\left.\begin{matrix} a_{00}^{(1)} & \ldots & \ldots & \ldots & a_{00}^{(\mathtt{ncorr}_0)} \\\ a_{10}^{(1)} & \ldots & \ldots & a_{10}^{(\mathtt{ncorr}_1)} \\\ a_{20}^{(1)} & \ldots & a_{20}^{(\mathtt{ncorr}_2)} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=0$$ $$\left.\begin{matrix} a_{11}^{(1)} & \ldots & \ldots & a_{11}^{(\mathtt{ncorr}_1)} \\\ a_{21}^{(1)} & \ldots & a_{21}^{(\mathtt{ncorr}_2)} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=1$$ $$\left.\begin{matrix} a_{22}^{(1)} & \ldots & a_{22}^{(\mathtt{ncorr}_2)} \\\ \vdots & \vdots & \vdots \end{matrix}\right\} \; m=2$$

As you can see, this is made more awkward by the HEALPix order of the alms, where we have blocks in $m$, not $\ell$.


On the GPU side: make the code run on GPU/JAX. This should happen with whatever solution to the above we come up with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api An (incompatible) API change performance Performance improvements or regressions
Projects
None yet
Development

No branches or pull requests

2 participants