diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml
new file mode 100644
index 0000000..6bdd95e
--- /dev/null
+++ b/.github/workflows/test_package.yml
@@ -0,0 +1,46 @@
+name: Test diffopt
+
+on:
+  workflow_dispatch: null
+  schedule:
+    # Runs "every Monday at noon UTC"
+    - cron: '0 12 * * 1'
+  push:
+    branches:
+      - main
+  pull_request: null
+
+jobs:
+  build:
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: ['3.9', '3.10', '3.11']
+    steps:
+    - uses: actions/checkout@v2
+    - name: Set up Python ${{ matrix.python-version }}
+      uses: conda-incubator/setup-miniconda@v2
+      with:
+        activate-environment: test${{ matrix.python-version}}
+        python-version: ${{ matrix.python-version }}
+        channels: conda-forge,defaults
+        channel-priority: strict
+        show-channel-urls: true
+        miniforge-version: latest
+
+    - name: Install dependencies
+      shell: bash -l {0}
+      run: |
+        conda install -yq jax 
+        conda install -yq pip pytest pytest-cov flake8
+
+    - name: Install package
+      shell: bash -l {0}
+      run: |
+        pip install -e .
+
+    - name: Run tests
+      shell: bash -l {0}
+      run: |
+        export PYTHONWARNINGS=error
+        pytest -v
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..0458016
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,36 @@
+# Contributing to `diffopt`
+
+Thank you for your interest in contributing to this project. All questions and ideas for improvement are welcome and can be made through opening an issue or pull request.
+
+Before contributing, familiarize yourself with our resources:
+
+- [Source Code](https://github.com/AlanPearl/diffopt)
+- [Documentation](https://diffopt.readthedocs.io)
+
+## Issues
+
+You can open an [issue](https://github.com/AlanPearl/diffopt/issues) if you:
+
+- Have encountered a bug or issue when using the software
+- Would like to see a new feature
+- Are seeking support that could not be resolved by reading the documentation
+
+## Pull Requests
+
+If you would like to directly submit your own change to the software, thank you! Here's how:
+
+- Fork [this repository](https://github.com/AlanPearl/diffopt).
+- Please remember to include a concise, self-contained unit test in your pull request. Ensure that all tests pass (see [Manual Testing](#manual-testing)).
+- Open a [pull request](https://github.com/AlanPearl/diffopt/pulls).
+
+## Manual Testing
+
+Make sure you have installed diffopt as described in the [docs](https://diffopt.readthedocs.io/en/latest/installation.html). To run all tests from the main directory:
+
+```bash
+pip install pytest
+pytest .
+mpirun -n 2 pytest .
+```
+
+Note that unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. Therefore, running these tests manually with `mpi4py` installed is necessary to assure that all tests pass.
diff --git a/diffopt/kdescent/descent.py b/diffopt/kdescent/descent.py
index 8069e23..6517feb 100644
--- a/diffopt/kdescent/descent.py
+++ b/diffopt/kdescent/descent.py
@@ -12,7 +12,8 @@
 
 
 def adam(lossfunc, guess, nsteps=100, param_bounds=None,
-         learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs):
+         learning_rate=0.01, randkey=1, const_randkey=False,
+         thin=1, progress=True, **other_kwargs):
     """
     Perform gradient descent
 
@@ -36,6 +37,11 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
     const_randkey : bool, optional
         By default (False), randkey is regenerated at each gradient descent
         iteration. Remove this behavior by setting const_randkey=True
+    thin : int, optional
+        Return parameters for every `thin` iterations, by default 1. Set
+        `thin=0` to only return final parameters
+    progress : bool, optional
+        Display tqdm progress bar, by default True
 
     Returns
     -------
@@ -46,7 +52,7 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
     if param_bounds is None:
         return adam_unbounded(
             lossfunc, guess, nsteps, learning_rate, randkey,
-            const_randkey, **other_kwargs)
+            const_randkey, thin, progress, **other_kwargs)
 
     assert len(guess) == len(param_bounds)
     if hasattr(param_bounds, "tolist"):
@@ -60,14 +66,15 @@ def ulossfunc(uparams, *args, **kwargs):
     init_uparams = apply_transforms(guess, param_bounds)
     uparams = adam_unbounded(
         ulossfunc, init_uparams, nsteps, learning_rate, randkey,
-        const_randkey, **other_kwargs)
+        const_randkey, thin, progress, **other_kwargs)
     params = apply_inverse_transforms(uparams.T, param_bounds).T
 
     return params
 
 
 def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
-                   randkey=1, const_randkey=False, **other_kwargs):
+                   randkey=1, const_randkey=False,
+                   thin=1, progress=True, **other_kwargs):
     kwargs = {**other_kwargs}
     if randkey is not None:
         randkey = keygen.init_randkey(randkey)
@@ -78,13 +85,18 @@ def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
     opt = optax.adam(learning_rate)
     solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps)
     state = solver.init_state(guess, **kwargs)
-    params = [guess]
-    for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"):
+    params = []
+    params_i = guess
+    for i in tqdm.trange(nsteps, disable=not progress,
+                         desc="Adam Gradient Descent Progress"):
         if randkey is not None:
             randkey, key_i = jax.random.split(randkey)
             kwargs["randkey"] = key_i
-        params_i, state = solver.update(params[-1], state, **kwargs)
-        params.append(params_i)
+        params_i, state = solver.update(params_i, state, **kwargs)
+        if i == nsteps - 1 or (thin and i % thin == thin - 1):
+            params.append(params_i)
+    if not thin:
+        params = params[-1]
 
     return jnp.array(params)
 
diff --git a/diffopt/multigrad/adam.py b/diffopt/multigrad/adam.py
index 078c47e..d87bc01 100644
--- a/diffopt/multigrad/adam.py
+++ b/diffopt/multigrad/adam.py
@@ -25,12 +25,12 @@
     N_RANKS = 1
 
 
-def trange_no_tqdm(n, desc=None):
+def trange_no_tqdm(n, desc=None, disable=False):
     return range(n)
 
 
-def trange_with_tqdm(n, desc="Adam Gradient Descent Progress"):
-    return tqdm.trange(n, desc=desc)
+def trange_with_tqdm(n, desc="Adam Gradient Descent Progress", disable=False):
+    return tqdm.trange(n, desc=desc, disable=disable)
 
 
 adam_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
@@ -49,27 +49,32 @@ def _master_wrapper(params, logloss_and_grad_fn, data, randkey=None):
     return loss, grad
 
 
-def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None):
+def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None,
+                    thin=1, progress=True):
     kwargs = {}
     # Note: Might be recommended to use optax instead of jax.example_libraries
     opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
     opt_state = opt_init(params)
 
-    param_steps = [params]
-    for step in adam_trange(nsteps):
+    param_steps = []
+    for step in adam_trange(nsteps, disable=not progress):
         if randkey is not None:
             randkey, key_i = jax.random.split(randkey)
             kwargs["randkey"] = key_i
         _, grad = fn(params, *fn_data, **kwargs)
         opt_state = opt_update(step, grad, opt_state)
         params = get_params(opt_state)
-        param_steps.append(params)
+        if step == nsteps - 1 or (thin and step % thin == thin - 1):
+            param_steps.append(params)
+    if not thin:
+        param_steps = param_steps[-1]
 
     return jnp.array(param_steps)
 
 
 def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
-                       learning_rate=0.01, randkey=None):
+                       learning_rate=0.01, randkey=None,
+                       thin=1, progress=True):
     """Run the adam optimizer on a loss function with a custom gradient.
 
     Parameters
@@ -88,6 +93,11 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
     randkey : int | PRNG Key
         If given, a new PRNG Key will be generated at each iteration and be
         passed to `logloss_and_grad_fn` under the "randkey" kwarg
+    thin : int, optional
+        Return parameters for every `thin` iterations, by default 1. Set
+        `thin=0` to only return final parameters
+    progress : bool, optional
+        Display tqdm progress bar, by default True
 
     Returns
     -------
@@ -104,7 +114,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
         fn_data = (logloss_and_grad_fn, data)
 
         params = _adam_optimizer(params, fn, fn_data, nsteps, learning_rate,
-                                 randkey=randkey)
+                                 randkey=randkey, thin=thin, progress=progress)
 
         if COMM is not None:
             COMM.bcast("exit", root=0)
@@ -131,7 +141,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
 
 
 def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
-             learning_rate=0.01, randkey=None):
+             learning_rate=0.01, randkey=None, thin=1, progress=True):
     """Run the adam optimizer on a loss function with a custom gradient.
 
     Parameters
@@ -153,6 +163,11 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
     randkey : int | PRNG Key
         If given, a new PRNG Key will be generated at each iteration and be
         passed to `logloss_and_grad_fn` under the "randkey" kwarg
+    thin : int, optional
+        Return parameters for every `thin` iterations, by default 1. Set
+        `thin=0` to only return final parameters
+    progress : bool, optional
+        Display tqdm progress bar, by default True
 
     Returns
     -------
@@ -162,7 +177,8 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
     if param_bounds is None:
         return run_adam_unbounded(
             logloss_and_grad_fn, params, data, nsteps=nsteps,
-            learning_rate=learning_rate, randkey=randkey)
+            learning_rate=learning_rate, randkey=randkey,
+            thin=thin, progress=progress)
 
     assert len(params) == len(param_bounds)
     if hasattr(param_bounds, "tolist"):
@@ -182,7 +198,8 @@ def unbound_loss_and_grad(uparams, *args, **kwargs):
 
     uparams = apply_trans(params)
     final_uparams = run_adam_unbounded(
-        unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey)
+        unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey,
+        thin, progress)
 
     if RANK == 0:
         final_params = invert_trans(final_uparams.T).T
diff --git a/diffopt/multigrad/bfgs.py b/diffopt/multigrad/bfgs.py
index 302875e..93d65eb 100644
--- a/diffopt/multigrad/bfgs.py
+++ b/diffopt/multigrad/bfgs.py
@@ -18,19 +18,19 @@
     N_RANKS = 1
 
 
-def trange_no_tqdm(n, desc=None):
+def trange_no_tqdm(n, desc=None, disable=False):
     return range(n)
 
 
-def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress"):
-    return tqdm.trange(n, desc=desc, leave=True)
+def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress", disable=False):
+    return tqdm.trange(n, desc=desc, leave=True, disable=disable)
 
 
 bfgs_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
 
 
 def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
-             randkey=None, comm=COMM):
+             randkey=None, progress=True, comm=COMM):
     """Run the adam optimizer on a loss function with a custom gradient.
 
     Parameters
@@ -46,6 +46,8 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
         `None` as the bound for each unbounded parameter, by default None
     randkey : int | PRNG Key (default=None)
         This will be passed to `logloss_and_grad_fn` under the "randkey" kwarg
+    progress : bool, optional
+        Display tqdm progress bar, by default True
     comm : MPI Communicator (default=COMM_WORLD)
         Communicator between all desired MPI ranks
 
@@ -66,7 +68,7 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
         kwargs["randkey"] = randkey
 
     if comm is None or comm.rank == 0:
-        pbar = bfgs_trange(maxsteps)
+        pbar = bfgs_trange(maxsteps, disable=not progress)
 
         # Wrap loss_and_grad function with commands to the worker ranks
         def loss_and_grad_fn_root(params):
diff --git a/diffopt/multigrad/multigrad.py b/diffopt/multigrad/multigrad.py
index b8b374b..2397c87 100644
--- a/diffopt/multigrad/multigrad.py
+++ b/diffopt/multigrad/multigrad.py
@@ -224,7 +224,8 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None,
 
     # NOTE: Never jit this method because it uses mpi4py
     def run_simple_grad_descent(self: Any, guess,
-                                nsteps=100, learning_rate=0.01):
+                                nsteps=100, learning_rate=0.01,
+                                thin=1, progress=True):
         """
         Descend the gradient with a fixed learning rate to optimize parameters,
         given an initial guess. Stochasticity not allowed.
@@ -237,6 +238,11 @@ def run_simple_grad_descent(self: Any, guess,
             The number of steps to take.
         learning_rate : float (default=0.001)
             The fixed learning rate.
+        thin : int, optional
+            Return parameters for every `thin` iterations, by default 1. Set
+            `thin=0` to only return final parameters
+        progress : bool, optional
+            Display tqdm progress bar, by default True
 
         Returns
         -------
@@ -253,12 +259,14 @@ def run_simple_grad_descent(self: Any, guess,
             learning_rate=learning_rate,
             loss_and_grad_func=self.calc_loss_and_grad_from_params,
             has_aux=False,
+            thin=thin,
+            progress=progress
         )
 
     # NOTE: Never jit this method because it uses mpi4py
     def run_adam(self: Any, guess, nsteps=100, param_bounds=None,
                  learning_rate=0.01, randkey=None, const_randkey=False,
-                 comm=None):
+                 thin=1, progress=True, comm=None):
         """
         Run adam to descend the gradient and optimize the model parameters,
         given an initial guess. Stochasticity is allowed if randkey is passed.
@@ -280,6 +288,11 @@ def run_adam(self: Any, guess, nsteps=100, param_bounds=None,
         const_randkey : bool (default=False)
             By default, randkey is regenerated at each gradient descent
             iteration. Remove this behavior by setting const_randkey=True
+        thin : int, optional
+            Return parameters for every `thin` iterations, by default 1. Set
+            `thin=0` to only return final parameters
+        progress : bool, optional
+            Display tqdm progress bar, by default True
 
         Returns
         -------
@@ -301,14 +314,14 @@ def loss_and_grad_fn(x, _, **kw):
         params_steps = run_adam(
             loss_and_grad_fn, params=guess, data=None, nsteps=nsteps,
             param_bounds=param_bounds, learning_rate=learning_rate,
-            randkey=randkey
+            randkey=randkey, thin=thin, progress=progress
         )
 
         return jnp.asarray(comm.bcast(params_steps, root=0))
 
     # NOTE: Never jit this method because it uses mpi4py
     def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
-                 randkey=None, comm=None):
+                 randkey=None, progress=True, comm=None):
         """
         Run BFGS to descend the gradient and optimize the model parameters,
         given an initial guess. Stochasticity must be held fixed via a random
@@ -327,6 +340,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
             Since BFGS requires a deterministic function, this key will be
             passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg
             as a constant at every iteration
+        progress : bool, optional
+            Display tqdm progress bar, by default True
 
         Returns
         -------
@@ -349,7 +364,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
         comm = self.comm if comm is None else comm
         return run_bfgs(
             self.calc_loss_and_grad_from_params, guess, maxsteps=maxsteps,
-            param_bounds=param_bounds, randkey=randkey, comm=comm)
+            param_bounds=param_bounds, randkey=randkey,
+            progress=progress, comm=comm)
 
     def run_lhs_param_scan(self, xmins, xmaxs, n_dim,
                            num_evaluations, seed=None, randkey=None):
@@ -581,22 +597,27 @@ def calc_loss_and_grad_from_params(self, params):
 
     # NOTE: Never jit this method because it uses mpi4py
     def run_simple_grad_descent(self, guess,
-                                nsteps=100, learning_rate=0.01):
+                                nsteps=100, learning_rate=0.01,
+                                thin=1, progress=True):
         return OnePointModel.run_simple_grad_descent(
-            self, guess, nsteps, learning_rate)
+            self, guess, nsteps, learning_rate, thin=thin, progress=progress)
 
     # NOTE: Never jit this method because it uses mpi4py
-    def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None):
+    def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None,
+                 progress=True):
         return OnePointModel.run_bfgs(
             self, guess, maxsteps, param_bounds=param_bounds,
-            randkey=randkey, comm=self.main_comm)
+            randkey=randkey, progress=progress,
+            comm=self.main_comm)
 
     # NOTE: Never jit this method because it uses mpi4py
     def run_adam(self, guess, nsteps=100, param_bounds=None,
-                 learning_rate=0.01, randkey=None, const_randkey=False):
+                 learning_rate=0.01, randkey=None, const_randkey=False,
+                 thin=1, progress=True):
         return OnePointModel.run_adam(
             self, guess, nsteps, param_bounds, learning_rate, randkey,
-            const_randkey=const_randkey, comm=self.main_comm)
+            const_randkey=const_randkey, thin=thin, progress=progress,
+            comm=self.main_comm)
 
     def __hash__(self):
         if isinstance(self.models, OnePointModel):
diff --git a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py
index 173d603..91e2af8 100644
--- a/diffopt/multigrad/tests/smf_example/smf_grad_descent.py
+++ b/diffopt/multigrad/tests/smf_example/smf_grad_descent.py
@@ -5,7 +5,10 @@
 from typing import NamedTuple
 from dataclasses import dataclass
 
-from mpi4py import MPI
+try:
+    from mpi4py import MPI
+except ImportError:
+    MPI = None
 import jax.scipy
 from jax import numpy as jnp
 import numpy as np
@@ -21,11 +24,15 @@ class ParamTuple(NamedTuple):
 
 # Generate fake HMF as power law (truncated so that the SMF has a knee)
 def load_halo_masses(num_halos=10_000, slope=-2, mmin=10.0 ** 10, qmax=0.95):
+    if MPI is None:
+        size, rank = 1, 0
+    else:
+        size, rank = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank
     q = jnp.linspace(0, qmax, num_halos)
     mhalo = mmin * (1 - q) ** (1/(slope+1))
 
     # Assign different halos to different MPI processes
-    return np.array_split(mhalo, MPI.COMM_WORLD.size)[MPI.COMM_WORLD.rank]
+    return np.array_split(mhalo, size)[rank]
 
 
 # SMF helper functions
@@ -89,6 +96,7 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None):
 parser.add_argument("--learning-rate", type=float, default=1e-3)
 
 if __name__ == "__main__":
+    assert MPI is not None, "MPI must be installed to run this script"
     args = parser.parse_args()
     data = dict(
         log_halo_masses=jnp.log10(load_halo_masses(args.num_halos)),
diff --git a/diffopt/multigrad/tests/test_mpi.py b/diffopt/multigrad/tests/test_mpi.py
index efc0102..b8d2f7f 100644
--- a/diffopt/multigrad/tests/test_mpi.py
+++ b/diffopt/multigrad/tests/test_mpi.py
@@ -5,18 +5,20 @@
 and `mpiexec -n 10 pytest test_mpi.py` all must pass
 (the --with-mpi flag shouldn't have any effect)
 """
-from mpi4py import MPI
+try:
+    from mpi4py import MPI
+except ImportError:
+    MPI = None
 import jax.numpy as jnp
 
+import unittest
 from ... import multigrad
 from .smf_example import smf_grad_descent as sgd
 
-comm = MPI.COMM_WORLD
-rank = comm.Get_rank()
-size = comm.Get_size()
-
 
+@unittest.skipIf(MPI is None, "MPI must be installed to run this test")
 def test_reduce_sum():
+    rank, size = MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size
     # Set value equal to the rank of the process
     value = jnp.array(rank)
 
@@ -24,7 +26,7 @@ def test_reduce_sum():
     result = multigrad.reduce_sum(value)
 
     # Gather the results from all processes
-    gathered_results = jnp.array(comm.allgather(result))
+    gathered_results = jnp.array(MPI.COMM_WORLD.allgather(result))
 
     if not rank:
         # Perform testing only on the rank 0 process
@@ -58,7 +60,7 @@ def test_simple_grad_descent_pipeline():
     gd_loss, gd_params = gd_iterations.loss, gd_iterations.params
     assert jnp.isclose(gd_loss[-1], 0.0)
     assert jnp.allclose(gd_params[-1], jnp.array([*truth]))
-    assert jnp.allclose(true_gradloss, 0.0, atol=1e-5)
+    assert jnp.allclose(true_gradloss, 0.0, atol=1e-4)
 
     # Calculate grad(loss) with the more memory efficient method
     loss, dloss_dparams = model.calc_loss_and_grad_from_params(truth)
diff --git a/diffopt/multigrad/util.py b/diffopt/multigrad/util.py
index c39a161..9d4cf4c 100644
--- a/diffopt/multigrad/util.py
+++ b/diffopt/multigrad/util.py
@@ -36,12 +36,12 @@
            "latin_hypercube_sampler", "scatter_nd"]
 
 
-def trange_no_tqdm(n, desc=None):
+def trange_no_tqdm(n, desc=None, disable=False):
     return range(n)
 
 
-def trange_with_tqdm(n, desc=None):
-    return tqdm.trange(n, desc=desc)
+def trange_with_tqdm(n, desc=None, disable=False):
+    return tqdm.trange(n, desc=desc, disable=disable)
 
 
 trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
@@ -85,6 +85,8 @@ def simple_grad_descent(
     loss_and_grad_func=None,
     grad_loss_func=None,
     has_aux=False,
+    thin=1,
+    progress=True,
     **kwargs,
 ):
     if loss_and_grad_func is None:
@@ -115,13 +117,19 @@ def loopfunc(state, _x):
 
     # The below is equivalent to lax.scan without jitting
     # ===================================================
-    initstate = (0.0, guess)
+    state = (0.0, guess)
     loss, params, aux = [], [], []
-    for x in trange(nsteps, desc="Simple Gradient Descent Progress"):
-        initstate, y = loopfunc(initstate, x)
-        loss.append(y[0])
-        params.append(y[1])
-        aux.append(y[2])
+    for x in trange(nsteps, desc="Simple Gradient Descent Progress",
+                    disable=not progress):
+        state, y = loopfunc(state, x)
+        if x == nsteps - 1 or (thin and x % thin == thin - 1):
+            loss.append(y[0])
+            params.append(y[1])
+            aux.append(y[2])
+    if not thin:
+        loss = loss[-1]
+        params = params[-1]
+        aux = aux[-1]
     loss = jnp.array(loss)
     params = jnp.array(params)
     if has_aux:
diff --git a/diffopt/multiswarm/pso_update.py b/diffopt/multiswarm/pso_update.py
index f08f491..0db1588 100644
--- a/diffopt/multiswarm/pso_update.py
+++ b/diffopt/multiswarm/pso_update.py
@@ -95,7 +95,8 @@ def __init__(self, nparticles, ndim, xlow, xhigh, seed=0,
         self.social_weight = social_weight
         self.vmax_frac = vmax_frac
 
-    def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False):
+    def run_pso(self, lossfunc, nsteps=100, progress=True,
+                keep_init_random_state=False):
         """
         Run particle swarm optimization (PSO)
 
@@ -106,6 +107,8 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False):
             with signature `lossfunc(x)` where x is an array of shape `(ndim,)`
         nsteps : int, optional
             Number of time step iterations, by default 100
+        progress : bool, optional
+            Display tqdm progress bar, by default True
         keep_init_random_state : bool, optional
             Set True to be able to rerun an identical run, or False (default)
             to continue a run by manually setting swarm.x_init and swarm.v_init
@@ -140,12 +143,12 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False):
         loc_loss_history = [[] for _ in range(self.num_particles_on_this_rank)]
         start = time()
 
-        def trange(x):
+        def trange(x, disable=False):
             if self.comm.rank:
                 return range(x)
             else:
-                return tqdm.trange(x, desc="PSO Progress")
-        for _ in trange(nsteps):
+                return tqdm.trange(x, desc="PSO Progress", disable=disable)
+        for _ in trange(nsteps, disable=not progress):
             istep_loss = [None for _ in range(self.num_particles_on_this_rank)]
             for ip in range(self.num_particles_on_this_rank):
                 update_key = jran.split(particle_keys[ip], 1)[0]
diff --git a/docs/requirements.txt b/docs/requirements.txt
index aa7d3c2..33ee71c 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,5 +1,6 @@
 sphinx_rtd_theme
 nbsphinx
+myst_parser
 IPython
 matplotlib
 seaborn
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 39e2718..174274d 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -25,6 +25,7 @@
     "sphinx.ext.napoleon",
     "sphinx.ext.viewcode",
     "nbsphinx",
+    "myst_parser",
 ]
 
 templates_path = ['_templates']
diff --git a/docs/source/include_contributing.rst b/docs/source/include_contributing.rst
new file mode 100644
index 0000000..7a661a3
--- /dev/null
+++ b/docs/source/include_contributing.rst
@@ -0,0 +1,2 @@
+.. include:: ../../CONTRIBUTING.md
+   :parser: myst_parser.sphinx_
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ee233cf..39344d5 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -20,6 +20,7 @@ Overview
 --------
 
 * :doc:`installation`
+* :doc:`include_contributing`
 * :doc:`reference`
 
 :mod:`multigrad`
@@ -55,6 +56,7 @@ Indices and tables
    :hidden:
 
    installation.rst
+   include_contributing.md
    reference.rst
 
 
diff --git a/docs/source/multigrad/smf_gradient_descent.py b/docs/source/multigrad/smf_gradient_descent.py
index c31f739..f90e9f8 100644
--- a/docs/source/multigrad/smf_gradient_descent.py
+++ b/docs/source/multigrad/smf_gradient_descent.py
@@ -1,4 +1,9 @@
-from mpi4py import MPI
+try:
+    from mpi4py import MPI
+    COMM_WORLD = MPI.COMM_WORLD
+except ImportError:
+    MPI = None
+    COMM_WORLD = None
 import jax.scipy
 from jax import numpy as jnp
 import numpy as np
@@ -7,13 +12,19 @@
 from diffopt import multigrad
 
 
-def load_halo_masses(num_halos=10_000, comm=MPI.COMM_WORLD):
+def load_halo_masses(num_halos=10_000, comm=COMM_WORLD):
+    if comm is None:
+        size = 1
+        rank = 0
+    else:
+        size = comm.size
+        rank = comm.rank
     # Generate fake halo masses between 10^10 < M_h < 10^11 as a power law
     quantile = jnp.linspace(0, 0.9, num_halos)
     mhalo = 1e10 / (1 - quantile)
 
     # Assign halos evenly across given MPI ranks (only one rank for now)
-    return np.array_split(mhalo, comm.size)[comm.rank]
+    return np.array_split(mhalo, size)[rank]
 
 
 # Compute one bin of the stellar mass function (SMF)
@@ -92,6 +103,7 @@ def calc_loss_from_sumstats(self, sumstats):
 
 
 if __name__ == "__main__":
+    assert MPI is not None, "MPI must be installed to run this script"
     volume = 1.0
     smf_bin_edges = jnp.linspace(9, 10, 11)
     true_params = jnp.array([-2.0, -0.5])