Skip to content

Commit 0a56cac

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Add the AIS kernel for use with SMC.
PiperOrigin-RevId: 721109723
1 parent 55d191e commit 0a56cac

File tree

3 files changed

+279
-2
lines changed

3 files changed

+279
-2
lines changed

spinoffs/fun_mc/fun_mc/smc.py

+193
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable
1818

1919
from fun_mc import backend
20+
from fun_mc import fun_mc_lib as fun_mc
2021
from fun_mc import types
2122

2223
jax = backend.jax
@@ -34,11 +35,16 @@
3435
BoolScalar = types.BoolScalar
3536
IntScalar = types.IntScalar
3637
FloatScalar = types.FloatScalar
38+
PotentialFn = types.PotentialFn
39+
3740
State = TypeVar('State')
3841
Extra = TypeVar('Extra')
42+
KernelExtra = TypeVar('KernelExtra')
3943
T = TypeVar('T')
4044

4145
__all__ = [
46+
'annealed_importance_sampling_kernel',
47+
'AnnealedImportanceSamplingKernelExtra',
4248
'conditional_systematic_resampling',
4349
'effective_sample_size_predicate',
4450
'ParticleGatherFn',
@@ -518,6 +524,193 @@ def dont_resample(
518524
return smc_state, smc_extra
519525

520526

527+
@runtime_checkable
528+
class AnnealedImportanceSamplingMCMCKernel(Protocol[State, Extra, KernelExtra]):
529+
"""Function that decides whether to resample."""
530+
531+
def __call__(
532+
self,
533+
state: State,
534+
step: IntScalar,
535+
target_log_prob_fn: PotentialFn[Extra],
536+
seed: Seed,
537+
) -> tuple[State, KernelExtra]:
538+
"""Return boolean indicating whether to resample.
539+
540+
Note that resampling happens before stepping the kernel.
541+
542+
Args:
543+
state: State step `t`.
544+
step: The timestep, `t`.
545+
target_log_prob_fn: Target distribution corresponding to `t`.
546+
seed: PRNG seed.
547+
548+
Returns:
549+
new_state: New state, targeting `target_log_prob_fn`.
550+
extra: Extra information from the kernel.
551+
"""
552+
553+
554+
@util.dataclass
555+
class AnnealedImportanceSamplingKernelExtra(Generic[KernelExtra, Extra]):
556+
"""Extra outputs from the AIS kernel.
557+
558+
Attributes:
559+
kernel_extra: Extra outputs from the inner kernel.
560+
next_state_extra: Extra output from the next step's target log prob
561+
function.
562+
cur_state_extra: Extra output from the current step's target log prob
563+
function.
564+
"""
565+
566+
kernel_extra: KernelExtra
567+
cur_state_extra: Extra
568+
next_state_extra: Extra
569+
570+
571+
@types.runtime_typed
572+
def annealed_importance_sampling_kernel(
573+
state: State,
574+
step: IntScalar,
575+
seed: Seed,
576+
kernel: AnnealedImportanceSamplingMCMCKernel[State, Extra, KernelExtra],
577+
make_target_log_probability_fn: Callable[[IntScalar], PotentialFn[Extra]],
578+
) -> tuple[
579+
State,
580+
tuple[
581+
Float[Array, 'num_particles'],
582+
AnnealedImportanceSamplingKernelExtra[KernelExtra, Extra],
583+
],
584+
]:
585+
"""SMC kernel that implements Annealed Importance Sampling.
586+
587+
Annealed Importance Sampling (AIS)[1] can be interpreted as a special case of
588+
SMC with a particular choice of forward and reverse kernels:
589+
```none
590+
r_t = k_t(x_{t + 1} | x_t) p_t(x_t) / p_t(x_{t + 1})
591+
q_t = k_{t - 1}(x_t | x_{t - 1})
592+
```
593+
where `k_t` is an MCMC kernel that has `p_t` invariant. This causes the
594+
incremental weight equation to be particularly simple:
595+
```none
596+
iw_t = p_t(x_t) / p_{t - 1}(x_t)
597+
```
598+
Unfortunately, the reverse kernel is not optimal, so the annealing schedule
599+
needs to be fine. The original formulation from [1] does not do resampling,
600+
but enabling it will usually reduce the variance of the estimator.
601+
602+
Args:
603+
state: The previous particle state, `x_{t - 1}^{1:K}`.
604+
step: The previous timestep, `t - 1`.
605+
seed: PRNG seed.
606+
kernel: The inner MCMC kernel. It takes the current state, the timestep, the
607+
target distribution and the seed and generates an approximate sample from
608+
`p_t` where `t` is the passed-in timestep.
609+
make_target_log_probability_fn: A function that, given a timestep, returns
610+
the target distribution `p_t` where `t` is the passed-in timestep.
611+
612+
Returns:
613+
state: The new particles, `x_t^{1:K}`.
614+
extra: A 2-tuple of:
615+
incremental_log_weights: The incremental log weight at timestep t,
616+
`iw_t^{1:K}`.
617+
kernel_extra: Extra information returned by the kernel.
618+
619+
#### Example
620+
621+
In this example we estimate the normalizing constant ratio between `tlp_1`
622+
and `tlp_2`.
623+
624+
```python
625+
def tlp_1(x):
626+
return -(x**2) / 2.0, ()
627+
628+
def tlp_2(x):
629+
return -((x - 2) ** 2) / 2 / 16.0, ()
630+
631+
@jax.jit
632+
def kernel(smc_state, seed):
633+
smc_seed, seed = jax.random.split(seed, 2)
634+
635+
def inner_kernel(state, stage, tlp_fn, seed):
636+
f = jnp.array(stage, state.dtype) / num_steps
637+
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
638+
hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step(
639+
hmc_state,
640+
tlp_fn,
641+
step_size=f * 4.0 + (1.0 - f) * 1.0,
642+
num_integrator_steps=1,
643+
seed=seed,
644+
)
645+
return hmc_state.state, ()
646+
647+
smc_state, _ = smc.sequential_monte_carlo_step(
648+
smc_state,
649+
kernel=functools.partial(
650+
smc.annealed_importance_sampling_kernel,
651+
kernel=inner_kernel,
652+
make_target_log_probability_fn=functools.partial(
653+
fun_mc.geometric_annealing_path,
654+
num_stages=num_steps,
655+
initial_target_log_prob_fn=tlp_1,
656+
final_target_log_prob_fn=tlp_2,
657+
),
658+
),
659+
seed=smc_seed,
660+
)
661+
662+
return (smc_state, seed), ()
663+
664+
num_steps = 100
665+
num_particles = 400
666+
init_seed, seed = jax.random.split(jax.random.PRNGKey(0))
667+
init_state = jax.random.normal(init_seed, [num_particles])
668+
669+
(smc_state, _), _ = fun_mc.trace(
670+
(
671+
smc.sequential_monte_carlo_init(
672+
init_state,
673+
weight_dtype=self._dtype,
674+
),
675+
smc_seed,
676+
),
677+
kernel,
678+
num_steps,
679+
)
680+
681+
weights = jnp.exp(smc_state.log_weights)
682+
# Should be close to 4.
683+
print(estimated z2/z1, weights.mean())
684+
# Should be close to 2.
685+
print(estimated mean, (jax.nn.softmax(smc_state.log_weights)
686+
* smc_state.state).sum())
687+
```
688+
689+
#### References
690+
691+
[1]: Neal, Radford M. (1998) Annealed Importance Sampling.
692+
https://arxiv.org/abs/physics/9803008
693+
"""
694+
new_state, kernel_extra = kernel(
695+
state, step, make_target_log_probability_fn(step), seed
696+
)
697+
tlp_num, num_extra = fun_mc.call_potential_fn(
698+
make_target_log_probability_fn(step + 1), new_state
699+
)
700+
tlp_denom, denom_extra = fun_mc.call_potential_fn(
701+
make_target_log_probability_fn(step), new_state
702+
)
703+
extra = AnnealedImportanceSamplingKernelExtra(
704+
kernel_extra=kernel_extra,
705+
cur_state_extra=denom_extra,
706+
next_state_extra=num_extra,
707+
)
708+
return new_state, (
709+
tlp_num - tlp_denom,
710+
extra,
711+
)
712+
713+
521714
def _smart_cond(
522715
pred: BoolScalar,
523716
true_fn: Callable[..., T],

spinoffs/fun_mc/fun_mc/smc_test.py

+67
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,73 @@ def kernel(smc_state, seed):
11841184
self.assertAllClose(gt_log_evidence, log_evidence, rtol=0.01)
11851185
self.assertAllClose(gt_log_evidence, log_evidence, atol=0.2)
11861186

1187+
def test_annealed_importance_sampling(self):
1188+
def tlp_1(x):
1189+
return -0.5 * x**2, ()
1190+
1191+
def tlp_2(x):
1192+
return (-0.5 * (x - 2) ** 2) / 16.0, ()
1193+
1194+
@jax.jit
1195+
def kernel(smc_state, seed):
1196+
smc_seed, seed = util.split_seed(seed, 2)
1197+
1198+
def inner_kernel(state, step, tlp_fn, seed):
1199+
f = jnp.array(step, state.dtype) / num_steps
1200+
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
1201+
hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step(
1202+
hmc_state,
1203+
tlp_fn,
1204+
step_size=f * 4.0 + (1.0 - f) * 1.0,
1205+
num_integrator_steps=1,
1206+
seed=seed,
1207+
)
1208+
return hmc_state.state, ()
1209+
1210+
smc_state, _ = smc.sequential_monte_carlo_step(
1211+
smc_state,
1212+
kernel=functools.partial(
1213+
smc.annealed_importance_sampling_kernel,
1214+
kernel=inner_kernel,
1215+
make_target_log_probability_fn=functools.partial(
1216+
fun_mc.geometric_annealing_path,
1217+
num_stages=num_steps,
1218+
initial_target_log_prob_fn=tlp_1,
1219+
final_target_log_prob_fn=tlp_2,
1220+
),
1221+
),
1222+
seed=smc_seed,
1223+
)
1224+
1225+
return (smc_state, seed), ()
1226+
1227+
num_steps = 1000
1228+
num_particles = 1000
1229+
init_seed, smc_seed = util.split_seed(_test_seed(), 2)
1230+
init_state = util.random_normal([num_particles], self._dtype, init_seed)
1231+
1232+
(smc_state, _), _ = fun_mc.trace(
1233+
(
1234+
smc.sequential_monte_carlo_init(
1235+
init_state,
1236+
weight_dtype=self._dtype,
1237+
),
1238+
smc_seed,
1239+
),
1240+
kernel,
1241+
num_steps,
1242+
)
1243+
1244+
weights = jnp.exp(smc_state.log_weights)
1245+
# 4 because tlp_2 has stddev of 4 while tlp_1 has stddev of 1.
1246+
self.assertAllClose(4.0, jnp.mean(weights), atol=0.1)
1247+
1248+
normed_weights = jax.nn.softmax(smc_state.log_weights)
1249+
mean = jnp.sum(normed_weights * smc_state.state)
1250+
variance = jnp.sum(normed_weights * (smc_state.state - mean) ** 2)
1251+
self.assertAllClose(2.0, mean, atol=0.3)
1252+
self.assertAllClose(16.0, variance, rtol=0.2)
1253+
11871254

11881255
@test_util.multi_backend_test(globals(), 'smc_test')
11891256
class SMCTest32(SMCTest):

spinoffs/fun_mc/fun_mc/types.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ============================================================================
1515
"""Various types used in FunMC."""
1616

17-
from typing import Callable, TypeAlias, TypeVar
17+
from typing import Callable, Protocol, TypeAlias, TypeVar, runtime_checkable
1818

1919
import jaxtyping
2020
from fun_mc import backend
@@ -29,6 +29,7 @@
2929
'FloatScalar',
3030
'Int',
3131
'IntScalar',
32+
'PotentialFn',
3233
'runtime_typed',
3334
'Seed',
3435
]
@@ -42,8 +43,24 @@
4243
BoolScalar: TypeAlias = bool | Bool[Array, '']
4344
IntScalar: TypeAlias = int | Int[Array, '']
4445
FloatScalar: TypeAlias = float | Float[Array, '']
45-
4646
F = TypeVar('F', bound=Callable)
47+
_Extra = TypeVar('_Extra')
48+
49+
50+
@runtime_checkable
51+
class PotentialFn(Protocol[_Extra]):
52+
"""Maps state to an array of float.
53+
54+
If the state has leading dimension, the same dimension is present in the
55+
returned values as well.
56+
"""
57+
58+
def __call__(
59+
self,
60+
*args,
61+
**kwargs,
62+
) -> tuple[Float[Array, '...'], _Extra]:
63+
"""Potential function."""
4764

4865

4966
def runtime_typed(f: F) -> F:

0 commit comments

Comments
 (0)