|
17 | 17 | from typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable
|
18 | 18 |
|
19 | 19 | from fun_mc import backend
|
| 20 | +from fun_mc import fun_mc_lib as fun_mc |
20 | 21 | from fun_mc import types
|
21 | 22 |
|
22 | 23 | jax = backend.jax
|
|
34 | 35 | BoolScalar = types.BoolScalar
|
35 | 36 | IntScalar = types.IntScalar
|
36 | 37 | FloatScalar = types.FloatScalar
|
| 38 | +PotentialFn = types.PotentialFn |
| 39 | + |
37 | 40 | State = TypeVar('State')
|
38 | 41 | Extra = TypeVar('Extra')
|
| 42 | +KernelExtra = TypeVar('KernelExtra') |
39 | 43 | T = TypeVar('T')
|
40 | 44 |
|
41 | 45 | __all__ = [
|
| 46 | + 'annealed_importance_sampling_kernel', |
| 47 | + 'AnnealedImportanceSamplingKernelExtra', |
42 | 48 | 'conditional_systematic_resampling',
|
43 | 49 | 'effective_sample_size_predicate',
|
44 | 50 | 'ParticleGatherFn',
|
@@ -518,6 +524,193 @@ def dont_resample(
|
518 | 524 | return smc_state, smc_extra
|
519 | 525 |
|
520 | 526 |
|
| 527 | +@runtime_checkable |
| 528 | +class AnnealedImportanceSamplingMCMCKernel(Protocol[State, Extra, KernelExtra]): |
| 529 | + """Function that decides whether to resample.""" |
| 530 | + |
| 531 | + def __call__( |
| 532 | + self, |
| 533 | + state: State, |
| 534 | + step: IntScalar, |
| 535 | + target_log_prob_fn: PotentialFn[Extra], |
| 536 | + seed: Seed, |
| 537 | + ) -> tuple[State, KernelExtra]: |
| 538 | + """Return boolean indicating whether to resample. |
| 539 | +
|
| 540 | + Note that resampling happens before stepping the kernel. |
| 541 | +
|
| 542 | + Args: |
| 543 | + state: State step `t`. |
| 544 | + step: The timestep, `t`. |
| 545 | + target_log_prob_fn: Target distribution corresponding to `t`. |
| 546 | + seed: PRNG seed. |
| 547 | +
|
| 548 | + Returns: |
| 549 | + new_state: New state, targeting `target_log_prob_fn`. |
| 550 | + extra: Extra information from the kernel. |
| 551 | + """ |
| 552 | + |
| 553 | + |
| 554 | +@util.dataclass |
| 555 | +class AnnealedImportanceSamplingKernelExtra(Generic[KernelExtra, Extra]): |
| 556 | + """Extra outputs from the AIS kernel. |
| 557 | +
|
| 558 | + Attributes: |
| 559 | + kernel_extra: Extra outputs from the inner kernel. |
| 560 | + next_state_extra: Extra output from the next step's target log prob |
| 561 | + function. |
| 562 | + cur_state_extra: Extra output from the current step's target log prob |
| 563 | + function. |
| 564 | + """ |
| 565 | + |
| 566 | + kernel_extra: KernelExtra |
| 567 | + cur_state_extra: Extra |
| 568 | + next_state_extra: Extra |
| 569 | + |
| 570 | + |
| 571 | +@types.runtime_typed |
| 572 | +def annealed_importance_sampling_kernel( |
| 573 | + state: State, |
| 574 | + step: IntScalar, |
| 575 | + seed: Seed, |
| 576 | + kernel: AnnealedImportanceSamplingMCMCKernel[State, Extra, KernelExtra], |
| 577 | + make_target_log_probability_fn: Callable[[IntScalar], PotentialFn[Extra]], |
| 578 | +) -> tuple[ |
| 579 | + State, |
| 580 | + tuple[ |
| 581 | + Float[Array, 'num_particles'], |
| 582 | + AnnealedImportanceSamplingKernelExtra[KernelExtra, Extra], |
| 583 | + ], |
| 584 | +]: |
| 585 | + """SMC kernel that implements Annealed Importance Sampling. |
| 586 | +
|
| 587 | + Annealed Importance Sampling (AIS)[1] can be interpreted as a special case of |
| 588 | + SMC with a particular choice of forward and reverse kernels: |
| 589 | + ```none |
| 590 | + r_t = k_t(x_{t + 1} | x_t) p_t(x_t) / p_t(x_{t + 1}) |
| 591 | + q_t = k_{t - 1}(x_t | x_{t - 1}) |
| 592 | + ``` |
| 593 | + where `k_t` is an MCMC kernel that has `p_t` invariant. This causes the |
| 594 | + incremental weight equation to be particularly simple: |
| 595 | + ```none |
| 596 | + iw_t = p_t(x_t) / p_{t - 1}(x_t) |
| 597 | + ``` |
| 598 | + Unfortunately, the reverse kernel is not optimal, so the annealing schedule |
| 599 | + needs to be fine. The original formulation from [1] does not do resampling, |
| 600 | + but enabling it will usually reduce the variance of the estimator. |
| 601 | +
|
| 602 | + Args: |
| 603 | + state: The previous particle state, `x_{t - 1}^{1:K}`. |
| 604 | + step: The previous timestep, `t - 1`. |
| 605 | + seed: PRNG seed. |
| 606 | + kernel: The inner MCMC kernel. It takes the current state, the timestep, the |
| 607 | + target distribution and the seed and generates an approximate sample from |
| 608 | + `p_t` where `t` is the passed-in timestep. |
| 609 | + make_target_log_probability_fn: A function that, given a timestep, returns |
| 610 | + the target distribution `p_t` where `t` is the passed-in timestep. |
| 611 | +
|
| 612 | + Returns: |
| 613 | + state: The new particles, `x_t^{1:K}`. |
| 614 | + extra: A 2-tuple of: |
| 615 | + incremental_log_weights: The incremental log weight at timestep t, |
| 616 | + `iw_t^{1:K}`. |
| 617 | + kernel_extra: Extra information returned by the kernel. |
| 618 | +
|
| 619 | + #### Example |
| 620 | +
|
| 621 | + In this example we estimate the normalizing constant ratio between `tlp_1` |
| 622 | + and `tlp_2`. |
| 623 | +
|
| 624 | + ```python |
| 625 | + def tlp_1(x): |
| 626 | + return -(x**2) / 2.0, () |
| 627 | +
|
| 628 | + def tlp_2(x): |
| 629 | + return -((x - 2) ** 2) / 2 / 16.0, () |
| 630 | +
|
| 631 | + @jax.jit |
| 632 | + def kernel(smc_state, seed): |
| 633 | + smc_seed, seed = jax.random.split(seed, 2) |
| 634 | +
|
| 635 | + def inner_kernel(state, stage, tlp_fn, seed): |
| 636 | + f = jnp.array(stage, state.dtype) / num_steps |
| 637 | + hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) |
| 638 | + hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step( |
| 639 | + hmc_state, |
| 640 | + tlp_fn, |
| 641 | + step_size=f * 4.0 + (1.0 - f) * 1.0, |
| 642 | + num_integrator_steps=1, |
| 643 | + seed=seed, |
| 644 | + ) |
| 645 | + return hmc_state.state, () |
| 646 | +
|
| 647 | + smc_state, _ = smc.sequential_monte_carlo_step( |
| 648 | + smc_state, |
| 649 | + kernel=functools.partial( |
| 650 | + smc.annealed_importance_sampling_kernel, |
| 651 | + kernel=inner_kernel, |
| 652 | + make_target_log_probability_fn=functools.partial( |
| 653 | + fun_mc.geometric_annealing_path, |
| 654 | + num_stages=num_steps, |
| 655 | + initial_target_log_prob_fn=tlp_1, |
| 656 | + final_target_log_prob_fn=tlp_2, |
| 657 | + ), |
| 658 | + ), |
| 659 | + seed=smc_seed, |
| 660 | + ) |
| 661 | +
|
| 662 | + return (smc_state, seed), () |
| 663 | +
|
| 664 | + num_steps = 100 |
| 665 | + num_particles = 400 |
| 666 | + init_seed, seed = jax.random.split(jax.random.PRNGKey(0)) |
| 667 | + init_state = jax.random.normal(init_seed, [num_particles]) |
| 668 | +
|
| 669 | + (smc_state, _), _ = fun_mc.trace( |
| 670 | + ( |
| 671 | + smc.sequential_monte_carlo_init( |
| 672 | + init_state, |
| 673 | + weight_dtype=self._dtype, |
| 674 | + ), |
| 675 | + smc_seed, |
| 676 | + ), |
| 677 | + kernel, |
| 678 | + num_steps, |
| 679 | + ) |
| 680 | +
|
| 681 | + weights = jnp.exp(smc_state.log_weights) |
| 682 | + # Should be close to 4. |
| 683 | + print(estimated z2/z1, weights.mean()) |
| 684 | + # Should be close to 2. |
| 685 | + print(estimated mean, (jax.nn.softmax(smc_state.log_weights) |
| 686 | + * smc_state.state).sum()) |
| 687 | + ``` |
| 688 | +
|
| 689 | + #### References |
| 690 | +
|
| 691 | + [1]: Neal, Radford M. (1998) Annealed Importance Sampling. |
| 692 | + https://arxiv.org/abs/physics/9803008 |
| 693 | + """ |
| 694 | + new_state, kernel_extra = kernel( |
| 695 | + state, step, make_target_log_probability_fn(step), seed |
| 696 | + ) |
| 697 | + tlp_num, num_extra = fun_mc.call_potential_fn( |
| 698 | + make_target_log_probability_fn(step + 1), new_state |
| 699 | + ) |
| 700 | + tlp_denom, denom_extra = fun_mc.call_potential_fn( |
| 701 | + make_target_log_probability_fn(step), new_state |
| 702 | + ) |
| 703 | + extra = AnnealedImportanceSamplingKernelExtra( |
| 704 | + kernel_extra=kernel_extra, |
| 705 | + cur_state_extra=denom_extra, |
| 706 | + next_state_extra=num_extra, |
| 707 | + ) |
| 708 | + return new_state, ( |
| 709 | + tlp_num - tlp_denom, |
| 710 | + extra, |
| 711 | + ) |
| 712 | + |
| 713 | + |
521 | 714 | def _smart_cond(
|
522 | 715 | pred: BoolScalar,
|
523 | 716 | true_fn: Callable[..., T],
|
|
0 commit comments