This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.
Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.
This library runs on jax==0.1.77
and torch==1.6.0
. To install all other requirements:
pip install -r requirements.txt
Note: Package versions may change, refer to official JAX installation instructions here.
The jaxsde
library contains SDE solvers in the Ito and Stratonovich form.
Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun}
(strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE).
Stochastic adjoint (sdeint_ito
) training mode does not work efficiently yet, use sdeint_ito_fixed_grid
for now.
Trade off solver speed for precision during training or inference by adjusting --nsteps <# steps>
.
Default solver: Backpropagation through the solver.
from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid
y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")
Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.
from jaxsde.jaxsde.sdeint import sdeint_ito
y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")
Implementation of composable Bayesian layers in the stax API.
Our SDE Bayesian layers can be used with the SDEBNN
block composed with multiple parameterizations of time-dependent layers in diffeq_layers
.
Sticking-the-landing (STL) trick can be enabled during training with --stl
for improving convergence rate.
Augment the inputs by a custom amount --aug <integer>
, set the number of samples averaged over with --nsamples <integer>
.
If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad
and gradient checkpointing: --remat
.
Samples from SDEBNN-learned predictive prior and posterior density distributions.
All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn
instead).
python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou
To train an SDEBNN model:
python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1
To train a ResNet baseline, specify --model resnet
and for a Bayesian ResNet baseline, specify --meanfield_sdebnn
.
A PyTorch implementation of the Brax framework powered by the torchsde backend.
All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.
python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>
Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.
All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True
if training with adjoint and adaptive modes).
python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True
[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]
[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]
[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]
If you found this library useful in your research, please consider citing
@article{xu2021sdebnn,
title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
journal = {International Conference on Artificial Intelligence and Statistics},
year={2022}
}