Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge software changes into paper branch #6

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Create test_package workflow (#4)
* Create test_package workflow

* Skip unit tests requiring mpi4py if not installed

* Add manual testing instructions to README
AlanPearl authored Nov 26, 2024
commit 43d4b57c38cf7a79fa2e068cdbb65dfecf3b53f3
46 changes: 46 additions & 0 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Test diffopt

on:
workflow_dispatch: null
schedule:
# Runs "every Monday at noon UTC"
- cron: '0 12 * * 1'
push:
branches:
- main
pull_request: null

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: test${{ matrix.python-version}}
python-version: ${{ matrix.python-version }}
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
miniforge-version: latest

- name: Install dependencies
shell: bash -l {0}
run: |
conda install -yq jax
conda install -yq pip pytest pytest-cov flake8
- name: Install package
shell: bash -l {0}
run: |
pip install -e .
- name: Run tests
shell: bash -l {0}
run: |
export PYTHONWARNINGS=error
pytest -v
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -6,3 +6,6 @@ Parallelization and optimization of differentiable and many-parameter models

## Documentation
Online documentation is available at [diffopt.readthedocs.io](https://diffopt.readthedocs.io/en/latest).

## Manual Testing
Unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. To run all tests, install `mpi4py` locally, and run `pytest` from the root directory. Additionally, all tests must pass with `mpirun -n 4 pytest` etc. (up to the maximum number of tasks that can be run on your machine).
12 changes: 10 additions & 2 deletions diffopt/multigrad/tests/smf_example/smf_grad_descent.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@
from typing import NamedTuple
from dataclasses import dataclass

from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
import jax.scipy
from jax import numpy as jnp
import numpy as np
@@ -21,11 +24,15 @@ class ParamTuple(NamedTuple):

# Generate fake HMF as power law (truncated so that the SMF has a knee)
def load_halo_masses(num_halos=10_000, slope=-2, mmin=10.0 ** 10, qmax=0.95):
if MPI is None:
size, rank = 1, 0
else:
size, rank = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank
q = jnp.linspace(0, qmax, num_halos)
mhalo = mmin * (1 - q) ** (1/(slope+1))

# Assign different halos to different MPI processes
return np.array_split(mhalo, MPI.COMM_WORLD.size)[MPI.COMM_WORLD.rank]
return np.array_split(mhalo, size)[rank]


# SMF helper functions
@@ -89,6 +96,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None):
parser.add_argument("--learning-rate", type=float, default=1e-3)

if __name__ == "__main__":
assert MPI is not None, "MPI must be installed to run this script"
args = parser.parse_args()
data = dict(
log_halo_masses=jnp.log10(load_halo_masses(args.num_halos)),
14 changes: 8 additions & 6 deletions diffopt/multigrad/tests/test_mpi.py
Original file line number Diff line number Diff line change
@@ -5,26 +5,28 @@
and `mpiexec -n 10 pytest test_mpi.py` all must pass
(the --with-mpi flag shouldn't have any effect)
"""
from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
import jax.numpy as jnp

import unittest
from ... import multigrad
from .smf_example import smf_grad_descent as sgd

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()


@unittest.skipIf(MPI is None, "MPI must be installed to run this test")
def test_reduce_sum():
rank, size = MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size
# Set value equal to the rank of the process
value = jnp.array(rank)

# Reduce the sum of the values across all ranks
result = multigrad.reduce_sum(value)

# Gather the results from all processes
gathered_results = jnp.array(comm.allgather(result))
gathered_results = jnp.array(MPI.COMM_WORLD.allgather(result))

if not rank:
# Perform testing only on the rank 0 process
18 changes: 15 additions & 3 deletions docs/source/multigrad/smf_gradient_descent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from mpi4py import MPI
try:
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD
except ImportError:
MPI = None
COMM_WORLD = None
import jax.scipy
from jax import numpy as jnp
import numpy as np
@@ -7,13 +12,19 @@
from diffopt import multigrad


def load_halo_masses(num_halos=10_000, comm=MPI.COMM_WORLD):
def load_halo_masses(num_halos=10_000, comm=COMM_WORLD):
if comm is None:
size = 1
rank = 0
else:
size = comm.size
rank = comm.rank
# Generate fake halo masses between 10^10 < M_h < 10^11 as a power law
quantile = jnp.linspace(0, 0.9, num_halos)
mhalo = 1e10 / (1 - quantile)

# Assign halos evenly across given MPI ranks (only one rank for now)
return np.array_split(mhalo, comm.size)[comm.rank]
return np.array_split(mhalo, size)[rank]


# Compute one bin of the stellar mass function (SMF)
@@ -92,6 +103,7 @@ def calc_loss_from_sumstats(self, sumstats):


if __name__ == "__main__":
assert MPI is not None, "MPI must be installed to run this script"
volume = 1.0
smf_bin_edges = jnp.linspace(9, 10, 11)
true_params = jnp.array([-2.0, -0.5])