diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml new file mode 100644 index 0000000..6bdd95e --- /dev/null +++ b/.github/workflows/test_package.yml @@ -0,0 +1,46 @@ +name: Test diffopt + +on: + workflow_dispatch: null + schedule: + # Runs "every Monday at noon UTC" + - cron: '0 12 * * 1' + push: + branches: + - main + pull_request: null + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: test${{ matrix.python-version}} + python-version: ${{ matrix.python-version }} + channels: conda-forge,defaults + channel-priority: strict + show-channel-urls: true + miniforge-version: latest + + - name: Install dependencies + shell: bash -l {0} + run: | + conda install -yq jax + conda install -yq pip pytest pytest-cov flake8 + + - name: Install package + shell: bash -l {0} + run: | + pip install -e . + + - name: Run tests + shell: bash -l {0} + run: | + export PYTHONWARNINGS=error + pytest -v diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..0458016 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to `diffopt` + +Thank you for your interest in contributing to this project. All questions and ideas for improvement are welcome and can be made through opening an issue or pull request. + +Before contributing, familiarize yourself with our resources: + +- [Source Code](https://github.com/AlanPearl/diffopt) +- [Documentation](https://diffopt.readthedocs.io) + +## Issues + +You can open an [issue](https://github.com/AlanPearl/diffopt/issues) if you: + +- Have encountered a bug or issue when using the software +- Would like to see a new feature +- Are seeking support that could not be resolved by reading the documentation + +## Pull Requests + +If you would like to directly submit your own change to the software, thank you! Here's how: + +- Fork [this repository](https://github.com/AlanPearl/diffopt). +- Please remember to include a concise, self-contained unit test in your pull request. Ensure that all tests pass (see [Manual Testing](#manual-testing)). +- Open a [pull request](https://github.com/AlanPearl/diffopt/pulls). + +## Manual Testing + +Make sure you have installed diffopt as described in the [docs](https://diffopt.readthedocs.io/en/latest/installation.html). To run all tests from the main directory: + +```bash +pip install pytest +pytest . +mpirun -n 2 pytest . +``` + +Note that unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. Therefore, running these tests manually with `mpi4py` installed is necessary to assure that all tests pass. diff --git a/diffopt/kdescent/descent.py b/diffopt/kdescent/descent.py index 8069e23..6517feb 100644 --- a/diffopt/kdescent/descent.py +++ b/diffopt/kdescent/descent.py @@ -12,7 +12,8 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs): + learning_rate=0.01, randkey=1, const_randkey=False, + thin=1, progress=True, **other_kwargs): """ Perform gradient descent @@ -36,6 +37,11 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, const_randkey : bool, optional By default (False), randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -46,7 +52,7 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, if param_bounds is None: return adam_unbounded( lossfunc, guess, nsteps, learning_rate, randkey, - const_randkey, **other_kwargs) + const_randkey, thin, progress, **other_kwargs) assert len(guess) == len(param_bounds) if hasattr(param_bounds, "tolist"): @@ -60,14 +66,15 @@ def ulossfunc(uparams, *args, **kwargs): init_uparams = apply_transforms(guess, param_bounds) uparams = adam_unbounded( ulossfunc, init_uparams, nsteps, learning_rate, randkey, - const_randkey, **other_kwargs) + const_randkey, thin, progress, **other_kwargs) params = apply_inverse_transforms(uparams.T, param_bounds).T return params def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, - randkey=1, const_randkey=False, **other_kwargs): + randkey=1, const_randkey=False, + thin=1, progress=True, **other_kwargs): kwargs = {**other_kwargs} if randkey is not None: randkey = keygen.init_randkey(randkey) @@ -78,13 +85,18 @@ def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, opt = optax.adam(learning_rate) solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps) state = solver.init_state(guess, **kwargs) - params = [guess] - for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"): + params = [] + params_i = guess + for i in tqdm.trange(nsteps, disable=not progress, + desc="Adam Gradient Descent Progress"): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i - params_i, state = solver.update(params[-1], state, **kwargs) - params.append(params_i) + params_i, state = solver.update(params_i, state, **kwargs) + if i == nsteps - 1 or (thin and i % thin == thin - 1): + params.append(params_i) + if not thin: + params = params[-1] return jnp.array(params) diff --git a/diffopt/multigrad/adam.py b/diffopt/multigrad/adam.py index 078c47e..d87bc01 100644 --- a/diffopt/multigrad/adam.py +++ b/diffopt/multigrad/adam.py @@ -25,12 +25,12 @@ N_RANKS = 1 -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc="Adam Gradient Descent Progress"): - return tqdm.trange(n, desc=desc) +def trange_with_tqdm(n, desc="Adam Gradient Descent Progress", disable=False): + return tqdm.trange(n, desc=desc, disable=disable) adam_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm @@ -49,27 +49,32 @@ def _master_wrapper(params, logloss_and_grad_fn, data, randkey=None): return loss, grad -def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None): +def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None, + thin=1, progress=True): kwargs = {} # Note: Might be recommended to use optax instead of jax.example_libraries opt_init, opt_update, get_params = jax_opt.adam(learning_rate) opt_state = opt_init(params) - param_steps = [params] - for step in adam_trange(nsteps): + param_steps = [] + for step in adam_trange(nsteps, disable=not progress): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i _, grad = fn(params, *fn_data, **kwargs) opt_state = opt_update(step, grad, opt_state) params = get_params(opt_state) - param_steps.append(params) + if step == nsteps - 1 or (thin and step % thin == thin - 1): + param_steps.append(params) + if not thin: + param_steps = param_steps[-1] return jnp.array(param_steps) def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, - learning_rate=0.01, randkey=None): + learning_rate=0.01, randkey=None, + thin=1, progress=True): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -88,6 +93,11 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, randkey : int | PRNG Key If given, a new PRNG Key will be generated at each iteration and be passed to `logloss_and_grad_fn` under the "randkey" kwarg + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -104,7 +114,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, fn_data = (logloss_and_grad_fn, data) params = _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, - randkey=randkey) + randkey=randkey, thin=thin, progress=progress) if COMM is not None: COMM.bcast("exit", root=0) @@ -131,7 +141,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=None): + learning_rate=0.01, randkey=None, thin=1, progress=True): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -153,6 +163,11 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, randkey : int | PRNG Key If given, a new PRNG Key will be generated at each iteration and be passed to `logloss_and_grad_fn` under the "randkey" kwarg + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -162,7 +177,8 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, if param_bounds is None: return run_adam_unbounded( logloss_and_grad_fn, params, data, nsteps=nsteps, - learning_rate=learning_rate, randkey=randkey) + learning_rate=learning_rate, randkey=randkey, + thin=thin, progress=progress) assert len(params) == len(param_bounds) if hasattr(param_bounds, "tolist"): @@ -182,7 +198,8 @@ def unbound_loss_and_grad(uparams, *args, **kwargs): uparams = apply_trans(params) final_uparams = run_adam_unbounded( - unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey) + unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey, + thin, progress) if RANK == 0: final_params = invert_trans(final_uparams.T).T diff --git a/diffopt/multigrad/bfgs.py b/diffopt/multigrad/bfgs.py index 302875e..93d65eb 100644 --- a/diffopt/multigrad/bfgs.py +++ b/diffopt/multigrad/bfgs.py @@ -18,19 +18,19 @@ N_RANKS = 1 -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress"): - return tqdm.trange(n, desc=desc, leave=True) +def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress", disable=False): + return tqdm.trange(n, desc=desc, leave=True, disable=disable) bfgs_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, - randkey=None, comm=COMM): + randkey=None, progress=True, comm=COMM): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -46,6 +46,8 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, `None` as the bound for each unbounded parameter, by default None randkey : int | PRNG Key (default=None) This will be passed to `logloss_and_grad_fn` under the "randkey" kwarg + progress : bool, optional + Display tqdm progress bar, by default True comm : MPI Communicator (default=COMM_WORLD) Communicator between all desired MPI ranks @@ -66,7 +68,7 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, kwargs["randkey"] = randkey if comm is None or comm.rank == 0: - pbar = bfgs_trange(maxsteps) + pbar = bfgs_trange(maxsteps, disable=not progress) # Wrap loss_and_grad function with commands to the worker ranks def loss_and_grad_fn_root(params): diff --git a/diffopt/multigrad/multigrad.py b/diffopt/multigrad/multigrad.py index b8b374b..2397c87 100644 --- a/diffopt/multigrad/multigrad.py +++ b/diffopt/multigrad/multigrad.py @@ -224,7 +224,8 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None, # NOTE: Never jit this method because it uses mpi4py def run_simple_grad_descent(self: Any, guess, - nsteps=100, learning_rate=0.01): + nsteps=100, learning_rate=0.01, + thin=1, progress=True): """ Descend the gradient with a fixed learning rate to optimize parameters, given an initial guess. Stochasticity not allowed. @@ -237,6 +238,11 @@ def run_simple_grad_descent(self: Any, guess, The number of steps to take. learning_rate : float (default=0.001) The fixed learning rate. + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -253,12 +259,14 @@ def run_simple_grad_descent(self: Any, guess, learning_rate=learning_rate, loss_and_grad_func=self.calc_loss_and_grad_from_params, has_aux=False, + thin=thin, + progress=progress ) # NOTE: Never jit this method because it uses mpi4py def run_adam(self: Any, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, - comm=None): + thin=1, progress=True, comm=None): """ Run adam to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity is allowed if randkey is passed. @@ -280,6 +288,11 @@ def run_adam(self: Any, guess, nsteps=100, param_bounds=None, const_randkey : bool (default=False) By default, randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -301,14 +314,14 @@ def loss_and_grad_fn(x, _, **kw): params_steps = run_adam( loss_and_grad_fn, params=guess, data=None, nsteps=nsteps, param_bounds=param_bounds, learning_rate=learning_rate, - randkey=randkey + randkey=randkey, thin=thin, progress=progress ) return jnp.asarray(comm.bcast(params_steps, root=0)) # NOTE: Never jit this method because it uses mpi4py def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, - randkey=None, comm=None): + randkey=None, progress=True, comm=None): """ Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random @@ -327,6 +340,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, Since BFGS requires a deterministic function, this key will be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg as a constant at every iteration + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -349,7 +364,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, comm = self.comm if comm is None else comm return run_bfgs( self.calc_loss_and_grad_from_params, guess, maxsteps=maxsteps, - param_bounds=param_bounds, randkey=randkey, comm=comm) + param_bounds=param_bounds, randkey=randkey, + progress=progress, comm=comm) def run_lhs_param_scan(self, xmins, xmaxs, n_dim, num_evaluations, seed=None, randkey=None): @@ -581,22 +597,27 @@ def calc_loss_and_grad_from_params(self, params): # NOTE: Never jit this method because it uses mpi4py def run_simple_grad_descent(self, guess, - nsteps=100, learning_rate=0.01): + nsteps=100, learning_rate=0.01, + thin=1, progress=True): return OnePointModel.run_simple_grad_descent( - self, guess, nsteps, learning_rate) + self, guess, nsteps, learning_rate, thin=thin, progress=progress) # NOTE: Never jit this method because it uses mpi4py - def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None): + def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None, + progress=True): return OnePointModel.run_bfgs( self, guess, maxsteps, param_bounds=param_bounds, - randkey=randkey, comm=self.main_comm) + randkey=randkey, progress=progress, + comm=self.main_comm) # NOTE: Never jit this method because it uses mpi4py def run_adam(self, guess, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=None, const_randkey=False): + learning_rate=0.01, randkey=None, const_randkey=False, + thin=1, progress=True): return OnePointModel.run_adam( self, guess, nsteps, param_bounds, learning_rate, randkey, - const_randkey=const_randkey, comm=self.main_comm) + const_randkey=const_randkey, thin=thin, progress=progress, + comm=self.main_comm) def __hash__(self): if isinstance(self.models, OnePointModel): diff --git a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py index 173d603..91e2af8 100644 --- a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py +++ b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py @@ -5,7 +5,10 @@ from typing import NamedTuple from dataclasses import dataclass -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None import jax.scipy from jax import numpy as jnp import numpy as np @@ -21,11 +24,15 @@ class ParamTuple(NamedTuple): # Generate fake HMF as power law (truncated so that the SMF has a knee) def load_halo_masses(num_halos=10_000, slope=-2, mmin=10.0 ** 10, qmax=0.95): + if MPI is None: + size, rank = 1, 0 + else: + size, rank = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank q = jnp.linspace(0, qmax, num_halos) mhalo = mmin * (1 - q) ** (1/(slope+1)) # Assign different halos to different MPI processes - return np.array_split(mhalo, MPI.COMM_WORLD.size)[MPI.COMM_WORLD.rank] + return np.array_split(mhalo, size)[rank] # SMF helper functions @@ -89,6 +96,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None): parser.add_argument("--learning-rate", type=float, default=1e-3) if __name__ == "__main__": + assert MPI is not None, "MPI must be installed to run this script" args = parser.parse_args() data = dict( log_halo_masses=jnp.log10(load_halo_masses(args.num_halos)), diff --git a/diffopt/multigrad/tests/test_mpi.py b/diffopt/multigrad/tests/test_mpi.py index efc0102..b8d2f7f 100644 --- a/diffopt/multigrad/tests/test_mpi.py +++ b/diffopt/multigrad/tests/test_mpi.py @@ -5,18 +5,20 @@ and `mpiexec -n 10 pytest test_mpi.py` all must pass (the --with-mpi flag shouldn't have any effect) """ -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None import jax.numpy as jnp +import unittest from ... import multigrad from .smf_example import smf_grad_descent as sgd -comm = MPI.COMM_WORLD -rank = comm.Get_rank() -size = comm.Get_size() - +@unittest.skipIf(MPI is None, "MPI must be installed to run this test") def test_reduce_sum(): + rank, size = MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size # Set value equal to the rank of the process value = jnp.array(rank) @@ -24,7 +26,7 @@ def test_reduce_sum(): result = multigrad.reduce_sum(value) # Gather the results from all processes - gathered_results = jnp.array(comm.allgather(result)) + gathered_results = jnp.array(MPI.COMM_WORLD.allgather(result)) if not rank: # Perform testing only on the rank 0 process @@ -58,7 +60,7 @@ def test_simple_grad_descent_pipeline(): gd_loss, gd_params = gd_iterations.loss, gd_iterations.params assert jnp.isclose(gd_loss[-1], 0.0) assert jnp.allclose(gd_params[-1], jnp.array([*truth])) - assert jnp.allclose(true_gradloss, 0.0, atol=1e-5) + assert jnp.allclose(true_gradloss, 0.0, atol=1e-4) # Calculate grad(loss) with the more memory efficient method loss, dloss_dparams = model.calc_loss_and_grad_from_params(truth) diff --git a/diffopt/multigrad/util.py b/diffopt/multigrad/util.py index c39a161..9d4cf4c 100644 --- a/diffopt/multigrad/util.py +++ b/diffopt/multigrad/util.py @@ -36,12 +36,12 @@ "latin_hypercube_sampler", "scatter_nd"] -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc=None): - return tqdm.trange(n, desc=desc) +def trange_with_tqdm(n, desc=None, disable=False): + return tqdm.trange(n, desc=desc, disable=disable) trange = trange_no_tqdm if tqdm is None else trange_with_tqdm @@ -85,6 +85,8 @@ def simple_grad_descent( loss_and_grad_func=None, grad_loss_func=None, has_aux=False, + thin=1, + progress=True, **kwargs, ): if loss_and_grad_func is None: @@ -115,13 +117,19 @@ def loopfunc(state, _x): # The below is equivalent to lax.scan without jitting # =================================================== - initstate = (0.0, guess) + state = (0.0, guess) loss, params, aux = [], [], [] - for x in trange(nsteps, desc="Simple Gradient Descent Progress"): - initstate, y = loopfunc(initstate, x) - loss.append(y[0]) - params.append(y[1]) - aux.append(y[2]) + for x in trange(nsteps, desc="Simple Gradient Descent Progress", + disable=not progress): + state, y = loopfunc(state, x) + if x == nsteps - 1 or (thin and x % thin == thin - 1): + loss.append(y[0]) + params.append(y[1]) + aux.append(y[2]) + if not thin: + loss = loss[-1] + params = params[-1] + aux = aux[-1] loss = jnp.array(loss) params = jnp.array(params) if has_aux: diff --git a/diffopt/multiswarm/pso_update.py b/diffopt/multiswarm/pso_update.py index f08f491..0db1588 100644 --- a/diffopt/multiswarm/pso_update.py +++ b/diffopt/multiswarm/pso_update.py @@ -95,7 +95,8 @@ def __init__(self, nparticles, ndim, xlow, xhigh, seed=0, self.social_weight = social_weight self.vmax_frac = vmax_frac - def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): + def run_pso(self, lossfunc, nsteps=100, progress=True, + keep_init_random_state=False): """ Run particle swarm optimization (PSO) @@ -106,6 +107,8 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): with signature `lossfunc(x)` where x is an array of shape `(ndim,)` nsteps : int, optional Number of time step iterations, by default 100 + progress : bool, optional + Display tqdm progress bar, by default True keep_init_random_state : bool, optional Set True to be able to rerun an identical run, or False (default) to continue a run by manually setting swarm.x_init and swarm.v_init @@ -140,12 +143,12 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): loc_loss_history = [[] for _ in range(self.num_particles_on_this_rank)] start = time() - def trange(x): + def trange(x, disable=False): if self.comm.rank: return range(x) else: - return tqdm.trange(x, desc="PSO Progress") - for _ in trange(nsteps): + return tqdm.trange(x, desc="PSO Progress", disable=disable) + for _ in trange(nsteps, disable=not progress): istep_loss = [None for _ in range(self.num_particles_on_this_rank)] for ip in range(self.num_particles_on_this_rank): update_key = jran.split(particle_keys[ip], 1)[0] diff --git a/docs/requirements.txt b/docs/requirements.txt index aa7d3c2..33ee71c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ sphinx_rtd_theme nbsphinx +myst_parser IPython matplotlib seaborn diff --git a/docs/source/conf.py b/docs/source/conf.py index 39e2718..174274d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ "sphinx.ext.napoleon", "sphinx.ext.viewcode", "nbsphinx", + "myst_parser", ] templates_path = ['_templates'] diff --git a/docs/source/include_contributing.rst b/docs/source/include_contributing.rst new file mode 100644 index 0000000..7a661a3 --- /dev/null +++ b/docs/source/include_contributing.rst @@ -0,0 +1,2 @@ +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ diff --git a/docs/source/index.rst b/docs/source/index.rst index ee233cf..39344d5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ Overview -------- * :doc:`installation` +* :doc:`include_contributing` * :doc:`reference` :mod:`multigrad` @@ -55,6 +56,7 @@ Indices and tables :hidden: installation.rst + include_contributing.md reference.rst diff --git a/docs/source/multigrad/smf_gradient_descent.py b/docs/source/multigrad/smf_gradient_descent.py index c31f739..f90e9f8 100644 --- a/docs/source/multigrad/smf_gradient_descent.py +++ b/docs/source/multigrad/smf_gradient_descent.py @@ -1,4 +1,9 @@ -from mpi4py import MPI +try: + from mpi4py import MPI + COMM_WORLD = MPI.COMM_WORLD +except ImportError: + MPI = None + COMM_WORLD = None import jax.scipy from jax import numpy as jnp import numpy as np @@ -7,13 +12,19 @@ from diffopt import multigrad -def load_halo_masses(num_halos=10_000, comm=MPI.COMM_WORLD): +def load_halo_masses(num_halos=10_000, comm=COMM_WORLD): + if comm is None: + size = 1 + rank = 0 + else: + size = comm.size + rank = comm.rank # Generate fake halo masses between 10^10 < M_h < 10^11 as a power law quantile = jnp.linspace(0, 0.9, num_halos) mhalo = 1e10 / (1 - quantile) # Assign halos evenly across given MPI ranks (only one rank for now) - return np.array_split(mhalo, comm.size)[comm.rank] + return np.array_split(mhalo, size)[rank] # Compute one bin of the stellar mass function (SMF) @@ -92,6 +103,7 @@ def calc_loss_from_sumstats(self, sumstats): if __name__ == "__main__": + assert MPI is not None, "MPI must be installed to run this script" volume = 1.0 smf_bin_edges = jnp.linspace(9, 10, 11) true_params = jnp.array([-2.0, -0.5])