Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create test_package workflow #4

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Test diffopt

on:
workflow_dispatch: null
schedule:
# Runs "every Monday at noon UTC"
- cron: '0 12 * * 1'
push:
branches:
- main
pull_request: null

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: test${{ matrix.python-version}}
python-version: ${{ matrix.python-version }}
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
miniforge-version: latest

- name: Install dependencies
shell: bash -l {0}
run: |
conda install -yq jax
conda install -yq pip pytest pytest-cov flake8

- name: Install package
shell: bash -l {0}
run: |
pip install -e .

- name: Run tests
shell: bash -l {0}
run: |
export PYTHONWARNINGS=error
pytest -v
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ Parallelization and optimization of differentiable and many-parameter models

## Documentation
Online documentation is available at [diffopt.readthedocs.io](https://diffopt.readthedocs.io/en/latest).

## Manual Testing
Unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. To run all tests, install `mpi4py` locally, and run `pytest` from the root directory. Additionally, all tests must pass with `mpirun -n 4 pytest` etc. (up to the maximum number of tasks that can be run on your machine).
12 changes: 10 additions & 2 deletions diffopt/multigrad/tests/smf_example/smf_grad_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from typing import NamedTuple
from dataclasses import dataclass

from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
import jax.scipy
from jax import numpy as jnp
import numpy as np
Expand All @@ -21,11 +24,15 @@ class ParamTuple(NamedTuple):

# Generate fake HMF as power law (truncated so that the SMF has a knee)
def load_halo_masses(num_halos=10_000, slope=-2, mmin=10.0 ** 10, qmax=0.95):
if MPI is None:
size, rank = 1, 0
else:
size, rank = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank
q = jnp.linspace(0, qmax, num_halos)
mhalo = mmin * (1 - q) ** (1/(slope+1))

# Assign different halos to different MPI processes
return np.array_split(mhalo, MPI.COMM_WORLD.size)[MPI.COMM_WORLD.rank]
return np.array_split(mhalo, size)[rank]


# SMF helper functions
Expand Down Expand Up @@ -89,6 +96,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None):
parser.add_argument("--learning-rate", type=float, default=1e-3)

if __name__ == "__main__":
assert MPI is not None, "MPI must be installed to run this script"
args = parser.parse_args()
data = dict(
log_halo_masses=jnp.log10(load_halo_masses(args.num_halos)),
Expand Down
14 changes: 8 additions & 6 deletions diffopt/multigrad/tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@
and `mpiexec -n 10 pytest test_mpi.py` all must pass
(the --with-mpi flag shouldn't have any effect)
"""
from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
import jax.numpy as jnp

import unittest
from ... import multigrad
from .smf_example import smf_grad_descent as sgd

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()


@unittest.skipIf(MPI is None, "MPI must be installed to run this test")
def test_reduce_sum():
rank, size = MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size
# Set value equal to the rank of the process
value = jnp.array(rank)

# Reduce the sum of the values across all ranks
result = multigrad.reduce_sum(value)

# Gather the results from all processes
gathered_results = jnp.array(comm.allgather(result))
gathered_results = jnp.array(MPI.COMM_WORLD.allgather(result))

if not rank:
# Perform testing only on the rank 0 process
Expand Down
18 changes: 15 additions & 3 deletions docs/source/multigrad/smf_gradient_descent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from mpi4py import MPI
try:
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD
except ImportError:
MPI = None
COMM_WORLD = None
import jax.scipy
from jax import numpy as jnp
import numpy as np
Expand All @@ -7,13 +12,19 @@
from diffopt import multigrad


def load_halo_masses(num_halos=10_000, comm=MPI.COMM_WORLD):
def load_halo_masses(num_halos=10_000, comm=COMM_WORLD):
if comm is None:
size = 1
rank = 0
else:
size = comm.size
rank = comm.rank
# Generate fake halo masses between 10^10 < M_h < 10^11 as a power law
quantile = jnp.linspace(0, 0.9, num_halos)
mhalo = 1e10 / (1 - quantile)

# Assign halos evenly across given MPI ranks (only one rank for now)
return np.array_split(mhalo, comm.size)[comm.rank]
return np.array_split(mhalo, size)[rank]


# Compute one bin of the stellar mass function (SMF)
Expand Down Expand Up @@ -92,6 +103,7 @@ def calc_loss_from_sumstats(self, sumstats):


if __name__ == "__main__":
assert MPI is not None, "MPI must be installed to run this script"
volume = 1.0
smf_bin_edges = jnp.linspace(9, 10, 11)
true_params = jnp.array([-2.0, -0.5])
Expand Down
Loading