diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 399fdbd8..7c0d30ff 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -35,6 +35,8 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest pytest-codspeed python -m pip install . + # temp pin until 0.5 is on conda + python -m pip install "jax<0.5.0" - name: Test with pytest run: | diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 7ef21332..d00cad23 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -212,3 +212,29 @@ def test_benchmark_spergel_kvalue(benchmark, kind): benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +@jax.jit +def _run_spergel_bench_init(): + return jgs.Spergel(nu=-0.6, half_light_radius=3.4).scale_radius + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_init(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_init().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") + + +@jax.jit +def _run_gaussian_bench_init(): + return jgs.Gaussian(half_light_radius=3.4).sigma + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_gaussian_init(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ")