Skip to content

Commit 1c86f15

Browse files
authored
Merge software changes into paper branch (#6)
* Add thin and progress kwargs to fitters * Create test_package workflow (#4) * Create test_package workflow * Skip unit tests requiring mpi4py if not installed * Add manual testing instructions to README * Create CONTRIBUTING.md and move testing instructions here * Link to contributing guidelines in docs
1 parent a2b6e83 commit 1c86f15

File tree

15 files changed

+234
-61
lines changed

15 files changed

+234
-61
lines changed

.github/workflows/test_package.yml

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Test diffopt
2+
3+
on:
4+
workflow_dispatch: null
5+
schedule:
6+
# Runs "every Monday at noon UTC"
7+
- cron: '0 12 * * 1'
8+
push:
9+
branches:
10+
- main
11+
pull_request: null
12+
13+
jobs:
14+
build:
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ['3.9', '3.10', '3.11']
19+
steps:
20+
- uses: actions/checkout@v2
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: conda-incubator/setup-miniconda@v2
23+
with:
24+
activate-environment: test${{ matrix.python-version}}
25+
python-version: ${{ matrix.python-version }}
26+
channels: conda-forge,defaults
27+
channel-priority: strict
28+
show-channel-urls: true
29+
miniforge-version: latest
30+
31+
- name: Install dependencies
32+
shell: bash -l {0}
33+
run: |
34+
conda install -yq jax
35+
conda install -yq pip pytest pytest-cov flake8
36+
37+
- name: Install package
38+
shell: bash -l {0}
39+
run: |
40+
pip install -e .
41+
42+
- name: Run tests
43+
shell: bash -l {0}
44+
run: |
45+
export PYTHONWARNINGS=error
46+
pytest -v

CONTRIBUTING.md

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Contributing to `diffopt`
2+
3+
Thank you for your interest in contributing to this project. All questions and ideas for improvement are welcome and can be made through opening an issue or pull request.
4+
5+
Before contributing, familiarize yourself with our resources:
6+
7+
- [Source Code](https://github.com/AlanPearl/diffopt)
8+
- [Documentation](https://diffopt.readthedocs.io)
9+
10+
## Issues
11+
12+
You can open an [issue](https://github.com/AlanPearl/diffopt/issues) if you:
13+
14+
- Have encountered a bug or issue when using the software
15+
- Would like to see a new feature
16+
- Are seeking support that could not be resolved by reading the documentation
17+
18+
## Pull Requests
19+
20+
If you would like to directly submit your own change to the software, thank you! Here's how:
21+
22+
- Fork [this repository](https://github.com/AlanPearl/diffopt).
23+
- Please remember to include a concise, self-contained unit test in your pull request. Ensure that all tests pass (see [Manual Testing](#manual-testing)).
24+
- Open a [pull request](https://github.com/AlanPearl/diffopt/pulls).
25+
26+
## Manual Testing
27+
28+
Make sure you have installed diffopt as described in the [docs](https://diffopt.readthedocs.io/en/latest/installation.html). To run all tests from the main directory:
29+
30+
```bash
31+
pip install pytest
32+
pytest .
33+
mpirun -n 2 pytest .
34+
```
35+
36+
Note that unit tests requiring `mpi4py` installation are not automatically tested by GitHub workflows. Therefore, running these tests manually with `mpi4py` installed is necessary to assure that all tests pass.

diffopt/kdescent/descent.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313

1414
def adam(lossfunc, guess, nsteps=100, param_bounds=None,
15-
learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs):
15+
learning_rate=0.01, randkey=1, const_randkey=False,
16+
thin=1, progress=True, **other_kwargs):
1617
"""
1718
Perform gradient descent
1819
@@ -36,6 +37,11 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
3637
const_randkey : bool, optional
3738
By default (False), randkey is regenerated at each gradient descent
3839
iteration. Remove this behavior by setting const_randkey=True
40+
thin : int, optional
41+
Return parameters for every `thin` iterations, by default 1. Set
42+
`thin=0` to only return final parameters
43+
progress : bool, optional
44+
Display tqdm progress bar, by default True
3945
4046
Returns
4147
-------
@@ -46,7 +52,7 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
4652
if param_bounds is None:
4753
return adam_unbounded(
4854
lossfunc, guess, nsteps, learning_rate, randkey,
49-
const_randkey, **other_kwargs)
55+
const_randkey, thin, progress, **other_kwargs)
5056

5157
assert len(guess) == len(param_bounds)
5258
if hasattr(param_bounds, "tolist"):
@@ -60,14 +66,15 @@ def ulossfunc(uparams, *args, **kwargs):
6066
init_uparams = apply_transforms(guess, param_bounds)
6167
uparams = adam_unbounded(
6268
ulossfunc, init_uparams, nsteps, learning_rate, randkey,
63-
const_randkey, **other_kwargs)
69+
const_randkey, thin, progress, **other_kwargs)
6470
params = apply_inverse_transforms(uparams.T, param_bounds).T
6571

6672
return params
6773

6874

6975
def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
70-
randkey=1, const_randkey=False, **other_kwargs):
76+
randkey=1, const_randkey=False,
77+
thin=1, progress=True, **other_kwargs):
7178
kwargs = {**other_kwargs}
7279
if randkey is not None:
7380
randkey = keygen.init_randkey(randkey)
@@ -78,13 +85,18 @@ def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
7885
opt = optax.adam(learning_rate)
7986
solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps)
8087
state = solver.init_state(guess, **kwargs)
81-
params = [guess]
82-
for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"):
88+
params = []
89+
params_i = guess
90+
for i in tqdm.trange(nsteps, disable=not progress,
91+
desc="Adam Gradient Descent Progress"):
8392
if randkey is not None:
8493
randkey, key_i = jax.random.split(randkey)
8594
kwargs["randkey"] = key_i
86-
params_i, state = solver.update(params[-1], state, **kwargs)
87-
params.append(params_i)
95+
params_i, state = solver.update(params_i, state, **kwargs)
96+
if i == nsteps - 1 or (thin and i % thin == thin - 1):
97+
params.append(params_i)
98+
if not thin:
99+
params = params[-1]
88100

89101
return jnp.array(params)
90102

diffopt/multigrad/adam.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
N_RANKS = 1
2626

2727

28-
def trange_no_tqdm(n, desc=None):
28+
def trange_no_tqdm(n, desc=None, disable=False):
2929
return range(n)
3030

3131

32-
def trange_with_tqdm(n, desc="Adam Gradient Descent Progress"):
33-
return tqdm.trange(n, desc=desc)
32+
def trange_with_tqdm(n, desc="Adam Gradient Descent Progress", disable=False):
33+
return tqdm.trange(n, desc=desc, disable=disable)
3434

3535

3636
adam_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
@@ -49,27 +49,32 @@ def _master_wrapper(params, logloss_and_grad_fn, data, randkey=None):
4949
return loss, grad
5050

5151

52-
def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None):
52+
def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None,
53+
thin=1, progress=True):
5354
kwargs = {}
5455
# Note: Might be recommended to use optax instead of jax.example_libraries
5556
opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
5657
opt_state = opt_init(params)
5758

58-
param_steps = [params]
59-
for step in adam_trange(nsteps):
59+
param_steps = []
60+
for step in adam_trange(nsteps, disable=not progress):
6061
if randkey is not None:
6162
randkey, key_i = jax.random.split(randkey)
6263
kwargs["randkey"] = key_i
6364
_, grad = fn(params, *fn_data, **kwargs)
6465
opt_state = opt_update(step, grad, opt_state)
6566
params = get_params(opt_state)
66-
param_steps.append(params)
67+
if step == nsteps - 1 or (thin and step % thin == thin - 1):
68+
param_steps.append(params)
69+
if not thin:
70+
param_steps = param_steps[-1]
6771

6872
return jnp.array(param_steps)
6973

7074

7175
def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
72-
learning_rate=0.01, randkey=None):
76+
learning_rate=0.01, randkey=None,
77+
thin=1, progress=True):
7378
"""Run the adam optimizer on a loss function with a custom gradient.
7479
7580
Parameters
@@ -88,6 +93,11 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
8893
randkey : int | PRNG Key
8994
If given, a new PRNG Key will be generated at each iteration and be
9095
passed to `logloss_and_grad_fn` under the "randkey" kwarg
96+
thin : int, optional
97+
Return parameters for every `thin` iterations, by default 1. Set
98+
`thin=0` to only return final parameters
99+
progress : bool, optional
100+
Display tqdm progress bar, by default True
91101
92102
Returns
93103
-------
@@ -104,7 +114,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
104114
fn_data = (logloss_and_grad_fn, data)
105115

106116
params = _adam_optimizer(params, fn, fn_data, nsteps, learning_rate,
107-
randkey=randkey)
117+
randkey=randkey, thin=thin, progress=progress)
108118

109119
if COMM is not None:
110120
COMM.bcast("exit", root=0)
@@ -131,7 +141,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
131141

132142

133143
def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
134-
learning_rate=0.01, randkey=None):
144+
learning_rate=0.01, randkey=None, thin=1, progress=True):
135145
"""Run the adam optimizer on a loss function with a custom gradient.
136146
137147
Parameters
@@ -153,6 +163,11 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
153163
randkey : int | PRNG Key
154164
If given, a new PRNG Key will be generated at each iteration and be
155165
passed to `logloss_and_grad_fn` under the "randkey" kwarg
166+
thin : int, optional
167+
Return parameters for every `thin` iterations, by default 1. Set
168+
`thin=0` to only return final parameters
169+
progress : bool, optional
170+
Display tqdm progress bar, by default True
156171
157172
Returns
158173
-------
@@ -162,7 +177,8 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
162177
if param_bounds is None:
163178
return run_adam_unbounded(
164179
logloss_and_grad_fn, params, data, nsteps=nsteps,
165-
learning_rate=learning_rate, randkey=randkey)
180+
learning_rate=learning_rate, randkey=randkey,
181+
thin=thin, progress=progress)
166182

167183
assert len(params) == len(param_bounds)
168184
if hasattr(param_bounds, "tolist"):
@@ -182,7 +198,8 @@ def unbound_loss_and_grad(uparams, *args, **kwargs):
182198

183199
uparams = apply_trans(params)
184200
final_uparams = run_adam_unbounded(
185-
unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey)
201+
unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey,
202+
thin, progress)
186203

187204
if RANK == 0:
188205
final_params = invert_trans(final_uparams.T).T

diffopt/multigrad/bfgs.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818
N_RANKS = 1
1919

2020

21-
def trange_no_tqdm(n, desc=None):
21+
def trange_no_tqdm(n, desc=None, disable=False):
2222
return range(n)
2323

2424

25-
def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress"):
26-
return tqdm.trange(n, desc=desc, leave=True)
25+
def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress", disable=False):
26+
return tqdm.trange(n, desc=desc, leave=True, disable=disable)
2727

2828

2929
bfgs_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
3030

3131

3232
def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
33-
randkey=None, comm=COMM):
33+
randkey=None, progress=True, comm=COMM):
3434
"""Run the adam optimizer on a loss function with a custom gradient.
3535
3636
Parameters
@@ -46,6 +46,8 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
4646
`None` as the bound for each unbounded parameter, by default None
4747
randkey : int | PRNG Key (default=None)
4848
This will be passed to `logloss_and_grad_fn` under the "randkey" kwarg
49+
progress : bool, optional
50+
Display tqdm progress bar, by default True
4951
comm : MPI Communicator (default=COMM_WORLD)
5052
Communicator between all desired MPI ranks
5153
@@ -66,7 +68,7 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
6668
kwargs["randkey"] = randkey
6769

6870
if comm is None or comm.rank == 0:
69-
pbar = bfgs_trange(maxsteps)
71+
pbar = bfgs_trange(maxsteps, disable=not progress)
7072

7173
# Wrap loss_and_grad function with commands to the worker ranks
7274
def loss_and_grad_fn_root(params):

0 commit comments

Comments
 (0)