Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
488bd9c
Add first version of deterministic ADVI
Aug 1, 2025
f46f1cd
Update API
Aug 12, 2025
894f62b
Add a notebook example
Aug 14, 2025
a1afaf6
Merge branch 'main' into add_basic_deterministic_advi
Aug 14, 2025
637fc3b
Add to API and add a docstring
Aug 14, 2025
3e397f7
Change import in notebook
Aug 14, 2025
d954ec7
Add jax to dependencies
Aug 14, 2025
aad9f21
Add pytensor version
Aug 16, 2025
ef3d86b
Fix handling of pymc model
Aug 16, 2025
6bf92ef
Add (probably suboptimal) handling of the two backends
Aug 16, 2025
32aff46
Add transformation
Aug 18, 2025
138f8c2
Follow Ricardo's advice to simplify the transformation step
Aug 19, 2025
7073a7d
Fix naming bug
Aug 19, 2025
609aef7
Document and clean up
Aug 19, 2025
b611d51
Merge branch 'main' into add_basic_deterministic_advi
Aug 19, 2025
f17a090
Fix example
Aug 19, 2025
9ab2e1e
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 20, 2025
a8a53f3
Respond to comments
Aug 20, 2025
bdee446
Fix with pre commit checks
Aug 20, 2025
3fcafb6
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 28, 2025
ad46b07
Implement suggestions
Aug 28, 2025
6cd0184
Rename parameter because it's duplicated otherwise
Aug 28, 2025
d648105
Rename to be consistent in use of dadvi
Aug 28, 2025
9d18f80
Rename to `optimizer_method` and drop jac=True
Aug 28, 2025
9f86d4f
Add jac=True back in since trust-ncg complained
Aug 28, 2025
3b090ca
Make hessp and jac optional
Aug 28, 2025
93cd831
Harmonize naming with existing code
Aug 28, 2025
7b84872
Fix example
Aug 29, 2025
7cd407e
Switch to `better_optimize`
Aug 29, 2025
cb070aa
Replace with pt.split
Aug 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
975 changes: 975 additions & 0 deletions notebooks/deterministic_advi_example.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,12 @@
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
from pymc_extras.inference.deterministic_advi.dadvi import fit_deterministic_advi

__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
__all__ = [
"find_MAP",
"fit",
"fit_laplace",
"fit_pathfinder",
"fit_deterministic_advi",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"fit_deterministic_advi",
"fit_dadvi",

For brevity? Just a suggestion

]
Empty file.
228 changes: 228 additions & 0 deletions pymc_extras/inference/deterministic_advi/dadvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from collections import defaultdict
from typing import Tuple, Optional

import pymc
from pymc import Model
import arviz as az
import numpy as np
from scipy.optimize import minimize
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.variable import TensorVariable
import xarray

from pymc import join_nonshared_inputs, DictToArrayBijection
from pymc.util import get_default_varnames
from pymc.backends.arviz import (
apply_function_over_dataset,
PointFunc,
coords_and_dims_for_inferencedata,
)
from pymc_extras.inference.laplace_approx.scipy_interface import (
_compile_functions_for_scipy_optimize,
)
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws


def fit_deterministic_advi(
model: Optional[Model] = None,
n_fixed_draws: int = 30,
random_seed: int = 2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default should be stochastic?

Suggested change
random_seed: int = 2,
random_seed: int | None = None,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should actually be the RANDOM_SEED type we use in other places in PyMC, from which you can take an integer seed. It also allows users to pass generators

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point about stochastic default. In my work, I've found it quite nice to have it fixed, since it makes everything deterministic -- running it multiple times always gives the same result. But if it's more in keeping with pymc expectations to have the default stochastic, I'm happy to do it.

n_draws: int = 1000,
keep_untransformed: bool = False,
):
"""
Does inference using deterministic ADVI (automatic differentiation
variational inference).
For full details see the paper cited in the references:
https://www.jmlr.org/papers/v25/23-1015.html
Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.
n_fixed_draws : int
The number of fixed draws to use for the optimisation. More
draws will result in more accurate estimates, but also
increase inference time. Usually, the default of 30 is a good
tradeoff.between speed and accuracy.
random_seed: int
The random seed to use for the fixed draws. Running the optimisation
twice with the same seed should arrive at the same result.
n_draws: int
The number of draws to return from the variational approximation.
keep_untransformed: bool
Whether or not to keep the unconstrained variables (such as
logs of positive-constrained parameters) in the output.
Returns
-------
:class:`~arviz.InferenceData`
The inference data containing the results of the DADVI algorithm.
References
----------
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
Variational Inference with a Deterministic Objective: Faster, More
Accurate, and Even More Black Box. Journal of Machine Learning
Research, 25(18), 1–39.
"""

model = pymc.modelcontext(model) if model is None else model

initial_point_dict = model.initial_point()
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]

var_params, objective = create_dadvi_graph(
model,
n_fixed_draws=n_fixed_draws,
random_seed=random_seed,
n_params=n_params,
)

f_fused, f_hessp = _compile_functions_for_scipy_optimize(
objective,
[var_params],
compute_grad=True,
compute_hessp=True,
compute_hess=False,
)

result = minimize(
f_fused, np.zeros(2 * n_params), method="trust-ncg", jac=True, hessp=f_hessp
)

opt_var_params = result.x
opt_means, opt_log_sds = np.split(opt_var_params, 2)

# Make the draws:
draws_raw = np.random.randn(n_draws, n_params)
draws = opt_means + draws_raw * np.exp(opt_log_sds)
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)

transformed_draws = transform_draws(
draws_arviz, model, keep_untransformed=keep_untransformed
)

return transformed_draws


def create_dadvi_graph(
model: Model,
n_params: int,
n_fixed_draws: int = 30,
random_seed: int = 2,
) -> Tuple[TensorVariable, TensorVariable]:
"""
Sets up the DADVI graph in pytensor and returns it.
Parameters
----------
model : pm.Model
The PyMC model to be fit.
n_params: int
The total number of parameters in the model.
n_fixed_draws : int
The number of fixed draws to use.
random_seed: int
The random seed to use for the fixed draws.
Returns
-------
Tuple[TensorVariable, TensorVariable]
A tuple whose first element contains the variational parameters,
and whose second contains the DADVI objective.
"""

# Make the fixed draws
state = np.random.RandomState(random_seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use np.random.default_rng? RandomState is legacy numpy

draws = state.randn(n_fixed_draws, n_params)

inputs = model.continuous_value_vars + model.discrete_value_vars
initial_point_dict = model.initial_point()
logp = model.logp()

# Graph in terms of a flat input
[logp], flat_input = join_nonshared_inputs(
point=initial_point_dict, outputs=[logp], inputs=inputs
)

var_params = pt.vector(name="eta", shape=(2 * n_params,))

means = var_params[:n_params]
log_sds = var_params[n_params:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
means = var_params[:n_params]
log_sds = var_params[n_params:]
means , log_sds= pt.split(var_params, 2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split requests the size of each bucket so it would be pt.split(axis=0, split_sizes=[n_params, n_params])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94 , I've switched to your suggestion. The only thing I had to change was to also pass n_splits, so it's now:

    means, log_sds = pt.split(
        var_params, axis=0, splits_size=[n_params, n_params], n_splits=2
    )


draw = pt.vector(name="draw", shape=(n_params,))
sample = means + pt.exp(log_sds) * draw

# Graph in terms of a single sample
logp_draw = pytensor.clone_replace(logp, replace={flat_input: sample})
draw_matrix = pt.constant(draws)

# Vectorise
logp_vectorized_draws = pytensor.graph.vectorize_graph(
logp_draw, replace={draw: draw_matrix}
)

mean_log_density = pt.mean(logp_vectorized_draws)
entropy = pt.sum(log_sds)

objective = -mean_log_density - entropy

return var_params, objective


def transform_draws(
unstacked_draws: xarray.Dataset,
model: Model,
keep_untransformed: bool = False,
):
"""
Transforms the unconstrained draws back into the constrained space.
Parameters
----------
unstacked_draws : xarray.Dataset
The draws to constrain back into the original space.
model : Model
The PyMC model the variables were derived from.
n_draws: int
The number of draws to return from the variational approximation.
keep_untransformed: bool
Whether or not to keep the unconstrained variables in the output.
Returns
-------
:class:`~arviz.InferenceData`
Draws from the original constrained parameters.
"""

filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)
fn = pytensor.function(model.value_vars, vars_to_sample)
point_func = PointFunc(fn)

coords, dims = coords_and_dims_for_inferencedata(model)

transformed_result = apply_function_over_dataset(
point_func,
unstacked_draws,
output_var_names=[x.name for x in vars_to_sample],
coords=coords,
dims=dims,
)

return transformed_result
5 changes: 5 additions & 0 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def fit(method: str, **kwargs) -> az.InferenceData:
from pymc_extras.inference import fit_laplace

return fit_laplace(**kwargs)

if method == "deterministic_advi":
from pymc_extras.inference import fit_deterministic_advi

return fit_deterministic_advi(**kwargs)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"better-optimize>=0.1.5",
"pydantic>=2.0.0",
"preliz>=0.20.0",
"jax>=0.7.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax is an optional dependency on the project

]

[project.optional-dependencies]
Expand Down
Loading