-
Notifications
You must be signed in to change notification settings - Fork 72
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
488bd9c
f46f1cd
894f62b
a1afaf6
637fc3b
3e397f7
d954ec7
aad9f21
ef3d86b
6bf92ef
32aff46
138f8c2
7073a7d
609aef7
b611d51
f17a090
9ab2e1e
a8a53f3
bdee446
3fcafb6
ad46b07
6cd0184
d648105
9d18f80
9f86d4f
3b090ca
93cd831
7b84872
7cd407e
cb070aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,228 @@ | ||||||||
| from collections import defaultdict | ||||||||
| from typing import Tuple, Optional | ||||||||
|
|
||||||||
| import pymc | ||||||||
| from pymc import Model | ||||||||
| import arviz as az | ||||||||
| import numpy as np | ||||||||
| from scipy.optimize import minimize | ||||||||
| import pytensor | ||||||||
| import pytensor.tensor as pt | ||||||||
| from pytensor.tensor.variable import TensorVariable | ||||||||
| import xarray | ||||||||
|
|
||||||||
| from pymc import join_nonshared_inputs, DictToArrayBijection | ||||||||
| from pymc.util import get_default_varnames | ||||||||
| from pymc.backends.arviz import ( | ||||||||
| apply_function_over_dataset, | ||||||||
| PointFunc, | ||||||||
| coords_and_dims_for_inferencedata, | ||||||||
| ) | ||||||||
| from pymc_extras.inference.laplace_approx.scipy_interface import ( | ||||||||
| _compile_functions_for_scipy_optimize, | ||||||||
| ) | ||||||||
| from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws | ||||||||
|
|
||||||||
|
|
||||||||
| def fit_deterministic_advi( | ||||||||
| model: Optional[Model] = None, | ||||||||
| n_fixed_draws: int = 30, | ||||||||
| random_seed: int = 2, | ||||||||
|
||||||||
| random_seed: int = 2, | |
| random_seed: int | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should actually be the RANDOM_SEED type we use in other places in PyMC, from which you can take an integer seed. It also allows users to pass generators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm good point about stochastic default. In my work, I've found it quite nice to have it fixed, since it makes everything deterministic -- running it multiple times always gives the same result. But if it's more in keeping with pymc expectations to have the default stochastic, I'm happy to do it.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use np.random.default_rng? RandomState is legacy numpy
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| means = var_params[:n_params] | |
| log_sds = var_params[n_params:] | |
| means , log_sds= pt.split(var_params, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split requests the size of each bucket so it would be pt.split(axis=0, split_sizes=[n_params, n_params])?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ricardoV94 , I've switched to your suggestion. The only thing I had to change was to also pass n_splits, so it's now:
means, log_sds = pt.split(
var_params, axis=0, splits_size=[n_params, n_params], n_splits=2
)
martiningram marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ dependencies = [ | |
| "better-optimize>=0.1.5", | ||
| "pydantic>=2.0.0", | ||
| "preliz>=0.20.0", | ||
| "jax>=0.7.0" | ||
|
||
| ] | ||
|
|
||
| [project.optional-dependencies] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For brevity? Just a suggestion