diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml new file mode 100644 index 0000000..6bdd95e --- /dev/null +++ b/.github/workflows/test_package.yml @@ -0,0 +1,46 @@ +name: Test diffopt + +on: + workflow_dispatch: null + schedule: + # Runs "every Monday at noon UTC" + - cron: '0 12 * * 1' + push: + branches: + - main + pull_request: null + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: test${{ matrix.python-version}} + python-version: ${{ matrix.python-version }} + channels: conda-forge,defaults + channel-priority: strict + show-channel-urls: true + miniforge-version: latest + + - name: Install dependencies + shell: bash -l {0} + run: | + conda install -yq jax + conda install -yq pip pytest pytest-cov flake8 + + - name: Install package + shell: bash -l {0} + run: | + pip install -e . + + - name: Run tests + shell: bash -l {0} + run: | + export PYTHONWARNINGS=error + pytest -v diff --git a/README.md b/README.md index 2117c8c..f4ef274 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,6 @@ Parallelization and optimization of differentiable and many-parameter models ## Documentation Online documentation is available at [diffopt.readthedocs.io](https://diffopt.readthedocs.io/en/latest). + +## Manual Testing +Unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. To run all tests, install `mpi4py` locally, and run `pytest` from the root directory. Additionally, all tests must pass with `mpirun -n 4 pytest` etc. (up to the maximum number of tasks that can be run on your machine). \ No newline at end of file diff --git a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py index 173d603..91e2af8 100644 --- a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py +++ b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py @@ -5,7 +5,10 @@ from typing import NamedTuple from dataclasses import dataclass -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None import jax.scipy from jax import numpy as jnp import numpy as np @@ -21,11 +24,15 @@ class ParamTuple(NamedTuple): # Generate fake HMF as power law (truncated so that the SMF has a knee) def load_halo_masses(num_halos=10_000, slope=-2, mmin=10.0 ** 10, qmax=0.95): + if MPI is None: + size, rank = 1, 0 + else: + size, rank = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank q = jnp.linspace(0, qmax, num_halos) mhalo = mmin * (1 - q) ** (1/(slope+1)) # Assign different halos to different MPI processes - return np.array_split(mhalo, MPI.COMM_WORLD.size)[MPI.COMM_WORLD.rank] + return np.array_split(mhalo, size)[rank] # SMF helper functions @@ -89,6 +96,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None): parser.add_argument("--learning-rate", type=float, default=1e-3) if __name__ == "__main__": + assert MPI is not None, "MPI must be installed to run this script" args = parser.parse_args() data = dict( log_halo_masses=jnp.log10(load_halo_masses(args.num_halos)), diff --git a/diffopt/multigrad/tests/test_mpi.py b/diffopt/multigrad/tests/test_mpi.py index 6606357..b8d2f7f 100644 --- a/diffopt/multigrad/tests/test_mpi.py +++ b/diffopt/multigrad/tests/test_mpi.py @@ -5,18 +5,20 @@ and `mpiexec -n 10 pytest test_mpi.py` all must pass (the --with-mpi flag shouldn't have any effect) """ -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None import jax.numpy as jnp +import unittest from ... import multigrad from .smf_example import smf_grad_descent as sgd -comm = MPI.COMM_WORLD -rank = comm.Get_rank() -size = comm.Get_size() - +@unittest.skipIf(MPI is None, "MPI must be installed to run this test") def test_reduce_sum(): + rank, size = MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size # Set value equal to the rank of the process value = jnp.array(rank) @@ -24,7 +26,7 @@ def test_reduce_sum(): result = multigrad.reduce_sum(value) # Gather the results from all processes - gathered_results = jnp.array(comm.allgather(result)) + gathered_results = jnp.array(MPI.COMM_WORLD.allgather(result)) if not rank: # Perform testing only on the rank 0 process diff --git a/docs/source/multigrad/smf_gradient_descent.py b/docs/source/multigrad/smf_gradient_descent.py index c31f739..f90e9f8 100644 --- a/docs/source/multigrad/smf_gradient_descent.py +++ b/docs/source/multigrad/smf_gradient_descent.py @@ -1,4 +1,9 @@ -from mpi4py import MPI +try: + from mpi4py import MPI + COMM_WORLD = MPI.COMM_WORLD +except ImportError: + MPI = None + COMM_WORLD = None import jax.scipy from jax import numpy as jnp import numpy as np @@ -7,13 +12,19 @@ from diffopt import multigrad -def load_halo_masses(num_halos=10_000, comm=MPI.COMM_WORLD): +def load_halo_masses(num_halos=10_000, comm=COMM_WORLD): + if comm is None: + size = 1 + rank = 0 + else: + size = comm.size + rank = comm.rank # Generate fake halo masses between 10^10 < M_h < 10^11 as a power law quantile = jnp.linspace(0, 0.9, num_halos) mhalo = 1e10 / (1 - quantile) # Assign halos evenly across given MPI ranks (only one rank for now) - return np.array_split(mhalo, comm.size)[comm.rank] + return np.array_split(mhalo, size)[rank] # Compute one bin of the stellar mass function (SMF) @@ -92,6 +103,7 @@ def calc_loss_from_sumstats(self, sumstats): if __name__ == "__main__": + assert MPI is not None, "MPI must be installed to run this script" volume = 1.0 smf_bin_edges = jnp.linspace(9, 10, 11) true_params = jnp.array([-2.0, -0.5])