Skip to content

Commit

Permalink
Merge pull request #170 from astro-informatics/mmg/healpix-fft-tests
Browse files Browse the repository at this point in the history
Tests for consistency of HEALPix FFT and IFFT implementations
  • Loading branch information
matt-graham authored Dec 4, 2023
2 parents 01f0b9a + 4fe1afc commit 4ef9c67
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/test_healpix_ffts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import healpy as hp
import pytest
from jax import config
from s2fft.sampling import s2_samples as samples
from s2fft.utils.healpix_ffts import (
healpix_fft_jax,
healpix_fft_numpy,
healpix_ifft_jax,
healpix_ifft_numpy,
)


config.update("jax_enable_x64", True)


nside_to_test = [4, 5]
reality_to_test = [False, True]


@pytest.mark.parametrize("nside", nside_to_test)
@pytest.mark.parametrize("reality", reality_to_test)
def test_healpix_fft_jax_numpy_consistency(flm_generator, nside, reality):
L = 2 * nside
# Generate a random bandlimited signal
flm = flm_generator(L=L, reality=reality)
flm_hp = samples.flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
# Test consistency
assert np.allclose(
healpix_fft_numpy(f, L, nside, reality), healpix_fft_jax(f, L, nside, reality)
)


@pytest.mark.parametrize("nside", nside_to_test)
@pytest.mark.parametrize("reality", reality_to_test)
def test_healpix_ifft_jax_numpy_consistency(flm_generator, nside, reality):
L = 2 * nside
# Generate a random bandlimited signal
flm = flm_generator(L=L, reality=reality)
flm_hp = samples.flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
ftm = healpix_fft_numpy(f, L, nside, reality)
ftm_copy = np.copy(ftm)
# Test consistency
assert np.allclose(
healpix_ifft_numpy(ftm, L, nside, reality),
healpix_ifft_jax(ftm_copy, L, nside, reality),
)

0 comments on commit 4ef9c67

Please sign in to comment.