Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BAOA to SGMCMC #119

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Monte Carlo (SGHMC) algorithm from [Chen et al, 2014](https://arxiv.org/abs/1402
- [`sgmcmc.sgnht`](sgmcmc/sgnht.md) implements the stochastic gradient Nosé-Hoover
thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf),
(SGHMC with adaptive friction coefficient).
- [`sgmcmc.baoa`](sgmcmc/baoa.md) implements the BAOA integrator for SGHMC
from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/book/10.1007/978-3-319-16375-8).

For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696).

Expand Down
3 changes: 3 additions & 0 deletions docs/api/sgmcmc/baoa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# BAOA

::: posteriors.sgmcmc.baoa
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ nav:
- api/sgmcmc/sgld.md
- api/sgmcmc/sghmc.md
- api/sgmcmc/sgnht.md
- api/sgmcmc/baoa.md
- VI:
- Dense: api/vi/dense.md
- Diag: api/vi/diag.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from posteriors.sgmcmc import sgld
from posteriors.sgmcmc import sghmc
from posteriors.sgmcmc import sgnht
from posteriors.sgmcmc import baoa
172 changes: 172 additions & 0 deletions posteriors/sgmcmc/baoa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import Any
from functools import partial
import torch
from torch.func import grad_and_value
from optree import tree_map
from tensordict import TensorClass, NonTensorData

from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import flexi_tree_map, tree_insert_
from posteriors.utils import is_scalar, CatchAuxError


def build(
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
sigma: float = 1.0,
temperature: float = 1.0,
momenta: TensorTree | float | None = None,
) -> Transform:
"""Builds BAOA transform.

Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).

BAOA is conjugate to BAOAB (in Leimkuhler and Matthews' terminology) but requires
only a single gradient evaluation per iteration.
The two are equivalent when analyzing functions of the parameter trajectory.
Unlike BAOAB, BAOA is not reversible, but since we don't apply Metropolis-Hastings
or momenta reversal, the algorithm remains functionally identical to BAOAB.

\\begin{align}
m_{t+1/2} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}), \\\\
θ_{t+1/2} &= θ_t + (ε / 2) σ^{-2} m_{t+1/2}, \\\\
m_{t+1} &= e^{-h γ} m_{t+1/2} + N(0, ζ^2 σ^2), \\\\
θ_{t+1} &= θ_{t+1/2} + (ε / 2) σ^{-2} m_{t+1} \\
\\end{align}

for learning rate $\\epsilon$, temperature $T$, transformed friction $γ = α σ^{-2}$
and transformed noise variance$ζ^2 = T(1 - e^{-2γε})$.

Targets $p_T(θ, m) \\propto \\exp( (\\log p(θ) - \\frac{1}{2σ^2} m^Tm) / T)$
with temperature $T$.

The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.

Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).

Returns:
SGHMC transform instance.
"""
init_fn = partial(init, momenta=momenta)
update_fn = partial(
update,
log_posterior=log_posterior,
lr=lr,
alpha=alpha,
sigma=sigma,
temperature=temperature,
)
return Transform(init_fn, update_fn)


class BAOAState(TensorClass["frozen"]):
"""State encoding params and momenta for BAOA.

Attributes:
params: Parameters.
momenta: Momenta for each parameter.
log_posterior: Log posterior evaluation.
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
momenta: TensorTree
log_posterior: torch.Tensor = torch.tensor([])
aux: NonTensorData = None


def init(params: TensorTree, momenta: TensorTree | float | None = None) -> BAOAState:
"""Initialise momenta for BAOA.

Args:
params: Parameters for which to initialise.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).

Returns:
Initial SGHMCState containing momenta.
"""
if momenta is None:
momenta = tree_map(
lambda x: torch.randn_like(x, requires_grad=x.requires_grad),
params,
)
elif is_scalar(momenta):
momenta = tree_map(
lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad),
params,
)

return BAOAState(params, momenta)


def update(
state: BAOAState,
batch: Any,
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
sigma: float = 1.0,
temperature: float = 1.0,
inplace: bool = False,
) -> BAOAState:
"""Updates parameters and momenta for BAOA.

Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).

See [build](baoa.md#posteriors.sgmcmc.baoa.build) for more details.

Args:
state: SGHMCState containing params and momenta.
batch: Data batch to be send to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
inplace: Whether to modify state in place.

Returns:
Updated state
(which are pointers to the inputted state tensors if inplace=True).
"""
with torch.no_grad(), CatchAuxError():
grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
state.params, batch
)

prec = sigma**-2
gamma = torch.tensor(alpha * prec)
zeta2 = (temperature * (1 - torch.exp(-2 * gamma * lr))) ** 0.5

def BB_step(m, g):
return m + lr * g

def A_step(p, m):
return p + (lr / 2) * prec * m

def O_step(m):
return torch.exp(-gamma * lr) * m + zeta2 * sigma * torch.randn_like(m)

momenta = flexi_tree_map(BB_step, state.momenta, grads, inplace=inplace)
params = flexi_tree_map(A_step, state.params, momenta, inplace=inplace)
momenta = flexi_tree_map(O_step, momenta, inplace=inplace)
params = flexi_tree_map(A_step, params, momenta, inplace=inplace)

if inplace:
tree_insert_(state.log_posterior, log_post.detach())
return state.replace(aux=NonTensorData(aux))
return BAOAState(params, momenta, log_post.detach(), aux)
84 changes: 84 additions & 0 deletions tests/sgmcmc/test_baoa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.sgmcmc import baoa

from tests.scenarios import batch_normal_log_prob


def test_baoa():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

target_mean_flat = tree_ravel(target_mean)[0]
target_cov = torch.diag(tree_ravel(target_sds)[0] ** 2)

batch = torch.arange(10).reshape(-1, 1)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

n_steps = 10000
lr = 1e-2
alpha = 1.0

params = tree_map(lambda x: torch.zeros_like(x), target_mean)
init_params_copy = tree_map(lambda x: x.clone(), params)

sampler = baoa.build(batch_normal_log_prob_spec, lr=lr, alpha=alpha)

# Test inplace = False
baoa_state = sampler.init(params)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
baoa_state = sampler.update(baoa_state, batch, inplace=False)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, baoa_state.params
)

log_posts.append(baoa_state.log_posterior.item())

burnin = 1000
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x == y), params, init_params_copy
) # Check that the parameters are not updated

# Test inplace = True
baoa_state = sampler.init(params, momenta=0.0)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
baoa_state = sampler.update(baoa_state, batch, inplace=True)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, baoa_state.params
)

log_posts.append(baoa_state.log_posterior.item())

burnin = 1000
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x != y), params, init_params_copy
) # Check that the parameters are updated