From e77205bda9dd91aabb72f87e9340550419e9cdb1 Mon Sep 17 00:00:00 2001 From: yallup Date: Wed, 18 Sep 2024 14:23:28 +0100 Subject: [PATCH 01/14] Revert "Slice" --- blackjax/__init__.py | 8 - blackjax/mcmc/__init__.py | 2 - blackjax/mcmc/univariate_slice.py | 278 ---------------------------- blackjax/ns/__init__.py | 6 - blackjax/ns/base.py | 153 --------------- blackjax/ns/inner_kernel.py | 245 ------------------------ blackjax/ns/mh.py | 59 ------ blackjax/ns/rejection.py | 91 --------- blackjax/smc/inner_kernel_tuning.py | 1 - test_ns.py | 182 ------------------ test_slice.py | 100 ---------- 11 files changed, 1125 deletions(-) delete mode 100644 blackjax/mcmc/univariate_slice.py delete mode 100644 blackjax/ns/__init__.py delete mode 100644 blackjax/ns/base.py delete mode 100644 blackjax/ns/inner_kernel.py delete mode 100644 blackjax/ns/mh.py delete mode 100644 blackjax/ns/rejection.py delete mode 100644 test_ns.py delete mode 100644 test_slice.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index e707e8727..dfdcfc545 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -22,7 +22,6 @@ from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc -from .mcmc import univariate_slice as _slice from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk from .mcmc.random_walk import ( irmh_as_top_level_api, @@ -37,8 +36,6 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import tempered -from .ns import rejection -from .ns import inner_kernel from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder from .vi import schrodinger_follmer as _schrodinger_follmer @@ -113,7 +110,6 @@ def generate_top_level_api_from(module): mclmc = generate_top_level_api_from(_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) -univariate_slice = generate_top_level_api_from(_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) @@ -127,10 +123,6 @@ def generate_top_level_api_from(module): smc_family = [tempered_smc, adaptive_tempered_smc] "Step_fn returning state has a .particles attribute" -# NS -rejection_ns = generate_top_level_api_from(rejection) -inner_kernel_ns = generate_top_level_api_from(inner_kernel) - # stochastic gradient mcmc sgld = generate_top_level_api_from(_sgld) sghmc = generate_top_level_api_from(_sghmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index fdb9bf898..6e207741d 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -10,7 +10,6 @@ periodic_orbital, random_walk, rmhmc, - univariate_slice, ) __all__ = [ @@ -25,5 +24,4 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "univariate_slice", ] diff --git a/blackjax/mcmc/univariate_slice.py b/blackjax/mcmc/univariate_slice.py deleted file mode 100644 index 607ed573e..000000000 --- a/blackjax/mcmc/univariate_slice.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2020- The Blackjax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Public API for the Slice sampling Kernel""" -from typing import Callable, NamedTuple - -import jax -import jax.numpy as jnp -from jax import random - -from blackjax.base import SamplingAlgorithm -from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey - -__all__ = [ - "SliceState", - # "SliceInfo", - "init", - "build_kernel", - "as_top_level_api", -] - - -def as_top_level_api( - loglikelihood_fn: Callable, - *, - n_doublings: int = 5, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the Slice sampling kernel. - - Examples - -------- - - A slice sampling kernel can be initialized like this: - - .. code:: - - slice = blackjax.slice(logdensity_fn, n_doublings) - state = slice.init(position) - new_state, info = slice.step(rng_key, state) - - We can JIT-compile the step function for better performance - - .. code:: - - step = jax.jit(slice.step) - new_state, info = step(rng_key, state) - - Parameters - ---------- - logdensity_fn: Callable - the unnormalized posterior distribution we wish to sample from. - n_doublings: int - maximal number of slice expansions. - - Returns - ------- - A ``MCMCSamplingAlgorithm``. - """ - - kernel = build_kernel(n_doublings) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, loglikelihood_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, loglikelihood_fn) - - return SamplingAlgorithm(init_fn, step_fn) - - -class SliceState(NamedTuple): - position: ArrayTree - logdensity: ArrayTree - widths: ArrayTree - n: int - - -# class SliceInfo(NamedTuple): -# widths: jnp.ndarray -# n: jnp.ndarray - - -def init(position: ArrayTree, logdensity_fn: Callable): - logdensity = logdensity_fn(position) - widths = jax.tree.map(lambda x: jnp.full(x.shape, 0.01), position) - return SliceState(position, jnp.atleast_1d(logdensity), widths, 0) - - -def build_kernel(n_doublings: int) -> Callable: - """Instantiate a slice sampling kernel. - - Implementation according to [1]. Doubling implementation inspired - by Tensorflow probability's implementation. Performs a univariate update in - each dimension. - - Parameters - ---------- - n_doublings: int - maximal number of slice expansions - - References - ------- - [1] Radford M. Neal "Slice sampling", - The Annals of Statistics, Ann. Statist. 31(3), 705-767, (June 2003) - """ - - def one_step(rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable): - proposal_generator = slice_proposal(logdensity_fn, n_doublings) - return proposal_generator(rng_key, state) - - return one_step - - -def slice_proposal(logdensity_fn, n_doublings) -> Callable: - def generate(rng_key, state): - order_key, rng_key = random.split(rng_key) - n = state.n - positions, unravel_fn = jax.flatten_util.ravel_pytree(state.position) - widths, _ = jax.flatten_util.ravel_pytree(state.widths) - - def conditional_proposal(rng_key, idx): - return _sample_conditionally( - rng_key, logdensity_fn, idx, positions, widths, n_doublings - ) - - def body_fn(carry, rn): - seed, idx = rn - positions, widths = carry - xi, wi = conditional_proposal(seed, idx) - positions = positions.at[idx].set(xi) - nw = widths[idx] + (wi - widths[idx]) / (n + 1) - widths = widths.at[idx].set(nw) - return (positions, widths), (positions, widths) - - order = random.choice( - order_key, - jnp.arange(len(positions)), - shape=(len(positions),), - replace=False, - ) - - keys = random.split(rng_key, len(positions)) - (new_positions, new_widths), _ = jax.lax.scan( - body_fn, (positions, widths), (keys, order) - ) - - new_positions = unravel_fn(new_positions) - new_widths = unravel_fn(new_widths) - new_state = SliceState( - new_positions, - jnp.atleast_1d(logdensity_fn(new_positions)), - new_widths, - n + 1, - ) - # new_info = SliceInfo(new_widths, n + 1.0) - return new_state, _ - - return generate - - -def _sample_conditionally(seed, logdensity_fn, idx, positions, widths, n_doublings): - def cond_lp_fn(xi_to_set): - return logdensity_fn(positions.at[idx].set(xi_to_set)) - - key, seed1, seed2 = random.split(seed, 3) - x0, w0 = positions[idx], widths[idx] - y = cond_lp_fn(x0) - random.exponential(key) - left, right, _ = _doubling_fn(seed1, y, x0, cond_lp_fn, w0, n_doublings) - x1 = _shrinkage_fn(seed2, y, x0, cond_lp_fn, left, right, w0) - return x1, right - left - - -def _doubling_fn(rng, y, x0, cond_lp_fn, w, n_doublings): - key1, key2 = random.split(rng, 2) - left = x0 - w * random.uniform(key1) - - K = n_doublings + 1 - left_expands = random.bernoulli(key2, 0.5, (K,)) - width_multipliers = 2 ** jnp.arange(0, K, dtype=jnp.int32) - widths = width_multipliers * w - left_increments = jnp.cumsum(widths * left_expands) - - lefts = left - left_increments - rights = left + widths - left_lps = jax.vmap(cond_lp_fn)(lefts) - right_lps = jax.vmap(cond_lp_fn)(rights) - - both_ok = jnp.logical_and(left_lps < y, right_lps < y) - best_interval_idx = _best_interval(both_ok.astype(jnp.int32)) - - return ( - lefts[best_interval_idx], - rights[best_interval_idx], - both_ok[best_interval_idx], - ) - - -def _best_interval(x): - k = x.shape[0] - mults = jnp.arange(2 * k, k, -1, dtype=x.dtype) - shifts = jnp.arange(k, dtype=x.dtype) - indices = jnp.argmax(mults * x + shifts).astype(x.dtype) - return indices - - -def _shrinkage_fn(seed, y, x0, cond_lp_fn, left, right, w): - def cond_fn(state): - *_, found = state - return jnp.logical_not(found) - - def body_fn(state): - x1, left, right, seed, _ = state - key, seed = random.split(seed) - v = random.uniform(key) - x1 = left + v * (right - left) - - found = jnp.logical_and( - y < cond_lp_fn(x1), - _accept_fn(y, x1, x0, cond_lp_fn, left, right, w), - ) - - left = jnp.where(x1 < x0, x1, left) - right = jnp.where(x1 >= x0, x1, right) - - return x1, left, right, seed, found - - key, seed = random.split(seed) - v = random.uniform(key) - x1 = left + v * (right - left) - x1, left, right, seed, _ = jax.lax.while_loop( - cond_fn, body_fn, (x1, left, right, seed, False) - ) - return x1 - - -def _accept_fn(y, x1, x0, cond_lp_fn, left, right, w): - def cond_fn(state): - _, _, left, right, w, _, is_acceptable = state - return jnp.logical_and(right - left > 1.1 * w, is_acceptable) - - def body_fn(state): - x1, x0, left, right, w, D, _ = state - mid = (left + right) / 2 - D = jnp.logical_or( - jnp.logical_or( - jnp.logical_and(x0 < mid, x1 >= mid), - jnp.logical_and(x0 >= mid, x1 < mid), - ), - D, - ) - right = jnp.where(x1 < mid, mid, right) - left = jnp.where(x1 >= mid, mid, left) - - left_is_not_acceptable = y >= cond_lp_fn(left) - right_is_not_acceptable = y >= cond_lp_fn(right) - interval_is_not_acceptable = jnp.logical_and( - left_is_not_acceptable, right_is_not_acceptable - ) - is_still_acceptable = jnp.logical_not( - jnp.logical_and(D, interval_is_not_acceptable) - ) - return x1, x0, left, right, w, D, is_still_acceptable - - *_, is_acceptable = jax.lax.while_loop( - cond_fn, body_fn, (x1, x0, left, right, w, False, True) - ) - return is_acceptable diff --git a/blackjax/ns/__init__.py b/blackjax/ns/__init__.py deleted file mode 100644 index a5e69892c..000000000 --- a/blackjax/ns/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import rejection, inner_kernel - -__all__ = [ - "rejection", - "inner_kernel", -] diff --git a/blackjax/ns/base.py b/blackjax/ns/base.py deleted file mode 100644 index 29775f17b..000000000 --- a/blackjax/ns/base.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024- Will Handley & David Yallup -from typing import Callable, NamedTuple, Optional - -import jax -import jax.numpy as jnp - -import blackjax.ns.base as base -from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey - - -class NSState(NamedTuple): - """State of the Nested Sampler. - - Live points must be a ArrayTree, each leave represents a variable from the posterior, - being an array of size `(nlive, ...)`. - - Examples (three particles): - - Single univariate posterior: - [ Array([[1.], [1.2], [3.4]]) ] - - Single bivariate posterior: - [ Array([[1,2], [3,4], [5,6]]) ] - - Two variables, each univariate: - [ Array([[1.], [1.2], [3.4]]), - Array([[50.], [51], [55]]) ] - - Two variables, first one bivariate, second one 4-variate: - [ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]), - Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ] - """ - - particles: ArrayTree - logL: Array # The log-likelihood of the particles - logL_birth: Array # The hard likelihood threshold of each particle at birth - logL_star: float # The current hard likelihood threshold - create_parameters: ArrayTree # NOTE num_repeats? - # delete_parameters: ArrayTree # NOTE num_repeats? - - -class NSInfo(NamedTuple): - """Additional information on the NS step.""" - - particles: ArrayTree - logL: Array # The log-likelihood of the particles - logL_birth: Array # The hard likelihood threshold of each particle at birth - - -def init(particles: ArrayLikeTree, logL_fn, init_create_params): - logL_star = -jnp.inf - num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - logL_birth = logL_star * jnp.ones(num_particles) - logL = logL_fn(particles) - return NSState(particles, logL, logL_birth, logL_star, init_create_params) - - -def build_kernel( - log_density_fn: Callable, - logL_fn: Callable, - create_fn: Callable, - delete_fn: Callable, -) -> Callable: - r"""Build a Nested Sampling by running a creation and deletion step. - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - create_fn: Callable - Function that takes an array of keys and particles and returns - new particles. - delete_fn: Callable - Function that takes an array of keys and particles and deletes some - particles. - - Returns - ------- - A callable that takes a rng_key and a NSState that contains the current state - of the chain and that returns a new state of the chain along with - information about the transition. - - """ - - def kernel( - rng_key: PRNGKey, state: base.NSState - ) -> tuple[base.NSState, base.NSInfo]: - # Create new particles - create_parameters = state.create_parameters - # particles, create_info = create_fn(rng_key, state.particles, logL_fn, state.logL_star, create_parameters) - # num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] # Not good jax -- improve with sgpt - logL_birth = state.logL_star - val, dead_idx = delete_fn(rng_key, state.logL, state.create_parameters) - - dead_particles = jax.tree.map(lambda x: x[dead_idx], state.particles) - dead_logL = state.logL[dead_idx] - dead_logL_birth = state.logL_birth[dead_idx] - - new_particles, new_particles_logL = create_fn( - rng_key, dead_particles, log_density_fn, logL_fn, -val.min(), create_parameters - ) - logL_births = -val.min() * jnp.ones(dead_idx.shape) - - particles = state.particles.at[dead_idx].set(new_particles) - logL = state.logL.at[dead_idx].set(new_particles_logL) - logL_birth = state.logL_birth.at[dead_idx].set(logL_births) - logL_star = state.logL.min() - - return base.NSState( - particles, - logL, - logL_birth, - logL_star, - state.create_parameters, - # state.delete_parameters, - ), base.NSInfo(dead_particles, dead_logL, dead_logL_birth) - - return kernel - - -def as_top_level_api( - logL_fn: Callable, - create_fn: Callable, - delete_fn: Callable, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - create_fn: Callable - Function that takes an array of keys and particles and returns - new particles. - delete_fn: Callable - Function that takes an array of keys and particles and deletes some - particles. - Returns - ------- - A ``SamplingAlgorithm``. - - """ - kernel = build_kernel( - logL_fn, - create_fn, - delete_fn, - ) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, logL_fn, {}) - - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state) - - return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/inner_kernel.py b/blackjax/ns/inner_kernel.py deleted file mode 100644 index 9b8af206d..000000000 --- a/blackjax/ns/inner_kernel.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2024- Will Handley & David Yallup -from typing import Callable, NamedTuple, Optional -from typing import Callable, Dict, NamedTuple, Tuple -import jax -import jax.numpy as jnp - -import blackjax.ns.base as base -from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.ns.base import NSInfo, NSState -from blackjax.ns.base import init as init_base -from functools import partial -from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride - -__all__ = ["init", "as_top_level_api", "build_kernel"] - - -class NSState(NamedTuple): - """State of the Nested Sampler.""" - - particles: ArrayTree - logL: Array # The log-likelihood of the particles - logL_birth: ( - Array # The hard likelihood threshold of each particle at birth - ) - logL_star: float # The current hard likelihood threshold - - -class NSInfo(NamedTuple): - """Additional information on the NS step.""" - - particles: ArrayTree - logL: Array # The log-likelihood of the particles - logL_birth: ( - Array # The hard likelihood threshold of each particle at birth - ) - - -def init_base(particles: ArrayLikeTree, loglikelihood_fn): - logL_star = -jnp.inf - num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - logL_birth = logL_star * jnp.ones(num_particles) - logL = loglikelihood_fn(particles) - return NSState(particles, logL, logL_birth, logL_star) - - -def init(position, loglikelihood_fn, initial_parameter_value): - return StateWithParameterOverride( - init_base(position, loglikelihood_fn), initial_parameter_value - ) - - -def build_kernel( - logprior_fn: Callable, - loglikelihood_fn: Callable, - delete_fn: Callable, - contour_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameter_update_fn: Callable[ - [NSState, NSInfo], Dict[str, ArrayTree] - ], - num_mcmc_steps: int = 10, - **extra_parameters, -) -> Callable: - r"""Build a Nested Sampling by running a creation and deletion step. - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - create_fn: Callable - Function that takes an array of keys and particles and returns - new particles. - delete_fn: Callable - Function that takes an array of keys and particles and deletes some - particles. - - Returns - ------- - A callable that takes a rng_key and a NSState that contains the current state - of the chain and that returns a new state of the chain along with - information about the transition. - - """ - - def kernel( - rng_key: PRNGKey, - state: base.NSState, - **extra_step_parameters, - ) -> tuple[base.NSState, base.NSInfo]: - - logL_birth = state.sampler_state.logL_star - val, dead_idx = delete_fn(state.sampler_state.logL) - - dead_particles = jax.tree.map( - lambda x: x[dead_idx], state.sampler_state.particles - ) - dead_logL = state.sampler_state.logL[dead_idx] - dead_logL_birth = state.sampler_state.logL_birth[dead_idx] - - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in extra_step_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - - shared_mcmc_step_fn = partial( - mcmc_step_fn, logdensity=logprior_fn, **state.parameter_override - ) - - contour_check_fn = lambda x: x <= -val.min() - - def particle_map(xs): - xs,rng = xs - state = mcmc_init_fn(xs, logprior_fn) - - def chain_scan(carry, xs): - """Middle loop to scan over required MCMC steps.""" - - def cond_fun(carry): - # _, _, logL, MHaccept = carry - _, _, logL = carry - - return contour_check_fn(logL) #& jnp.logical_not(MHaccept) - - def inner_chain(carry): - """Inner most while to check steps are in contour""" - # key, state, _, _ = carry - key, state, _ = carry - rng_key, subkey = jax.random.split(key) - new_state, info = shared_mcmc_step_fn(subkey, state) - logL = loglikelihood_fn(new_state.position) - # return rng_key, new_state, logL, info.is_accepted - return rng_key, new_state, logL - - state, _ = carry - rng_key, step_key = jax.random.split(xs[0]) - # _, state, logL, _ = jax.lax.while_loop( - # cond_fun, inner_chain, (step_key, state, -jnp.inf, False) - # ) - _, state, logL = jax.lax.while_loop( - cond_fun, inner_chain, (step_key, state, -jnp.inf) - ) - return (state, logL), (rng_key, state, logL) - - (fs, fl), (rng, s, l) = jax.lax.scan( - chain_scan, (state, -jnp.inf), (rng, jnp.zeros(rng.shape[0])) - ) - return fs.position, fl - - scan_keys = jax.random.split( - rng_key, (*dead_idx.shape, num_mcmc_steps) - ) - - # particle_map((dead_particles[0], scan_keys[0])) - - new_pos,new_logl = jax.pmap(particle_map)((dead_particles, scan_keys)) - logL_births = -val.min() * jnp.ones(dead_idx.shape) - - particles = state.sampler_state.particles.at[dead_idx].set( - new_pos.squeeze() - ) - logL = state.sampler_state.logL.at[dead_idx].set(new_logl.squeeze()) - logL_birth = state.sampler_state.logL_birth.at[dead_idx].set( - logL_births - ) - logL_star = state.sampler_state.logL.min() - - state = NSState( - particles, - logL, - logL_birth, - logL_star, - ) - info = NSInfo(dead_particles, dead_logL, dead_logL_birth) - new_parameter_override = mcmc_parameter_update_fn(state, info) - return StateWithParameterOverride(state, new_parameter_override), info - - return kernel - - -def delete_fn(logL, n_delete): - val, idx = jax.lax.top_k(-logL, n_delete) - return val, idx - - -def contour_fn(logL, lstar): - return logL <= lstar - - -def as_top_level_api( - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameter_update_fn: Callable[ - [NSState, NSInfo], Dict[str, ArrayTree] - ], - mcmc_initial_parameters: dict, - num_mcmc_steps: int = 10, - n_delete: int = 1, - **extra_parameters, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - create_fn: Callable - Function that takes an array of keys and particles and returns - new particles. - delete_fn: Callable - Function that takes an array of keys and particles and deletes some - particles. - Returns - ------- - A ``SamplingAlgorithm``. - - """ - delete_func = partial(delete_fn, n_delete=n_delete) - - kernel = build_kernel( - logprior_fn, - loglikelihood_fn, - delete_func, - contour_fn, - mcmc_step_fn, - mcmc_init_fn, - mcmc_parameter_update_fn, - num_mcmc_steps, - **extra_parameters, - ) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, loglikelihood_fn, mcmc_initial_parameters) - - def step_fn(rng_key: PRNGKey, state, **extra_parameters): - return kernel(rng_key, state, **extra_parameters) - - return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/mh.py b/blackjax/ns/mh.py deleted file mode 100644 index 16ce7a1c0..000000000 --- a/blackjax/ns/mh.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024- Will Handley & David Yallup -from typing import Callable - -import jax -import jax.numpy as jnp - -import blackjax.ns.base as base -from blackjax.base import SamplingAlgorithm -from blackjax.types import ArrayLikeTree, PRNGKey - -from blackjax.ns.base import init, build_kernel - -__all__ = ["init", "as_top_level_api", "build_kernel"] - -def create_fn(rng_key, particles, logL_fn, logL_star, create_parameters): - num_particles, ndims = jax.tree_util.tree_flatten(particles)[0][0].shape - - def cond_fun(carry): - - _, logL, _, _ = carry - return logL <= logL_star - - def body_fun(carry, xs): - rng_key, _, _, mh_accept = carry - rng_key, subkey = jax.random.split(rng_key) - particle = jax.random.uniform(subkey, (ndims,)) - mh_accept = jnp.logical_or(jax.random.uniform(subkey) > 0.5, mh_accept) - logL = logL_fn(particle) - return (rng_key, logL, particle, mh_accept), particle - - - new_particles = jax.lax.scan(body_fun, (rng_key, -jnp.inf, jnp.zeros(ndims), jnp.zeros(ndims, dtype=bool)), particles) - - init_val = (rng_key, -jnp.inf, jnp.zeros(ndims), jnp.zeros(ndims, dtype=bool)) - final_rng_key, final_logL, final_particle = jax.lax.while_loop(cond_fun, body_fun, init_val) - - return jnp.array([final_particle]), { "logL": jnp.array([final_logL]) } - - -def delete_fn(rng_key, logL, delete_parameters): - idx = logL > logL.min() - return idx - - -def as_top_level_api( - logL_fn: Callable, -) -> SamplingAlgorithm: - """Implements a rejection sampling nested sampling algo - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - Returns - ------- - A ``SamplingAlgorithm``. - - """ - return base.as_top_level_api(logL_fn, create_fn, delete_fn) diff --git a/blackjax/ns/rejection.py b/blackjax/ns/rejection.py deleted file mode 100644 index f4b25093a..000000000 --- a/blackjax/ns/rejection.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2024- Will Handley & David Yallup -from functools import partial -from typing import Callable - -import jax -import jax.numpy as jnp - -import blackjax.ns.base as base -from blackjax.base import SamplingAlgorithm -from blackjax.ns.base import build_kernel, init -from blackjax.types import ArrayLikeTree, PRNGKey - -__all__ = ["init", "as_top_level_api", "build_kernel"] - - -def create_fn(rng_key, particles, prior, logL_fn, logL_star, create_parameters): - # num_particles, ndims = jax.tree_util.tree_flatten(particles)[0][0].shape - # ndims = jax.tree_util.tree_flatten(particles)[0][0].shape - num_particles, ndims = particles.shape - - def body_fun(carry, xs): - def cond_fun(carry): - _, logL, _ = carry - return logL <= logL_star - - def inner_body(carry): - rng_key, _, _ = carry - rng_key, subkey = jax.random.split(rng_key) - particle = prior(seed=subkey) - logL = logL_fn(particle) - return rng_key, logL, particle - - rng_key = carry - rng_key, step_rng = jax.random.split(rng_key) - _, final_logL, particle = jax.lax.while_loop( - cond_fun, inner_body, (step_rng, -jnp.inf, jnp.zeros(ndims)) - ) - return rng_key, (particle, final_logL) - - logLs = jnp.ones(num_particles) * -jnp.inf - rng_key, init_key = jax.random.split(rng_key) - _, new_particles = jax.lax.scan(body_fun, init_key, (particles, logLs)) - - return new_particles[0], new_particles[1] - - -def delete_fn(rng_key, logL, delete_parameters, n_delete): - val, idx = jax.lax.top_k(-logL, n_delete) - return val, idx - - -def as_top_level_api( - logPrior_fn: Callable, - logL_fn: Callable, - n_delete: int = 1, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. - - Parameters - ---------- - logL_fn: Callable - A function that assigns a weight to the particles. - create_fn: Callable - Function that takes an array of keys and particles and returns - new particles. - delete_fn: Callable - Function that takes an array of keys and particles and deletes some - particles. - Returns - ------- - A ``SamplingAlgorithm``. - - """ - - delete_func = partial(delete_fn, n_delete=n_delete) - - kernel = build_kernel( - logPrior_fn, - logL_fn, - create_fn, - delete_func, - ) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, logL_fn, {}) - - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state) - - return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 6f7d627c5..2a63fd1ce 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -121,7 +121,6 @@ def as_top_level_api( """ - kernel = build_kernel( smc_algorithm, logprior_fn, diff --git a/test_ns.py b/test_ns.py deleted file mode 100644 index b169cdfbc..000000000 --- a/test_ns.py +++ /dev/null @@ -1,182 +0,0 @@ -import multiprocessing -import os -from datetime import date - -num_cores = multiprocessing.cpu_count() -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(num_cores) - -import anesthetic as ns -import distrax -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -from jax.scipy.stats import multivariate_normal - -import blackjax -import blackjax.progress_bar -from blackjax import irmh -from blackjax.progress_bar import progress_bar_scan -from blackjax.smc.tuning.from_particles import ( - mass_matrix_from_particles, - particles_covariance_matrix, - particles_means, - particles_stds, -) - -################################################################################## -# Setup the problem -################################################################################## - -rng_key = jax.random.PRNGKey(2) -d = 5 - -np.random.seed(1) -C = np.random.randn(d, d) * 0.05 -like_cov = C @ C.T -like_mean = np.random.randn(d) * 2 - - -def loglikelihood(x): - return multivariate_normal.logpdf(x, mean=like_mean, cov=like_cov) - - -n_samples = 500 -n_delete = num_cores -rng_key, init_key, sample_key = jax.random.split(rng_key, 3) - -prior = distrax.MultivariateNormalDiag(loc=jnp.zeros(d), scale_diag=jnp.ones(d)) - - -################################################################################## -# Configure the NS kernel -################################################################################## - -kernel = irmh.build_kernel() - - -def mcmc_step_fn(key, state, logdensity, means, cov): - proposal_distribution = lambda key: jax.random.multivariate_normal(key, means, cov) - - def proposal_logdensity_fn(proposal, state): - return jax.scipy.stats.multivariate_normal.logpdf( - state.position, mean=means, cov=cov - ).squeeze() - - return kernel(key, state, logdensity, proposal_distribution, proposal_logdensity_fn) - - -def mcmc_parameter_update_fn(state, info): - cov = jnp.atleast_2d(particles_covariance_matrix(state.particles)) - mean = particles_means(state.particles) - return {"means": mean, "cov": cov} - - -initial_state = prior._sample_n(rng_key, n_samples) -means = particles_means(initial_state) -cov = particles_covariance_matrix(initial_state) -init_params = {"means": means, "cov": cov} - - -""" -Setup the Nested Sampling algorithm -Provide compulsary functions: -logprior_fn: log prior density function -loglikelihood_fn: log likelihood density function #TODO combine the two with logl as a mask - -mcmc_step_fn: inner MCMC algorithm step function to evolve the particles -mcmc_init_fn: corresponding initialization function for the inner kernel -mcmc_parameter_update_fn: function to tune the parameters of the mcmc step -mcmc_initial_parameters: initial parameters for the inner kernel -- effectively call the parameter update fn on the initial pop - -Specific settings for the NS algorithm: -n_delete: number of points to delete at each iteration - jax will pmap over this, so it is detected automatically in this script as the number of available cpu cores -num mcmc steps: number of successful steps to take in the inner kernel - n_repeats in polychord language -""" -algo = blackjax.inner_kernel_ns( - logprior_fn=lambda x: prior.log_prob(x).sum().squeeze(), - loglikelihood_fn=loglikelihood, - mcmc_step_fn=mcmc_step_fn, - mcmc_init_fn=blackjax.rmh.init, - mcmc_parameter_update_fn=mcmc_parameter_update_fn, - n_delete=n_delete, - mcmc_initial_parameters=init_params, - num_mcmc_steps=5, -) - -# Initialize the ns state -state = algo.init(initial_state, loglikelihood) - - -# request 1000 steps of the NS kernel, currently this is fixed, and compresses for n_delete * n_steps rounds -# simplest design pattern is to put this in an outer while loop, and break when some convergence criteria is met -# currently there is no safety check in this compression so it can hang with too many steps, or not a good enough inner kernel -n_steps = 1000 - - -@progress_bar_scan(n_steps) -def one_step(carry, xs): - state, k = carry - k, subk = jax.random.split(k, 2) - state, dead_point = algo.step(subk, state) - return (state, k), dead_point - - -################################################################################## -# run the ns kernel -################################################################################## - -iterations = jnp.arange(n_steps) -(live, _), dead = jax.lax.scan((one_step), (state, rng_key), iterations) - - -# comment out the above scan and uncomment this for debugging -# with jax.disable_jit(): -# for i in range(10): -# rng_key, sample_key = jax.random.split(rng_key) -# state, info = algo.step(sample_key, state) - -################################################################################## -# Collect the samples into anesthetic objects -################################################################################## - -dead_points = dead.particles.squeeze() -live_points = live.sampler_state.particles.squeeze() -# live_logL = live.sampler_state.logL - - -samples = ns.NestedSamples( - data=np.concatenate([live_points, dead_points.reshape(-1, d)], axis=0), - logL=np.concatenate([live.sampler_state.logL, dead.logL.squeeze().reshape(-1)]), - logL_birth=np.concatenate( - [live.sampler_state.logL_birth, dead.logL_birth.squeeze().reshape(-1)] - ), -) -samples.to_csv("samples.csv") -lzs = samples.logZ(100) -# print(samples.logZ()) -print(f"logZ = {lzs.mean():.2f} ± {lzs.std():.2f}") -from lsbi.model import ReducedLinearModel - -model = ReducedLinearModel( - mu_L=like_mean, - Sigma_L=like_cov, - logLmax=loglikelihood(like_mean), -) - -print(f"True logZ = {model.logZ():.2f}") -a = samples.set_beta(0.0).plot_2d(np.arange(d), figsize=(10, 10)) -# samples.plot_2d(a) -ns.MCMCSamples(model.posterior().rvs(200)).plot_2d(a) -samples.plot_2d(a) -samples.to_csv("post.csv") -a.iloc[0, 0].legend( - ["Prior", "Truth", "NS"], loc="lower left", bbox_to_anchor=(0, 1), ncol=3 -) -plt.suptitle( - f"NS logZ = {lzs.mean():.2f} ± {lzs.std():.2f}, true logZ = {model.logZ():.2f}" -) -plt.savefig("post.pdf") -plt.savefig("post.png", dpi=300) -plt.show() diff --git a/test_slice.py b/test_slice.py deleted file mode 100644 index 5b3db057a..000000000 --- a/test_slice.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -from datetime import date - -import anesthetic as ns -import distrax -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -from jax.scipy.stats import multivariate_normal - -import blackjax -import blackjax.progress_bar -from blackjax import univariate_slice -from blackjax.progress_bar import progress_bar_scan - -################################################################################## -# Setup the problem -################################################################################## - -rng_key = jax.random.PRNGKey(2) -d = 2 - -np.random.seed(1) -C = np.random.randn(d, d) * 0.1 -like_cov = C @ C.T -like_mean = np.random.randn(d) * 2 - - -def loglikelihood(x): - return multivariate_normal.logpdf(x, mean=like_mean, cov=like_cov) - - -n_samples = 500 -rng_key, init_key, sample_key = jax.random.split(rng_key, 3) - -prior = distrax.MultivariateNormalDiag(loc=jnp.zeros(d), scale_diag=jnp.ones(d)) - - -################################################################################## -# Configure the NS kernel -################################################################################## - - -def log_density(x): - return prior.log_prob(x) + loglikelihood(x) - - -algo = univariate_slice(log_density, n_doublings=10) -initial_state = prior.sample(seed=rng_key) -state = algo.init(initial_state) - -n_steps = 500 - - -@progress_bar_scan(n_steps) -def one_step(carry, xs): - state, k = carry - k, subk = jax.random.split(k, 2) - state, info = algo.step(subk, state) - return (state, k), info - - -################################################################################## -# run the ns kernel -################################################################################## - -iterations = jnp.arange(n_steps) -(live, _), dead = jax.lax.scan((one_step), (state, rng_key), iterations) - - -# comment out the above scan and uncomment this for debugging -# with jax.disable_jit(): -# for i in range(10): -# rng_key, sample_key = jax.random.split(rng_key) -# state, info = algo.step(sample_key, state) - -################################################################################## -# Collect the samples into anesthetic objects -################################################################################## - -samples = ns.MCMCSamples(jnp.concatenate(dead[0], axis=0)) -from lsbi.model import ReducedLinearModel - -model = ReducedLinearModel( - mu_L=like_mean, - Sigma_L=like_cov, - logLmax=loglikelihood(like_mean), -) - -print(f"True logZ = {model.logZ():.2f}") -# a = samples.set_beta(0.0).plot_2d(np.arange(d), figsize=(10, 10)) -# samples.plot_2d(a) -a = ns.MCMCSamples(model.posterior().rvs(200)).plot_2d(np.arange(d), figsize=(10, 10)) -samples.plot_2d(a) -samples.to_csv("post.csv") -a.iloc[0, 0].legend(["Truth", "NS"], loc="lower left", bbox_to_anchor=(0, 1), ncol=3) -plt.savefig("post.pdf") -plt.savefig("post.png", dpi=300) -plt.show() From df87345fb71b1890363769a3ae50f0725b99bda7 Mon Sep 17 00:00:00 2001 From: Reuben Date: Thu, 2 Jan 2025 10:18:52 -0500 Subject: [PATCH 02/14] Adjusted MCLMC (#675) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * FIX BAD MERGE * ADJUSTED MCLMC * REMOVE BENCHMARKS: * ADD ADJUSTED MCLMC * GITIGNORE * PRECOMMIT CLEAN UP * FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS * TEMPORARILY ADD BENCHMARKS * ADD ADJUSTED MCLMC TUNING * CLEAN * UNIFY ADJUSTED MCLMC AND MCHMC * ADD INITIAL_POSITION * FIX TEST * CLEAN UP * REMOVE BENCHMARKS * ADD TEST * REMOVE BENCHMARKS * MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR * MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * ADD OMELYAN TEST * ADD ADJUSTED MCLMC TEST * ADD ADJUSTED MCLMC TEST * RENAME O * UPDATE STREAMING AVG * UPDATE STREAMING AVG * FIX MERGE * UPDATE PR * RENAME STD_MAT * RENAME STD_MAT * RENAME STD_MAT * MERGE MAIN * REMOVE COEFFICIENT EXPORTS * REMOVE COEFFICIENT EXPORTS * RESOLVE MYPY ISSUE * RESOLVE MYPY ISSUE * RETURN EXPECTATION HISTORY * FIX KWARG BUG * FIX KWARG BUG * FIX KWARG BUG IN ADJUSTED MCLMC * MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT * L_proposal_factor * SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE * SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE * RENAME STREAMING_AVERAGE_UPDATE ARGS IN ADJUSTED MCLMC ADAPTATION * diagnostics * fix bugs * FIX MINOR TUNING BUGS * UPDATE TUNING * UPDATE TUNING * UPDATE TUNING * names * test * tuning * update * ready for test * ready for test * ready for test * Update blackjax/adaptation/adjusted_mclmc_adaptation.py Co-authored-by: Junpeng Lao * edit --------- Co-authored-by: Junpeng Lao --- blackjax/__init__.py | 4 + .../adaptation/adjusted_mclmc_adaptation.py | 373 ++++++++++++++++++ blackjax/adaptation/mclmc_adaptation.py | 8 +- blackjax/mcmc/__init__.py | 2 + blackjax/mcmc/adjusted_mclmc.py | 257 ++++++++++++ blackjax/mcmc/integrators.py | 10 +- tests/mcmc/test_sampling.py | 99 ++++- 7 files changed, 748 insertions(+), 5 deletions(-) create mode 100644 blackjax/adaptation/adjusted_mclmc_adaptation.py create mode 100644 blackjax/mcmc/adjusted_mclmc.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 5858c34aa..6a0de3809 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -3,6 +3,7 @@ from blackjax._version import __version__ +from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size from .adaptation.chees_adaptation import chees_adaptation from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation @@ -11,6 +12,7 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -110,6 +112,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) @@ -160,6 +163,7 @@ def generate_top_level_api_from(module): "chees_adaptation", "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation + "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", ] diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py new file mode 100644 index 000000000..f5d54e5c9 --- /dev/null +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -0,0 +1,373 @@ +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState +from blackjax.adaptation.step_size import ( + DualAveragingAdaptationState, + dual_averaging_adaptation, +) +from blackjax.diagnostics import effective_sample_size +from blackjax.util import incremental_value_update, pytree_size + +Lratio_lowerbound = 0.0 +Lratio_upperbound = 2.0 + + +def adjusted_mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + target, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + diagonal_preconditioning=True, + params=None, + max="avg", + num_windows=1, + tuning_factor=1.3, +): + """ + Finds the optimal value of the parameters for the MH-MCHMC algorithm. + + Parameters + ---------- + mclmc_kernel + The kernel function used for the MCMC algorithm. + num_steps + The number of MCMC steps that will subsequently be run, after tuning. + state + The initial state of the MCMC algorithm. + rng_key + The random number generator key. + target + The target acceptance rate for the step size adaptation. + frac_tune1 + The fraction of tuning for the first step of the adaptation. + frac_tune2 + The fraction of tuning for the second step of the adaptation. + frac_tune3 + The fraction of tuning for the third step of the adaptation. + diagonal_preconditioning + Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) + max + whether to calculate L from maximum or average eigenvalue. Average is advised. + num_windows + how many iterations of the tuning are carried out + tuning_factor + multiplicative factor for L + + + Returns + ------- + A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + """ + + frac_tune1 /= num_windows + frac_tune2 /= num_windows + frac_tune3 /= num_windows + + dim = pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + + for i in range(num_windows): + window_key = jax.random.fold_in(part1_key, i) + (state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + target=target, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )(state, params, num_steps, window_key) + + if frac_tune3 != 0: + for i in range(num_windows): + part2_key = jax.random.fold_in(part2_key, i) + part2_key1, part2_key2 = jax.random.split(part2_key, 2) + + state, params = adjusted_mclmc_make_adaptation_L( + mclmc_kernel, + frac=frac_tune3, + Lfactor=0.5, + max=max, + eigenvector=eigenvector, + )(state, params, num_steps, part2_key1) + + (state, params, _) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )(state, params, num_steps, part2_key2) + + return state, params + + +def adjusted_mclmc_make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + target, + diagonal_preconditioning, + fix_L_first_da=False, + max="avg", + tuning_factor=1.0, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC""" + + def dual_avg_step(fix_L, update_da): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + def step(iteration_state, weight_and_key): + mask, rng_key = weight_and_key + ( + previous_state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ) = iteration_state + + avg_num_integration_steps = params.L / params.step_size + + state, info = kernel( + rng_key=rng_key, + state=previous_state, + avg_num_integration_steps=avg_num_integration_steps, + step_size=params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + state, + params.step_size, + step_size_max, + info.energy, + ) + + with_mask = lambda x, y: mask * x + (1 - mask) * y + + log_step_size, log_step_size_avg, step, avg_error, mu = update_da( + adaptive_state, info.acceptance_rate + ) + + adaptive_state = DualAveragingAdaptationState( + with_mask(log_step_size, adaptive_state.log_step_size), + with_mask(log_step_size_avg, adaptive_state.log_step_size_avg), + with_mask(step, adaptive_state.step), + with_mask(avg_error, adaptive_state.avg_error), + with_mask(mu, adaptive_state.mu), + ) + + step_size = jax.lax.clamp( + 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 + ) + adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) + + x = ravel_pytree(state.position)[0] + + # update the running average of x, x^2 + previous_weight_and_average = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=previous_weight_and_average, + weight=(1 - mask) * success * step_size, + zero_prevention=mask, + ) + + params = params._replace(step_size=with_mask(step_size, params.step_size)) + if not fix_L: + params = params._replace( + L=with_mask(params.L * (step_size / params.step_size), params.L), + ) + + state_position = state.position + + return ( + state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ), ( + info, + state_position, + ) + + return step + + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): + return jax.lax.scan( + dual_avg_step(fix_L, update_da), + init=( + state, + params, + (initial_da(params.step_size), jnp.inf), # step size max + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, keys), + ) + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + + check_key, rng_key = jax.random.split(rng_key, 2) + + rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) + L_step_size_adaptation_keys_pass1 = jax.random.split( + rng_key_pass1, num_steps1 + num_steps2 + ) + L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) + + # determine which steps to ignore in the streaming average + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, position_samples), + ) = step_size_adaptation( + mask, + state, + params, + L_step_size_adaptation_keys_pass1, + fix_L=fix_L_first_da, + initial_da=initial_da, + update_da=update_da, + ) + + final_stepsize = final_da(dual_avg_state) + params = params._replace(step_size=final_stepsize) + + # determine L + eigenvector = None + if num_steps2 != 0.0: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + + if max == "max": + contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor + + elif max == "avg": + contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor + + else: + raise ValueError("max should be either 'max' or 'avg'") + + change = jax.lax.clamp( + Lratio_lowerbound, + contract(variances) / params.L, + Lratio_upperbound, + ) + params = params._replace( + L=params.L * change, step_size=params.step_size * change + ) + if diagonal_preconditioning: + params = params._replace( + sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) + ) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + jnp.ones(num_steps1), + state, + params, + L_step_size_adaptation_keys_pass2, + fix_L=True, + update_da=update_da, + initial_da=initial_da, + ) + + params = params._replace(step_size=final_da(dual_avg_state)) + + return state, params, eigenvector + + return L_step_size_adaptation + + +def adjusted_mclmc_make_adaptation_L( + kernel, frac, Lfactor, max="avg", eigenvector=None +): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + def step(state, key): + next_state, _ = kernel( + rng_key=key, + state=state, + step_size=params.step_size, + avg_num_integration_steps=params.L / params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan( + f=step, + init=state, + xs=adaptation_L_keys, + ) + + if max == "max": + contract = jnp.min + else: + contract = jnp.mean + + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + + if eigenvector is not None: + flat_samples = jnp.expand_dims( + jnp.einsum("ij,j", flat_samples, eigenvector), 1 + ) + + # number of effective samples per 1 actual sample + ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps + + return state, params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ) + + return adaptation_L + + +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): + """if there are nans, let's reduce the stepsize, and not update the state. The + function returns the old state in this case.""" + + reduced_step_size = 0.8 + p, unravel_fn = ravel_pytree(next_state.position) + nonans = jnp.all(jnp.isfinite(p)) + state, step_size, kinetic_change = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * reduced_step_size, 0.0), + ) + + return nonans, state, step_size, kinetic_change diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 831586201..8452b6171 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -77,6 +77,8 @@ def mclmc_find_L_and_step_size( The trust in the estimate of optimal stepsize. num_effective_samples The number of effective samples for the MCMC algorithm. + diagonal_preconditioning + Whether to do diagonal preconditioning (i.e. a mass matrix) Returns ------- @@ -85,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - std_mat=std_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -137,7 +139,7 @@ def make_L_step_size_adaptation( trust_in_estimate=1.5, num_effective_samples=150, ): - """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + """Adapts the stepsize and L of the MCLMC kernel. Designed for unadjusted MCLMC""" decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 6e207741d..1e1317684 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,4 +1,5 @@ from . import ( + adjusted_mclmc, barker, elliptical_slice, ghmc, @@ -24,4 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", + "adjusted_mclmc", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py new file mode 100644 index 000000000..81fbc2835 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -0,0 +1,257 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index e9d19e3dc..593683ca4 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -414,11 +414,19 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): ------- momentum with random change in angle """ + m, unravel_fn = ravel_pytree(momentum) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) + # return new_momentum + return jax.lax.cond( + jnp.isinf(L), + lambda _: momentum, + lambda _: new_momentum, + operand=None, + ) def with_isokinetic_maruyama(integrator): diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 98572cabc..474f67293 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,5 +1,4 @@ """Test the accuracy of the MCMC kernels.""" - import functools import itertools @@ -15,6 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -146,6 +146,78 @@ def run_mclmc( return samples + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.65 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), + integrator=integrator, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -262,6 +334,31 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_adjusted_mclmc(self): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_mclmc_preconditioning(self): class IllConditionedGaussian: """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" From fc539ca195d81cfa1147e8efd54198c96771a490 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 16 Jan 2025 13:39:27 -0300 Subject: [PATCH 03/14] SMC Pretuning (#765) * extracting taking last * test passing * layering * example * more * Adding another example * tests in place * rolling back changes * Adding test for num_mcmc_steps * format * better test coverage * linter * Flake8 * black * implementation[ * partial posteriors implementation * rolling back some changes * linter * fixing test * adding reference * typo * exposing in top level api * reruning precommit * up to now * one step working * fixes * tests passing * checkpoint tests passing * more * tests passing, implementation in place * tests passing * rounding * adding to init * rollbacks * rollback * rollback * docs * precommit * removing extra parameter * code review updates --- blackjax/__init__.py | 3 +- blackjax/smc/from_mcmc.py | 29 +- blackjax/smc/pretuning.py | 346 ++++++++++++++++++++++++ blackjax/smc/tempered.py | 13 +- blackjax/smc/tuning/from_kernel_info.py | 1 + blackjax/smc/tuning/from_particles.py | 2 +- tests/smc/test_pretuning.py | 235 ++++++++++++++++ 7 files changed, 615 insertions(+), 14 deletions(-) create mode 100644 blackjax/smc/pretuning.py create mode 100644 tests/smc/test_pretuning.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..81f8ebd2e 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -38,6 +38,7 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import partial_posteriors_path as _partial_posteriors_smc +from .smc import pretuning as _pretuning from .smc import tempered from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder @@ -124,7 +125,7 @@ def generate_top_level_api_from(module): tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) - +pretuning = generate_top_level_api_from(_pretuning) smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py index 0e60b5968..75e5c34a6 100644 --- a/blackjax/smc/from_mcmc.py +++ b/blackjax/smc/from_mcmc.py @@ -8,6 +8,23 @@ from blackjax.types import PRNGKey +def unshared_parameters_and_step_fn(mcmc_parameters, mcmc_step_fn): + """Splits MCMC parameters into two dictionaries. The shared dictionary + represents the parameters common to all chains, and the unshared are + different per chain. + Binds the step fn using the shared parameters. + """ + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + return unshared_mcmc_parameters, shared_mcmc_step_fn + + def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, @@ -34,15 +51,9 @@ def step( logposterior_fn: Callable, log_weights_fn: Callable, ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in mcmc_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - - shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + mcmc_parameters, mcmc_step_fn + ) update_fn, num_resampled = update_strategy( mcmc_init_fn, diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py new file mode 100644 index 000000000..f489a0dc2 --- /dev/null +++ b/blackjax/smc/pretuning.py @@ -0,0 +1,346 @@ +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp +import jax.random +from jax._src.flatten_util import ravel_pytree + +from blackjax import SamplingAlgorithm, smc +from blackjax.smc.base import SMCInfo, update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.smc.from_mcmc import unshared_parameters_and_step_fn +from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride +from blackjax.smc.resampling import stratified +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_gaussian_noise + + +class SMCInfoWithParameterDistribution(NamedTuple): + """Stores both the sampling status and also a dictionary + with parameter names as keys and (n_particles, *) arrays as values. + The latter represents a parameter per chain for the next mutation step. + """ + + smc_info: SMCInfo + parameter_override: Dict[str, ArrayTree] + + +def esjd(m): + """Implements ESJD (expected squared jumping distance). Inner Mahalanobis distance + is computed using the Cholesky decomposition of M=LLt, and then inverting L. + Whenever M is symmetrical definite positive then it must exist a Cholesky Decomposition. + For example, if M is the Covariance Matrix of Metropolis-Hastings or + the Inverse Mass Matrix of Hamiltonian Monte Carlo. + """ + L = jnp.linalg.cholesky(m) + + def measure(previous_position, next_position, acceptance_probability): + difference = ravel_pytree(previous_position)[0] - ravel_pytree(next_position)[0] + difference_by_matrix = jnp.matmul(L, difference) + norm = jnp.linalg.norm(difference_by_matrix, 2) + return acceptance_probability * jnp.power(norm, 2) + + return jax.vmap(measure) + + +def update_parameter_distribution( + key: PRNGKey, + previous_param_samples: ArrayLikeTree, + previous_particles: ArrayLikeTree, + latest_particles: ArrayLikeTree, + measure_of_chain_mixing: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + acceptance_probability: Array, +): + """Given an existing parameter distribution that was used to mutate previous_particles + into latest_particles, updates that parameter distribution by resampling from previous_param_samples after adding + noise to those samples. The weights used are a linear function of the measure of chain mixing. + Only works with float parameters, not integers. + See Equation 4 in https://arxiv.org/pdf/1005.1193.pdf + + Parameters + ---------- + previous_param_samples: + samples of the parameters of SMC inner MCMC chains. To be updated. + previous_particles: + particles from which the kernel step started + latest_particles: + particles after the step was performed + measure_of_chain_mixing: Callable + a callable that can compute a performance measure per chain + alpha: + a scalar to add to the weighting. See paper for details + sigma_parameters: + noise to add to the population of parameters to mutate them. must have the same shape of + previous_param_samples. + acceptance_probability: + the energy difference for each of the chains when taking a step from previous_particles + into latest_particles. + """ + noise_key, resampling_key = jax.random.split(key, 2) + + noises = jax.tree.map( + lambda x, s: generate_gaussian_noise(noise_key, x.astype("float32"), sigma=s), + previous_param_samples, + sigma_parameters, + ) + new_samples = jax.tree.map(lambda x, y: x + y, noises, previous_param_samples) + + chain_mixing_measurement = measure_of_chain_mixing( + previous_particles, latest_particles, acceptance_probability + ) + weights = alpha + chain_mixing_measurement + weights = weights / jnp.sum(weights) + resampling_idx = stratified(resampling_key, weights, len(chain_mixing_measurement)) + return ( + jax.tree.map(lambda x: x[resampling_idx], new_samples), + chain_mixing_measurement, + ) + + +def build_pretune( + mcmc_init_fn: Callable, + mcmc_step_fn: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + n_particles: int, + performance_of_chain_measure_factory: Callable = lambda state: esjd( + state.parameter_override["inverse_mass_matrix"] + ), + natural_parameters: Optional[List[str]] = None, + positive_parameters: Optional[List[str]] = None, +): + """Implements Buchholz et al https://arxiv.org/pdf/1808.07730 pretuning procedure. + The goal is to maintain a probability distribution of parameters, in order + to assign different values to each inner MCMC chain. + To have performant parameters for the distribution at step t, it takes a single step, measures + the chain mixing, and reweights the probability distribution of parameters accordingly. + Note that although similar, this strategy is different than inner_kernel_tuning. The latter updates + the parameters based on the particles and transition information after the SMC step is executed. This + implementation runs a single MCMC step which gets discarded, to then proceed with the SMC step execution. + """ + if natural_parameters is None: + round_to_integer_fn = lambda x: x + else: + + def round_to_integer_fn(x): + for k in natural_parameters: + x[k] = jax.tree.map(lambda a: jnp.abs(jnp.round(a).astype(int)), x[k]) + return x + + if positive_parameters is None: + make_positive_fn = lambda x: x + else: + + def make_positive_fn(x): + for k in positive_parameters: + x[k] = jax.tree.map(jnp.abs, x[k]) + return x + + def pretune(key, state, logposterior): + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + state.parameter_override, mcmc_step_fn + ) + + one_step_fn, _ = update_and_take_last( + mcmc_init_fn, logposterior, shared_mcmc_step_fn, 1, n_particles + ) + + new_state, info = one_step_fn( + jax.random.split(key, n_particles), + state.sampler_state.particles, + unshared_mcmc_parameters, + ) + + performance_of_chain_measure = performance_of_chain_measure_factory(state) + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + key, + previous_param_samples={ + key: state.parameter_override[key] for key in sigma_parameters + }, + previous_particles=state.sampler_state.particles, + latest_particles=new_state, + measure_of_chain_mixing=performance_of_chain_measure, + alpha=alpha, + sigma_parameters=sigma_parameters, + acceptance_probability=info.acceptance_rate, + ) + + return ( + make_positive_fn(round_to_integer_fn(new_parameter_distribution)), + chain_mixing_measurement, + ) + + def pretune_and_update(key, state: StateWithParameterOverride, logposterior): + """ + Updates the parameters that need to be pretuned and returns the rest. + """ + new_parameter_distribution, chain_mixing_measurement = pretune( + key, state, logposterior + ) + old_parameter_distribution = state.parameter_override + updated_parameter_distribution = old_parameter_distribution + for k in new_parameter_distribution: + updated_parameter_distribution[k] = new_parameter_distribution[k] + + return updated_parameter_distribution + + return pretune_and_update + + +def build_kernel( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + pretune_fn: Callable, + num_mcmc_steps: int = 10, + update_strategy=update_and_take_last, + **extra_parameters, +) -> Callable: + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) + + def pretuned_step( + rng_key: PRNGKey, + state, + num_mcmc_steps: int, + mcmc_parameters: dict, + logposterior_fn: Callable, + log_weights_fn: Callable, + ) -> tuple[smc.base.SMCState, SMCInfoWithParameterDistribution]: + """Wraps the output of smc.from_mcmc.build_kernel into a pretuning + step method. + This one should be a subtype of the former, in the sense that a usage of the former + can be replaced with an instance of this one. + """ + + pretune_key, step_key = jax.random.split(rng_key, 2) + pretuned_parameters = pretune_fn( + pretune_key, + StateWithParameterOverride(state, mcmc_parameters), + logposterior_fn, + ) + state, info = delegate( + rng_key, + state, + num_mcmc_steps, + pretuned_parameters, + logposterior_fn, + log_weights_fn, + ) + return state, SMCInfoWithParameterDistribution(info, pretuned_parameters) + + def kernel( + rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + extra_parameters["update_particles_fn"] = pretuned_step + step_fn = smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=state.parameter_override, + resampling_fn=resampling_fn, + num_mcmc_steps=num_mcmc_steps, + **extra_parameters, + ).step + new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) + return ( + StateWithParameterOverride(new_state, info.parameter_override), + info.smc_info, + ) + + return kernel + + +def init(alg_init_fn, position, initial_parameter_value): + return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value) + + +def as_top_level_api( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + num_mcmc_steps: int, + initial_parameter_value: ArrayLikeTree, + pretune_fn: Callable, + **extra_parameters, +): + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + + kernel = build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + pretune_fn, + num_mcmc_steps, + **extra_parameters, + ) + + def init_fn(position, rng_key=None): + del rng_key + return init(smc_algorithm.init, position, initial_parameter_value) + + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 88539deaa..350037f9c 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -55,6 +55,7 @@ def build_kernel( mcmc_init_fn: Callable, resampling_fn: Callable, update_strategy: Callable = update_and_take_last, + update_particles_fn: Optional[Callable] = None, ) -> Callable: """Build the base Tempered SMC kernel. @@ -92,8 +93,12 @@ def build_kernel( information about the transition. """ - delegate = smc_from_mcmc.build_kernel( - mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + update_particles = ( + smc_from_mcmc.build_kernel( + mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + ) + if update_particles_fn is None + else update_particles_fn ) def kernel( @@ -135,7 +140,7 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - smc_state, info = delegate( + smc_state, info = update_particles( rng_key, state, num_mcmc_steps, @@ -162,6 +167,7 @@ def as_top_level_api( resampling_fn: Callable, num_mcmc_steps: Optional[int] = 10, update_strategy=update_and_take_last, + update_particles_fn=None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -196,6 +202,7 @@ def as_top_level_api( mcmc_init_fn, resampling_fn, update_strategy, + update_particles_fn, ) def init_fn(position: ArrayLikeTree, rng_key=None): diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index a039e66c1..fa2c7054c 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -1,4 +1,5 @@ """ +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..279a718cb 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -1,5 +1,5 @@ """ -strategies to tune the parameters of mcmc kernels +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ import jax diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py new file mode 100644 index 000000000..a677c99ae --- /dev/null +++ b/tests/smc/test_pretuning.py @@ -0,0 +1,235 @@ +import unittest + +import chex +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +import blackjax +from blackjax.smc import extend_params, resampling +from blackjax.smc.pretuning import ( + build_pretune, + esjd, + init, + update_parameter_distribution, +) +from tests.smc import SMCLinearRegressionTestCase + + +class TestMeasureOfChainMixing(unittest.TestCase): + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + def test_measure_of_chain_mixing_identity(self): + """ + Given identity matrix and 1. acceptance probability + then the mixing is the square of norm 2. + """ + m = np.eye(2) + + acceptance_probabilities = np.array([1.0, 1.0]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 325) + np.testing.assert_allclose(chain_mixing[1], 100) + + def test_measure_of_chain_mixing_with_non_1_acceptance_rate(self): + """ + Given identity matrix + then the mixing is the square of norm 2. multiplied by the acceptance rate + """ + m = np.eye(2) + + acceptance_probabilities = np.array([0.5, 0.2]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 162.5) + np.testing.assert_allclose(chain_mixing[1], 20) + + def test_measure_of_chain_mixing(self): + m = np.array([[3, 0], [0, 5]]) + + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + acceptance_probabilities = np.array([1.0, 1.0]) + + chain_mixing = esjd(m)( + previous_position, next_position, acceptance_probabilities + ) + + assert chain_mixing.shape == (2,) + np.testing.assert_allclose(chain_mixing[0], 10 * 10 * 3 + 15 * 15 * 5) + np.testing.assert_allclose(chain_mixing[1], 6 * 6 * 3 + 8 * 8 * 5) + + +class TestUpdateParameterDistribution(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + self.previous_position = np.array( + [jnp.array([10.0, 15.0]), jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])] + ) + self.next_position = np.array( + [jnp.array([20.0, 30.0]), jnp.array([10.0, 15.0]), jnp.array([9.0, 12.0])] + ) + + def test_update_param_distribution(self): + """ + Given an extremely good mixing on one chain, + and that the alpha parameter is 0, then the parameters + of that chain with a slight mutation due to noise are reused. + """ + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + jnp.array([1.0, 2.0, 3.0]), + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters=0.0001, + acceptance_probability=None, + ) + + np.testing.assert_allclose( + new_parameter_distribution, + np.array([1, 1, 1], dtype="float32"), + rtol=1e-3, + ) + np.testing.assert_allclose( + chain_mixing_measurement, + np.array([1, 0, 0], dtype="float32"), + rtol=1e-6, + ) + + def test_update_multi_sigmas(self): + """ + When we have multiple parameters, the performance is attached to its combination + so sampling must work accordingly. + """ + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + { + "param_a": jnp.array([1.0, 2.0, 3.0]), + "param_b": jnp.array([[5.0, 6.0], [6.0, 7.0], [4.0, 5.0]]), + }, + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters={"param_a": 0.0001, "param_b": 0.00001}, + acceptance_probability=None, + ) + print(chain_mixing_measurement) + np.testing.assert_allclose(chain_mixing_measurement, np.array([1.0, 0, 0])) + + np.testing.assert_allclose( + new_parameter_distribution["param_a"], jnp.array([1.0, 1.0, 1.0]), atol=0.1 + ) + np.testing.assert_allclose( + new_parameter_distribution["param_b"], + jnp.array([[5.0, 6.0], [5.0, 6.0], [5.0, 6.0]]), + atol=0.1, + ) + + +class PretuningSMCTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_linear_regression(self): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + num_particles = 100 + sampling_key, step_size_key, integration_steps_key = jax.random.split( + self.key, 3 + ) + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (num_particles,), minval=1, maxval=100 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (num_particles,), minval=0, maxval=0.1 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(2)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + assert initial_parameters["step_size"].shape == (num_particles,) + assert initial_parameters["num_integration_steps"].shape == (num_particles,) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=1, + n_particles=num_particles, + sigma_parameters={"step_size": 0.01, "num_integration_steps": 2}, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + step = blackjax.smc.pretuning.build_kernel( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + ) + + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + smc_kernel = self.variant(step) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + self.assert_linear_regression_test_case(result.sampler_state) + assert set(result.parameter_override.keys()) == { + "step_size", + "num_integration_steps", + "inverse_mass_matrix", + } + assert result.parameter_override["step_size"].shape == (num_particles,) + assert result.parameter_override["num_integration_steps"].shape == ( + num_particles, + ) + assert all(result.parameter_override["step_size"] > 0) + assert all(result.parameter_override["num_integration_steps"] > 0) + + +if __name__ == "__main__": + absltest.main() From a0812beddc9d3b80fa46e324905c697f25ba1324 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Jan 2025 17:43:56 +0100 Subject: [PATCH 04/14] Remove meeting scheduling (#768) * Remove meeting scheduling * Fix tests --- .github/workflows/schedule-meeting.yml | 18 ------------------ tests/mcmc/test_integrators.py | 21 +++++++++++++++++---- tests/mcmc/test_proposal.py | 22 ++++++++-------------- tests/mcmc/test_sampling.py | 17 +++++++++-------- tests/smc/test_smc.py | 2 +- 5 files changed, 35 insertions(+), 45 deletions(-) delete mode 100644 .github/workflows/schedule-meeting.yml diff --git a/.github/workflows/schedule-meeting.yml b/.github/workflows/schedule-meeting.yml deleted file mode 100644 index 0575bd20f..000000000 --- a/.github/workflows/schedule-meeting.yml +++ /dev/null @@ -1,18 +0,0 @@ -# Open a Meeting issue the 25th day of the month. -# Meetings happen on the first Friday of the month -name: Open a meeting issue -on: - schedule: - - cron: '0 0 20 * *' - workflow_dispatch: - -jobs: - create-meeting-issue: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: JasonEtco/create-an-issue@v2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - filename: .github/ISSUE_TEMPLATE/meeting.md diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..362496629 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -77,10 +77,23 @@ def kinetic_energy(p, position=None): "c": jnp.ones((2, 1)), } _, unravel_fn = ravel_pytree(mvnormal_position_init) -key0, key1 = jax.random.split(jax.random.key(52)) -mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,))) -a = jax.random.normal(key1, (6, 6)) -cov = jnp.matmul(a.T, a) +mvnormal_momentum_init = { + "a": jnp.asarray(0.53288144), + "b": jnp.asarray([0.25310317, 1.3788314, -0.13486017]), + "c": jnp.asarray([[-0.59082425], [1.2088736]]), +} + +cov = jnp.asarray( + [ + [5.9959664, 1.1494889, -1.0420643, -0.6328479, -0.20363973, 2.1600752], + [1.1494889, 1.3504763, -0.3601517, -0.98311526, 1.1569028, -1.4185406], + [-1.0420643, -0.3601517, 6.3011055, -2.0662997, -0.10126236, 1.2898219], + [-0.6328479, -0.98311526, -2.0662997, 4.82699, -2.575554, 2.5724294], + [-0.20363973, 1.1569028, -0.10126236, -2.575554, 3.35319, -2.9411654], + [2.1600752, -1.4185406, 1.2898219, 2.5724294, -2.9411654, 6.3740206], + ] +) + # Validated numerically mvnormal_position_end = unravel_fn( jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426]) diff --git a/tests/mcmc/test_proposal.py b/tests/mcmc/test_proposal.py index 3a0c3ac38..391a66656 100644 --- a/tests/mcmc/test_proposal.py +++ b/tests/mcmc/test_proposal.py @@ -2,6 +2,7 @@ import jax import numpy as np import pytest +from absl.testing import parameterized from jax import numpy as jnp from blackjax.mcmc.random_walk import normal @@ -10,25 +11,18 @@ class TestNormalProposalDistribution(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.key(20220611) + self.key = jax.random.key(20250120) - def test_normal_univariate(self): + @parameterized.parameters([10.0, 15000.0]) + def test_normal_univariate(self, initial_position): """ Move samples are generated in the univariate case, with std following sigma, and independently of the position. """ - key1, key2 = jax.random.split(self.key) + keys = jax.random.split(self.key, 200) proposal = normal(sigma=jnp.array([1.0])) - samples_from_initial_position = [ - proposal(key, jnp.array([10.0])) for key in jax.random.split(key1, 100) - ] - samples_from_another_position = [ - proposal(key, jnp.array([15000.0])) for key in jax.random.split(key2, 100) - ] - - for samples in [samples_from_initial_position, samples_from_another_position]: - np.testing.assert_allclose(0.0, np.mean(samples), rtol=1e-2, atol=1e-1) - np.testing.assert_allclose(1.0, np.std(samples), rtol=1e-2, atol=1e-1) + samples = [proposal(key, jnp.array([initial_position])) for key in keys] + self._check_mean_and_std(jnp.array([0.0]), jnp.array([1.0]), samples) def test_normal_multivariate(self): proposal = normal(sigma=jnp.array([1.0, 2.0])) @@ -61,7 +55,7 @@ def _check_mean_and_std(expected_mean, expected_std, samples): ) np.testing.assert_allclose( expected_std, - np.sqrt(np.diag(np.cov(np.array(samples).T))), + np.sqrt(np.diag(np.atleast_2d(np.cov(np.array(samples).T)))), rtol=1e-2, atol=1e-1, ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 474f67293..653ba0dac 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,4 +1,5 @@ """Test the accuracy of the MCMC kernels.""" + import functools import itertools @@ -331,8 +332,8 @@ def test_mclmc(self): coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) def test_adjusted_mclmc(self): """Test the MCLMC kernel.""" @@ -356,8 +357,8 @@ def test_adjusted_mclmc(self): coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) def test_mclmc_preconditioning(self): class IllConditionedGaussian: @@ -607,8 +608,8 @@ def test_barker(self): coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) class SGMCMCTest(chex.TestCase): @@ -861,7 +862,7 @@ def test_irmh(self): @chex.all_variants(with_pmap=False) def test_nuts(self): inference_algorithm = blackjax.nuts( - self.normal_logprob, step_size=4.0, inverse_mass_matrix=jnp.array([1.0]) + self.normal_logprob, step_size=1.0, inverse_mass_matrix=jnp.array([1.0]) ) initial_state = inference_algorithm.init(jnp.array(3.0)) @@ -1021,7 +1022,7 @@ def test_barker(self): }, { "algorithm": blackjax.barker_proposal, - "parameters": {"step_size": 0.5}, + "parameters": {"step_size": 0.45}, "is_mass_matrix_diagonal": None, }, ] diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index b0e86e0b0..769078c8d 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -79,7 +79,7 @@ def test_smc_waste_free(self): {}, ) same_for_all_params = dict( - step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=100 ) hmc_kernel = functools.partial( blackjax.hmc.build_kernel(), **same_for_all_params From 4d4eae01b7cd4b002ceff746e239cf81937096e3 Mon Sep 17 00:00:00 2001 From: Reuben Date: Tue, 21 Jan 2025 09:15:08 -0500 Subject: [PATCH 05/14] Adjusted MCLMC (#771) * test CI * test CI * test CI: add static * test CI: add static * test CI: add static tests * Revert "test CI: add static" This reverts commit 2db919d90fa729261a817115caa5b1e76d3708e0. * Revert "test CI: add static" This reverts commit fa6558f2c8e254909b0ed663f2208da458230ac1. * test CI: add static tests * test CI: add static tests * test CI: add static tests * test CI: old tests * test CI: old tests * test CI: old tests with addition * test CI: old tests with addition of num tuning steps --- blackjax/__init__.py | 2 + .../adaptation/adjusted_mclmc_adaptation.py | 73 +++-- blackjax/adaptation/mclmc_adaptation.py | 30 +- blackjax/mcmc/__init__.py | 2 + blackjax/mcmc/adjusted_mclmc.py | 79 +++--- blackjax/mcmc/adjusted_mclmc_dynamic.py | 259 ++++++++++++++++++ blackjax/mcmc/integrators.py | 12 +- blackjax/mcmc/mclmc.py | 8 +- tests/mcmc/test_integrators.py | 4 +- tests/mcmc/test_sampling.py | 131 +++++++-- 10 files changed, 487 insertions(+), 113 deletions(-) create mode 100644 blackjax/mcmc/adjusted_mclmc_dynamic.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 81f8ebd2e..ef5eabd79 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,6 +13,7 @@ from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -113,6 +114,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f5d54e5c9..408c31383 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -74,14 +74,20 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 for i in range(num_windows): window_key = jax.random.fold_in(part1_key, i) - (state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + eigenvector, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, window_key) + )( + state, params, num_steps, window_key + ) + total_num_tuning_integrator_steps += num_tuning_integrator_steps if frac_tune3 != 0: for i in range(num_windows): part2_key = jax.random.fold_in(part2_key, i) part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params = adjusted_mclmc_make_adaptation_L( + ( + state, + params, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector, - )(state, params, num_steps, part2_key1) + )( + state, params, num_steps, part2_key1 + ) + + total_num_tuning_integrator_steps += num_tuning_integrator_steps - (state, params, _) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + _, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, part2_key2) + )( + state, params, num_steps, part2_key2 + ) + + total_num_tuning_integrator_steps += num_tuning_integrator_steps - return state, params + return state, params, total_num_tuning_integrator_steps def adjusted_mclmc_make_L_step_size_adaptation( @@ -152,7 +178,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) # step updating @@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): update_da=update_da, ) + num_tuning_integrator_steps = info.num_integration_steps.sum() final_stepsize = final_da(dual_avg_state) params = params._replace(step_size=final_stepsize) @@ -283,9 +310,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace( - sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) - ) + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -301,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): initial_da=initial_da, ) + num_tuning_integrator_steps += info.num_integration_steps.sum() + params = params._replace(step_size=final_da(dual_avg_state)) - return state, params, eigenvector + return state, params, eigenvector, num_tuning_integrator_steps return L_step_size_adaptation @@ -318,16 +345,16 @@ def adaptation_L(state, params, num_steps, key): adaptation_L_keys = jax.random.split(key, num_steps) def step(state, key): - next_state, _ = kernel( + next_state, info = kernel( rng_key=key, state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) - return next_state, next_state.position + return next_state, (next_state.position, info) - state, samples = jax.lax.scan( + state, (samples, info) = jax.lax.scan( f=step, init=state, xs=adaptation_L_keys, @@ -348,10 +375,14 @@ def step(state, key): # number of effective samples per 1 actual sample ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps - return state, params._replace( - L=jnp.clip( - Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound - ) + return ( + state, + params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ), + info.num_integration_steps.sum(), ) return adaptation_L diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 8452b6171..60fd46359 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -123,10 +123,10 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) - return state, params + return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3) def make_L_step_size_adaptation( @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) + num_steps_3 = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): next_state, _ = kernel( @@ -297,7 +297,7 @@ def step(state, key): ess = effective_sample_size(flat_samples[None, ...]) return state, params._replace( - L=Lfactor * params.step_size * jnp.mean(num_steps / ess) + L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess) ) return adaptation_L diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 1e1317684..8acb28274 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,5 +1,6 @@ from . import ( adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -25,5 +26,6 @@ "marginal_latent_gaussian", "random_walk", "mclmc", + "adjusted_mclmc_dynamic", "adjusted_mclmc", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 81fbc2835..f390402f2 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" from typing import Callable, Union import jax @@ -19,28 +23,26 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.hmc import HMCInfo, HMCState from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] -def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + return HMCState(position, logdensity, logdensity_grad) def build_kernel( - integration_steps_fn, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): - """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -63,20 +65,20 @@ def build_kernel( def kernel( rng_key: PRNGKey, - state: DynamicHMCState, - logdensity_fn: Callable, + state: HMCState, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) - key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -90,11 +92,10 @@ def kernel( ) return ( - DynamicHMCState( + HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, - next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -106,14 +107,13 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + num_integration_steps, ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- @@ -140,23 +140,23 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, + logdensity_fn=logdensity_fn, integrator=integrator, - next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - return init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -240,18 +240,3 @@ def generate( return sampled_state, info, other_proposal_info return generate - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..1a69e1a28 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,259 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + inverse_mass_matrix=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + inverse_mass_matrix=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + inverse_mass_matrix=inverse_mass_matrix, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 593683ca4..0effa204e 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..ff9638a1f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """Build a HMC kernel. Parameters @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -107,7 +107,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 362496629..fd7af4450 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -251,7 +251,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -276,7 +276,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 653ba0dac..4d8a9fa61 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -15,7 +15,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -113,15 +113,16 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -134,7 +135,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -159,18 +160,18 @@ def run_adjusted_mclmc( init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -183,6 +184,7 @@ def run_adjusted_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -198,14 +200,82 @@ def run_adjusted_mclmc( step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + + def run_adjusted_mclmc_static( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=avg_num_integration_steps, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -360,6 +430,31 @@ def test_adjusted_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) + def test_adjusted_mclmc_static(self): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_static( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) + def test_mclmc_preconditioning(self): class IllConditionedGaussian: """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" @@ -400,7 +495,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -411,16 +506,13 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, @@ -428,13 +520,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From a053bed036f53557cadb905bc23aad1c65b82a88 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 22 Jan 2025 10:31:06 -0300 Subject: [PATCH 06/14] test in place (#772) --- tests/smc/test_pretuning.py | 99 ++++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 24 deletions(-) diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py index a677c99ae..d24996eaf 100644 --- a/tests/smc/test_pretuning.py +++ b/tests/smc/test_pretuning.py @@ -145,13 +145,85 @@ def test_update_multi_sigmas(self): ) +def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state, *_ = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state, curr_loglikelihood = carry + subkey = jax.random.fold_in(rng_key, i) + state, info = kernel(subkey, state) + return i + 1, state, curr_loglikelihood + info.log_likelihood_increment + + total_iter, final_state, log_likelihood = jax.lax.while_loop( + cond, body, (0, initial_state, 0.0) + ) + return final_state + + class PretuningSMCTest(SMCLinearRegressionTestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) @chex.variants(with_jit=True) - def test_linear_regression(self): + def test_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + return result + + self.linear_regression_test_case(step_provider, loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + target_ess=0.5, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + def linear_regression_test_case(self, step_provider, loop): ( init_particles, logprior_fn, @@ -191,32 +263,11 @@ def test_linear_regression(self): positive_parameters=["step_size"], ) - step = blackjax.smc.pretuning.build_kernel( - blackjax.tempered_smc, - logprior_fn, - loglikelihood_fn, - blackjax.hmc.build_kernel(), - blackjax.hmc.init, - resampling.systematic, - num_mcmc_steps=10, - pretune_fn=pretune, - ) + step = step_provider(logprior_fn, loglikelihood_fn, pretune) - initial_state = init( - blackjax.tempered_smc.init, init_particles, initial_parameters - ) smc_kernel = self.variant(step) - def body_fn(carry, lmbda): - i, state = carry - subkey = jax.random.fold_in(self.key, i) - new_state, info = smc_kernel(subkey, state, lmbda=lmbda) - return (i + 1, new_state), (new_state, info) - - num_tempering_steps = 10 - lambda_schedule = np.logspace(-5, 0, num_tempering_steps) - - (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + result = loop(smc_kernel, init_particles, initial_parameters) self.assert_linear_regression_test_case(result.sampler_state) assert set(result.parameter_override.keys()) == { "step_size", From 3f0cbb7956765f1c25e0785beb1b6ad8a7092038 Mon Sep 17 00:00:00 2001 From: Hugo Simon-Onfroy <85559558+hsimonfroy@users.noreply.github.com> Date: Wed, 19 Feb 2025 00:14:35 +0100 Subject: [PATCH 07/14] MCLMC adaptation total num steps and initial guess (#778) * total_num_tuning_integrator_steps * Initial params for MCLMC adaptation --- blackjax/adaptation/mclmc_adaptation.py | 34 +++++++++++++++++-------- blackjax/diagnostics.py | 3 +++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 60fd46359..fa644898a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -51,6 +51,7 @@ def mclmc_find_L_and_step_size( trust_in_estimate=1.5, num_effective_samples=150, diagonal_preconditioning=True, + params=None, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -79,6 +80,8 @@ def mclmc_find_L_and_step_size( The number of effective samples for the MCMC algorithm. diagonal_preconditioning Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) Returns ------- @@ -105,10 +108,19 @@ def mclmc_find_L_and_step_size( ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) - ) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 + + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 + ) + num_steps2 += diagonal_preconditioning * (num_steps2 // 3) + num_steps3 = round(num_steps * frac_tune3) state, params = make_L_step_size_adaptation( kernel=mclmc_kernel, @@ -120,13 +132,15 @@ def mclmc_find_L_and_step_size( num_effective_samples=num_effective_samples, diagonal_preconditioning=diagonal_preconditioning, )(state, params, num_steps, part1_key) + total_num_tuning_integrator_steps += num_steps1 + num_steps2 - if frac_tune3 != 0: + if num_steps3 >= 2: # at least 2 samples for ESS estimation state, params = make_adaptation_L( mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) + total_num_tuning_integrator_steps += num_steps3 - return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3) + return state, params, total_num_tuning_integrator_steps def make_L_step_size_adaptation( @@ -225,10 +239,10 @@ def step(iteration_state, weight_and_key): )[0] def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = ( - int(num_steps * frac_tune1) + 1, - int(num_steps * frac_tune2) + 1, + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 ) + L_step_size_adaptation_keys = jax.random.split( rng_key, num_steps1 + num_steps2 + 1 ) @@ -259,7 +273,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = jnp.sqrt(dim) # readjust the stepsize - steps = num_steps2 // 3 # we do some small number of steps + steps = round(num_steps2 / 3) # we do some small number of steps keys = jax.random.split(final_key, steps) state, params, _, (_, average) = run_steps( xs=(jnp.ones(steps), keys), state=state, params=params @@ -274,7 +288,7 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps_3 = int(num_steps * frac) + num_steps_3 = round(num_steps * frac) adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 93480302e..257ce759c 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -115,6 +115,9 @@ def effective_sample_size( sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] + assert ( + num_samples > 1 + ), f"The input array must have at least 2 samples, got only {num_samples}." mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) # Compute autocovariance estimates for every lag for the input array using FFT. From 7e4241ff7ecda51f855dbc7d03103b70c09567ad Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 19 Feb 2025 11:56:46 -0300 Subject: [PATCH 08/14] SMC: Joint tuning and pretuning (#776) * impl * rename * docs --------- Co-authored-by: Junpeng Lao --- blackjax/smc/inner_kernel_tuning.py | 50 +++++++- blackjax/smc/pretuning.py | 12 +- blackjax/smc/tuning/from_particles.py | 12 +- tests/smc/test_inner_kernel_tuning.py | 172 +++++++++++++++++++++----- 4 files changed, 201 insertions(+), 45 deletions(-) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 2a63fd1ce..334a1488c 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -1,5 +1,7 @@ from typing import Callable, Dict, NamedTuple, Tuple +import jax + from blackjax.base import SamplingAlgorithm from blackjax.smc.base import SMCInfo, SMCState from blackjax.types import ArrayTree, PRNGKey @@ -28,8 +30,11 @@ def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner @@ -40,7 +45,8 @@ def build_kernel( ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). + a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this + to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn @@ -54,7 +60,30 @@ def build_kernel( A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. + smc_returns_state_with_parameter_override: + a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override. + this is used in order to compose different adaptation mechanisms, such as pretuning with tuning. """ + if smc_returns_state_with_parameter_override: + + def extract_state_for_delegate(state): + return state + + def compose_new_state(new_state, new_parameter_override): + composed_parameter_override = ( + new_state.parameter_override | new_parameter_override + ) + return StateWithParameterOverride( + new_state.sampler_state, composed_parameter_override + ) + + else: + + def extract_state_for_delegate(state): + return state.sampler_state + + def compose_new_state(new_state, new_parameter_override): + return StateWithParameterOverride(new_state, new_parameter_override) def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters @@ -69,9 +98,14 @@ def kernel( num_mcmc_steps=num_mcmc_steps, **extra_parameters, ).step - new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) - new_parameter_override = mcmc_parameter_update_fn(new_state, info) - return StateWithParameterOverride(new_state, new_parameter_override), info + parameter_update_key, step_key = jax.random.split(rng_key, 2) + new_state, info = step_fn( + step_key, extract_state_for_delegate(state), **extra_step_parameters + ) + new_parameter_override = mcmc_parameter_update_fn( + parameter_update_key, new_state, info + ) + return compose_new_state(new_state, new_parameter_override), info return kernel @@ -83,9 +117,12 @@ def as_top_level_api( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], initial_parameter_value, num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state @@ -130,6 +167,7 @@ def as_top_level_api( resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, + smc_returns_state_with_parameter_override, **extra_parameters, ) diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py index f489a0dc2..374b8f425 100644 --- a/blackjax/smc/pretuning.py +++ b/blackjax/smc/pretuning.py @@ -99,15 +99,21 @@ def update_parameter_distribution( ) +def default_measure_factory(state): + inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"] + if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1): + raise ValueError("ESJD only works if chains share the inverse_mass_matrix.") + + return esjd(inverse_mass_matrix[0]) + + def build_pretune( mcmc_init_fn: Callable, mcmc_step_fn: Callable, alpha: float, sigma_parameters: ArrayLikeTree, n_particles: int, - performance_of_chain_measure_factory: Callable = lambda state: esjd( - state.parameter_override["inverse_mass_matrix"] - ), + performance_of_chain_measure_factory: Callable = default_measure_factory, natural_parameters: Optional[List[str]] = None, positive_parameters: Optional[List[str]] = None, ): diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 279a718cb..505e7f3a1 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -12,7 +12,7 @@ "particles_means", "particles_stds", "particles_covariance_matrix", - "mass_matrix_from_particles", + "inverse_mass_matrix_from_particles", ] @@ -28,18 +28,16 @@ def particles_covariance_matrix(particles): return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False) -def mass_matrix_from_particles(particles) -> Array: +def inverse_mass_matrix_from_particles(particles) -> Array: """ Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf - Computing a mass matrix to be used in HMC from particles. - Given the particles covariance matrix, set all non-diagonal elements as zero, - take the inverse, and keep the diagonal. + Computing an inverse mass matrix to be used in HMC from particles. Returns ------- - A mass Matrix + An inverse mass matrix """ - return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0)) + return jnp.diag(jnp.var(particles_as_rows(particles), axis=0)) def particles_as_rows(particles): diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 7d6190af5..d7daaf839 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -15,9 +15,10 @@ from blackjax.mcmc.random_walk import build_irmh from blackjax.smc import extend_params from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning +from blackjax.smc.pretuning import build_pretune from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( - mass_matrix_from_particles, + inverse_mass_matrix_from_particles, particles_as_rows, particles_covariance_matrix, particles_means, @@ -93,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory = MagicMock() proposal_factory.return_value = 100 - def mcmc_parameter_update_fn(state, info): + def mcmc_parameter_update_fn(key, state, info): return extend_params({"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -186,30 +187,30 @@ def setUp(self): self.key = jax.random.key(42) def test_inverse_mass_matrix_from_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([np.array(10.0), np.array(3.0)]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.08163])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariate_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.081633, 0.033058])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25, 30.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariable_particles(self): var1 = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) var2 = np.array([jnp.array([10.0]), jnp.array([3.0])]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (3, 3) np.testing.assert_allclose( np.diag(mass_matrix), - np.array([0.081633, 0.033058, 0.081633], dtype="float32"), + np.array([12.25, 30.25, 12.25], dtype="float32"), rtol=1e-4, ) @@ -217,10 +218,10 @@ def test_inverse_mass_matrix_from_multivariable_univariate_particles(self): var1 = np.array([3.0, 2.0]) var2 = np.array([10.0, 3.0]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (2, 2) np.testing.assert_allclose( - np.diag(mass_matrix), np.array([4, 0.081633], dtype="float32"), rtol=1e-4 + np.diag(mass_matrix), np.array([0.25, 12.25], dtype="float32"), rtol=1e-4 ) @@ -279,10 +280,12 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -308,21 +311,7 @@ def parameter_update(state, info): ) init_state = init(init_particles) smc_kernel = self.variant(step) - - def inference_loop(kernel, rng_key, initial_state): - def cond(carry): - _, state = carry - return state.sampler_state.lmbda < 1 - - def body(carry): - i, state = carry - subkey = jax.random.fold_in(rng_key, i) - state, _ = kernel(subkey, state) - return i + 1, state - - return jax.lax.while_loop(cond, body, (0, initial_state)) - - _, state = inference_loop(smc_kernel, self.key, init_state) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @@ -336,10 +325,12 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -393,5 +384,128 @@ def test_particles_as_rows(self): np.testing.assert_array_equal(np.arange(3 * 5 + 2), flatten_particles[0]) +def adaptive_tempered_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state = carry + subkey = jax.random.fold_in(rng_key, i) + state, _ = kernel(subkey, state) + return i + 1, state + + return jax.lax.while_loop(cond, body, (0, initial_state)) + + +class MultipleTuningTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.all_variants(with_pmap=False) + def test_tuning_pretuning(self): + """ + Tests that we can apply tuning on some parameters + and pretuning in some others at the same time. + """ + + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + n_particles = 100 + dimentions = 2 + + step_size_key, integration_steps_key = jax.random.split(self.key, 2) + + # Set initial samples for integration steps and step sizes. + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (n_particles,), minval=1, maxval=50 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (n_particles,), minval=1e-1 / 2, maxval=1e-1 * 2 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(dimentions)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=2, + n_particles=n_particles, + sigma_parameters={ + "step_size": jnp.array(0.1), + "num_integration_steps": jnp.array(2.0), + }, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + def pretuning_factory( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + mcmc_parameters, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + target_ess, + ): + # we need to wrap the pretuning into a factory, which is what + # the inner_kernel_tuning expects + return blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + pretune, + target_ess=target_ess, + ) + + def mcmc_parameter_update_fn(key, state, info): + imm = inverse_mass_matrix_from_particles(state.sampler_state.particles) + return {"inverse_mass_matrix": extend_params(imm)} + + step = blackjax.smc.inner_kernel_tuning.build_kernel( + pretuning_factory, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + mcmc_parameter_update_fn=mcmc_parameter_update_fn, + initial_parameter_value=initial_parameters, + num_mcmc_steps=10, + target_ess=0.5, + smc_returns_state_with_parameter_override=True, + ) + + def init(position): + return blackjax.smc.inner_kernel_tuning.init( + blackjax.adaptive_tempered_smc.init, position, initial_parameters + ) + + init_state = init(init_particles) + smc_kernel = self.variant(step) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) + self.assert_linear_regression_test_case(state.sampler_state) + + if __name__ == "__main__": absltest.main() From 060c99a4278c7a76d35e509465dc00acf8e11ed5 Mon Sep 17 00:00:00 2001 From: Reuben Date: Sat, 26 Apr 2025 11:46:02 -0400 Subject: [PATCH 09/14] Energy error monitoring (#784) * energy error monitoring * energy error monitoring * jnp abs --- blackjax/mcmc/mclmc.py | 48 +++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..d4a235770 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,6 +15,7 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( @@ -60,7 +61,13 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): +def build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): """Build a HMC kernel. Parameters @@ -91,14 +98,33 @@ def kernel( state, step_size, L, rng_key ) - return IntegratorState( - position, momentum, logdensity, logdensitygrad - ), MCLMCInfo( - logdensity=logdensity, - energy_change=kinetic_change - logdensity + state.logdensity, - kinetic_change=kinetic_change, + energy_error = kinetic_change - logdensity + state.logdensity + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + new_state, new_info = jax.lax.cond( + jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), + lambda: ( + state, + MCLMCInfo( + logdensity=state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + ), + ), + lambda: ( + IntegratorState(position, momentum, logdensity, logdensitygrad), + MCLMCInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + ), + ), ) + return new_state, new_info + return kernel @@ -108,6 +134,7 @@ def as_top_level_api( step_size, integrator=isokinetic_mclachlan, inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +182,12 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) + kernel = build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) From 56df032f43f9a4a000ea97f34a2f5f0529e6dc1b Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 27 Apr 2025 10:20:36 +0200 Subject: [PATCH 10/14] ping Jaxopt version to unbreak test (#789) * ping Jaxopt version to unbreak test * lower jaxopt version --- pyproject.toml | 2 +- requirements-doc.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0739361e2..cbd2cefd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "fastprogress>=1.0.0", "jax>=0.4.16", "jaxlib>=0.4.16", - "jaxopt>=0.8", + "jaxopt<=0.8.3", "optax>=0.1.7", "typing-extensions>=4.4.0", ] diff --git a/requirements-doc.txt b/requirements-doc.txt index 83af1ffe3..fe8089cf4 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -6,7 +6,7 @@ flax ipython jax>=0.4.25 jaxlib>=0.4.25 -jaxopt +jaxopt<=0.8.3 jupytext myst_nb>=1.0.0 numba From 4326d576c7464e94cb248fb9fd2a582576fee72d Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Tue, 3 Jun 2025 10:01:50 +0100 Subject: [PATCH 11/14] Guarantee covariance matrix is 2d --- blackjax/ns/nss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/ns/nss.py b/blackjax/ns/nss.py index 3dcf658a5..aebd69d0f 100644 --- a/blackjax/ns/nss.py +++ b/blackjax/ns/nss.py @@ -133,7 +133,7 @@ def compute_covariance_from_particles( `cov_pytree` is equivalent to `jax.vmap(unravel_fn)(M_flat)`. This means each leaf of `cov_pytree` will have a shape `(D, *leaf_original_dims)`. """ - cov_matrix = particles_covariance_matrix(state.particles) + cov_matrix = jnp.atleast_2d(particles_covariance_matrix(state.particles)) single_particle = get_first_row(state.particles) _, unravel_fn = ravel_pytree(single_particle) cov_pytree = jax.vmap(unravel_fn)(cov_matrix) From 35d6213250b395edc8d6472cdd3dc1c2aad552a6 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Sat, 14 Jun 2025 07:16:28 +0100 Subject: [PATCH 12/14] Add direction vector to SliceInfo for enhanced slice sampling diagnostics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/mcmc/ss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/blackjax/mcmc/ss.py b/blackjax/mcmc/ss.py index 857d7950b..e0191bb0e 100644 --- a/blackjax/mcmc/ss.py +++ b/blackjax/mcmc/ss.py @@ -84,6 +84,8 @@ class SliceInfo(NamedTuple): acceptable sample. evals The total number of log-density evaluations performed during the step. + d + The direction vector used for the slice sampling step. """ constraint: Array = jnp.array([]) @@ -91,6 +93,7 @@ class SliceInfo(NamedTuple): r_steps: int = 0 s_steps: int = 0 evals: int = 0 + d: ArrayTree = None def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState: @@ -167,6 +170,7 @@ def kernel( r_steps=hs_info.r_steps, s_steps=hs_info.s_steps, evals=vs_info.evals + hs_info.evals, + d=d, ) return new_state, info From 63446c1ce19504358798ae96333946de90a26029 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Sat, 14 Jun 2025 13:27:01 +0100 Subject: [PATCH 13/14] Add test files for nested sampling and slice sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add comprehensive test suite for nested sampling functionality - Add unit tests for slice sampling implementation - Tests cover core algorithms and edge cases 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- tests/mcmc/test_slice_sampling.py | 646 +++++++++++----- tests/mcmc/test_slice_sampling_units.py | 404 ++++++++++ tests/ns/test_nested_sampling.py | 950 ++++++++++-------------- tests/ns/test_nested_sampling_units.py | 829 +++++++++++++++++++++ 4 files changed, 2099 insertions(+), 730 deletions(-) create mode 100644 tests/mcmc/test_slice_sampling_units.py create mode 100644 tests/ns/test_nested_sampling_units.py diff --git a/tests/mcmc/test_slice_sampling.py b/tests/mcmc/test_slice_sampling.py index 54c2a721f..ab36ed9b6 100644 --- a/tests/mcmc/test_slice_sampling.py +++ b/tests/mcmc/test_slice_sampling.py @@ -1,4 +1,6 @@ -"""Test the Slice Sampling algorithm""" +"""Test the Hit-and-Run Slice Sampling algorithm.""" +import functools + import chex import jax import jax.numpy as jnp @@ -9,228 +11,532 @@ from blackjax.mcmc import ss -def logdensity_fn(x): - """Standard normal density""" - return stats.norm.logpdf(x).sum() - - -def multimodal_logdensity(x): - """Mixture of two Gaussians""" - mode1 = stats.norm.logpdf(x - 2.0) - mode2 = stats.norm.logpdf(x + 2.0) - return jnp.logaddexp(mode1, mode2).sum() - +class SliceSamplingCoreTest(chex.TestCase): + """Test core slice sampling functionality.""" -def constrained_logdensity(x): - """Truncated normal (x > 0)""" - return jnp.where(x > 0, stats.norm.logpdf(x), -jnp.inf).sum() - - -class SliceSamplingTest(chex.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) - def test_slice_init(self): - """Test slice sampler initialization""" - position = jnp.array([1.0, 2.0]) - state = ss.init(position, logdensity_fn) - + def logdensity_normal(self, x): + """Standard multivariate normal.""" + return stats.norm.logpdf(x).sum() + + def logdensity_constrained(self, x): + """Constrained to positive values.""" + return jnp.where(jnp.all(x > 0), stats.norm.logpdf(x).sum(), -jnp.inf) + + def test_slice_state_structure(self): + """Test SliceState structure and initialization.""" + position = jnp.array([1.0, -0.5, 2.0]) + state = ss.init(position, self.logdensity_normal) + + # Check structure + self.assertIsInstance(state, ss.SliceState) chex.assert_trees_all_close(state.position, position) - expected_logdensity = logdensity_fn(position) - chex.assert_trees_all_close(state.logdensity, expected_logdensity) + + # Check logdensity is computed correctly + expected_logdens = self.logdensity_normal(position) + chex.assert_trees_all_close(state.logdensity, expected_logdens) + + # Check default logslice + self.assertEqual(state.logslice, jnp.inf) + + def test_slice_info_structure(self): + """Test SliceInfo structure.""" + info = ss.SliceInfo( + constraint=jnp.array([1.0, 2.0]), + l_steps=3, + r_steps=5, + s_steps=7, + evals=15, + d=jnp.array([0.5, -0.2]) + ) + + chex.assert_shape(info.constraint, (2,)) + self.assertEqual(info.l_steps, 3) + self.assertEqual(info.r_steps, 5) + self.assertEqual(info.s_steps, 7) + self.assertEqual(info.evals, 15) + chex.assert_shape(info.d, (2,)) def test_vertical_slice(self): - """Test vertical slice height sampling""" - key = jax.random.key(123) + """Test vertical slice height sampling.""" position = jnp.array([0.0]) - state = ss.init(position, logdensity_fn) - - # Sample many slice heights - keys = jax.random.split(key, 1000) - new_state, info = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) - - # Heights should be below log density at position - logdens_at_pos = logdensity_fn(position) - self.assertTrue(jnp.all(new_state.logslice <= logdens_at_pos)) - - # Heights should be reasonably distributed - mean_height = jnp.mean(new_state.logslice) - expected_mean = logdens_at_pos - 1.0 # E[log(U)] = -1 for U~Uniform(0,1) + state = ss.init(position, self.logdensity_normal) + + # Test multiple samples + n_samples = 1000 + keys = jax.random.split(self.key, n_samples) + + new_states, infos = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) + + # Heights should be below current logdensity + logdens_at_pos = self.logdensity_normal(position) + self.assertTrue(jnp.all(new_states.logslice <= logdens_at_pos)) + + # Mean should be approximately logdens - 1 (E[log(U)] = -1) + mean_height = jnp.mean(new_states.logslice) + expected_mean = logdens_at_pos - 1.0 chex.assert_trees_all_close(mean_height, expected_mean, atol=0.1) + + # Check info structure + self.assertTrue(jnp.all(infos.evals == 0)) # Vertical slice doesn't eval logdensity @parameterized.parameters([1, 2, 5]) def test_slice_sampling_dimensions(self, ndim): - """Test slice sampling in different dimensions""" - key = jax.random.key(456) + """Test slice sampling in different dimensions.""" position = jnp.zeros(ndim) - - # Simple step function - def stepper_fn(x, d, t): - return x + t * d - - # Build kernel - def direction_fn(rng_key): - return jax.random.normal(rng_key, (ndim,)) - - kernel = ss.build_hrss_kernel(direction_fn, stepper_fn) - state = ss.init(position, logdensity_fn) - - # Take one step - new_state, info = kernel(key, state, logdensity_fn) - + state = ss.init(position, self.logdensity_normal) + + # Test with simple direction and stepper + direction = jax.random.normal(self.key, (ndim,)) + direction = direction / jnp.linalg.norm(direction) + + kernel = ss.build_kernel(ss.default_stepper_fn) + + def dummy_constraint_fn(x): + return jnp.array([]) + + new_state, info = kernel( + self.key, state, self.logdensity_normal, direction, + dummy_constraint_fn, jnp.array([]), jnp.array([]) + ) + chex.assert_shape(new_state.position, (ndim,)) self.assertIsInstance(new_state.logdensity, (float, jax.Array)) + self.assertIsInstance(info, ss.SliceInfo) - def test_constrained_slice_sampling(self): - """Test slice sampling with constraints""" - key = jax.random.key(789) - position = jnp.array([1.0]) # Start in valid region + def test_1d_slice_sampling(self): + """Test 1D slice sampling (edge case for JAX shapes).""" + position = jnp.array(0.5) # 1D scalar + state = ss.init(position, lambda x: -0.5 * x**2) + + direction = jnp.array(1.0) # 1D direction + kernel = ss.build_kernel(ss.default_stepper_fn) + + def dummy_constraint_fn(x): + return jnp.array([]) + + new_state, info = kernel( + self.key, state, lambda x: -0.5 * x**2, direction, + dummy_constraint_fn, jnp.array([]), jnp.array([]) + ) + + # Check it runs without shape errors + self.assertIsInstance(new_state.logdensity, (float, jax.Array)) + self.assertIsInstance(info.evals, (int, jax.Array)) - def stepper_fn(x, d, t): - return x + t * d + def test_default_stepper_fn(self): + """Test default stepper function.""" + x = jnp.array([1.0, 2.0, -1.5]) + d = jnp.array([0.5, -0.3, 0.8]) + t = 2.5 + + result = ss.default_stepper_fn(x, d, t) + expected = x + t * d + + chex.assert_trees_all_close(result, expected) - kernel = ss.build_kernel(stepper_fn) - state = ss.init(position, constrained_logdensity) + def test_stepper_fn_with_pytrees(self): + """Test stepper function with PyTree inputs.""" + x = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([-0.5])} + d = {"a": jnp.array([0.3, -0.2]), "b": jnp.array([0.7])} + t = 1.5 + + result = ss.default_stepper_fn(x, d, t) + + chex.assert_trees_all_close(result["a"], x["a"] + t * d["a"]) + chex.assert_trees_all_close(result["b"], x["b"] + t * d["b"]) - # Direction pointing outward - direction = jnp.array([1.0]) - # Constraint function - def constraint_fn(x): - return jnp.array([]) # No additional constraints for this test +class SliceSamplingConstraintsTest(chex.TestCase): + """Test slice sampling with constraints.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(123) + def test_constrained_sampling(self): + """Test slice sampling respects constraints.""" + # Start in valid region (x > 0) + position = jnp.array([1.0, 2.0]) + + def constrained_logdens(x): + return jnp.where(jnp.all(x > 0), -0.5 * jnp.sum(x**2), -jnp.inf) + + state = ss.init(position, constrained_logdens) + direction = jnp.array([1.0, -0.5]) # Could lead outside valid region + + kernel = ss.build_kernel(ss.default_stepper_fn) + + # Test with constraint function + def constraint_fn(x): + return x # Return position values to check > 0 + + constraint_thresholds = jnp.array([0.0, 0.0]) # Must be > 0 + strict_flags = jnp.array([True, True]) # Strict inequality + new_state, info = kernel( - key, - state, - constrained_logdensity, - direction, - constraint_fn, - jnp.array([]), - jnp.array([]), + self.key, state, constrained_logdens, direction, + constraint_fn, constraint_thresholds, strict_flags ) - + # Should remain in valid region self.assertTrue(jnp.all(new_state.position > 0)) + self.assertFalse(jnp.isneginf(new_state.logdensity)) + + def test_constraint_evaluation_ordering(self): + """Test that constraints are evaluated correctly.""" + position = jnp.array([0.5]) + + def logdens(x): + return -0.5 * x**2 + + state = ss.init(position, logdens) + direction = jnp.array([1.0]) + + kernel = ss.build_kernel(ss.default_stepper_fn) + + # Constraint that evaluates a simple function + def constraint_fn(x): + return jnp.array([x[0]**2]) # Square of position + + constraint_threshold = jnp.array([0.25]) # x^2 > 0.25, so |x| > 0.5 + strict_flag = jnp.array([True]) + + new_state, info = kernel( + self.key, state, logdens, direction, + constraint_fn, constraint_threshold, strict_flag + ) + + # Check constraint is satisfied + constraint_val = constraint_fn(new_state.position) + self.assertTrue(jnp.all(constraint_val > constraint_threshold)) + + def test_multiple_constraints(self): + """Test multiple constraints simultaneously.""" + position = jnp.array([1.0, 1.5]) + + def logdens(x): + return -0.5 * jnp.sum(x**2) + + state = ss.init(position, logdens) + direction = jnp.array([0.7, -0.3]) + + kernel = ss.build_kernel(ss.default_stepper_fn) + + def constraint_fn(x): + return jnp.array([x[0], x[1], jnp.sum(x)]) # Multiple constraints + + constraints = jnp.array([0.2, 0.1, 1.0]) # x[0] > 0.2, x[1] > 0.1, sum > 1.0 + strict = jnp.array([True, True, False]) # Mixed strict/non-strict + + new_state, info = kernel( + self.key, state, logdens, direction, + constraint_fn, constraints, strict + ) + + # Check all constraints are satisfied + constraint_vals = constraint_fn(new_state.position) + self.assertTrue(constraint_vals[0] > constraints[0]) # Strict + self.assertTrue(constraint_vals[1] > constraints[1]) # Strict + self.assertTrue(constraint_vals[2] >= constraints[2]) # Non-strict - def test_default_direction_generation(self): - """Test default direction generation function""" - key = jax.random.key(101112) - cov = jnp.eye(3) * 2.0 - - direction = ss.sample_direction_from_covariance(key, cov) - chex.assert_shape(direction, (3,)) +class HitAndRunSliceSamplingTest(chex.TestCase): + """Test Hit-and-Run Slice Sampling functionality.""" - # Direction should be normalized in Mahalanobis sense + def setUp(self): + super().setUp() + self.key = jax.random.key(456) + + def logdensity_normal(self, x): + return stats.norm.logpdf(x).sum() + + def test_direction_generation_from_covariance(self): + """Test direction generation from covariance matrix.""" + ndim = 3 + cov = jnp.array([[2.0, 0.5, 0.0], + [0.5, 1.5, -0.3], + [0.0, -0.3, 1.0]]) + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (ndim,)) + + # Check Mahalanobis normalization invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) - def test_hrss_top_level_api(self): - """Test hit-and-run slice sampling top-level API""" - cov = jnp.eye(2) - algorithm = ss.hrss_as_top_level_api(logdensity_fn, cov) + def test_direction_generation_identity_covariance(self): + """Test direction generation with identity covariance.""" + ndim = 4 + cov = jnp.eye(ndim) + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (ndim,)) + + # With identity covariance, should be unit normalized + euclidean_norm = jnp.linalg.norm(direction) + chex.assert_trees_all_close(euclidean_norm, 1.0, atol=1e-6) + + def test_hrss_kernel_construction(self): + """Test HRSS kernel construction.""" + def direction_fn(rng_key): + return jax.random.normal(rng_key, (2,)) + + kernel = ss.build_hrss_kernel(direction_fn, ss.default_stepper_fn) + + self.assertTrue(callable(kernel)) + + # Test kernel execution + position = jnp.array([0.0, 1.0]) + state = ss.init(position, self.logdensity_normal) + + new_state, info = kernel(self.key, state, self.logdensity_normal) + + chex.assert_shape(new_state.position, (2,)) + self.assertIsInstance(info, ss.SliceInfo) - # Check it returns a SamplingAlgorithm + def test_hrss_top_level_api(self): + """Test hit-and-run slice sampling top-level API.""" + ndim = 2 + cov = jnp.eye(ndim) * 1.5 + + algorithm = ss.hrss_as_top_level_api(self.logdensity_normal, cov) + + # Check it's a proper SamplingAlgorithm self.assertIsInstance(algorithm, blackjax.base.SamplingAlgorithm) - - # Test init and step functions - position = jnp.array([0.0, 0.0]) + self.assertTrue(hasattr(algorithm, "init")) + self.assertTrue(hasattr(algorithm, "step")) + + # Test initialization + position = jnp.array([1.0, -0.5]) state = algorithm.init(position) + + self.assertIsInstance(state, ss.SliceState) + chex.assert_trees_all_close(state.position, position) + + # Test step + new_state, info = algorithm.step(self.key, state) + + chex.assert_shape(new_state.position, (ndim,)) + self.assertIsInstance(info, ss.SliceInfo) - key = jax.random.key(123) - new_state, info = algorithm.step(key, state) + def test_hrss_1d_case(self): + """Test HRSS with 1D problem.""" + cov = jnp.array([[1.0]]) # 1x1 covariance matrix + + def logdens_1d(x): + return -0.5 * x**2 + + algorithm = ss.hrss_as_top_level_api(logdens_1d, cov) + + position = jnp.array([0.5]) + state = algorithm.init(position) + + new_state, info = algorithm.step(self.key, state) + + chex.assert_shape(new_state.position, (1,)) + self.assertIsInstance(new_state.logdensity, (float, jax.Array)) - chex.assert_shape(new_state.position, (2,)) - def test_slice_sampling_statistical_correctness(self): - """Test that slice sampling produces correct statistics""" - n_samples = 100 # Reduced significantly for faster testing - key = jax.random.key(42) +class SliceSamplingStatisticalTest(chex.TestCase): + """Statistical correctness tests for slice sampling.""" - # Use HRSS for sampling from standard normal + def setUp(self): + super().setUp() + self.key = jax.random.key(789) + + def test_slice_sampling_mean_estimation(self): + """Test that HRSS correctly estimates mean of target distribution.""" + # Target: standard normal, should have mean ≈ 0 + def logdens(x): + return stats.norm.logpdf(x).sum() + cov = jnp.eye(1) - algorithm = ss.hrss_as_top_level_api(logdensity_fn, cov) - - # Run inference - initial_position = jnp.array([0.0]) - initial_state = algorithm.init(initial_position) - - # Simple sampling loop with progress tracking + algorithm = ss.hrss_as_top_level_api(logdens, cov) + + # Run short chain + n_samples = 200 # Modest for testing + position = jnp.array([0.0]) + state = algorithm.init(position) + samples = [] - state = initial_state - keys = jax.random.split(key, n_samples) - + keys = jax.random.split(self.key, n_samples) + for i, sample_key in enumerate(keys): state, info = algorithm.step(sample_key, state) - samples.append(state.position) - # Early exit if we get stuck - if i > 0 and jnp.isnan(state.position).any(): - break - - if len(samples) < 10: # If we got very few samples, skip statistical test - self.skipTest("Not enough samples generated") - + if i >= 50: # Skip some burn-in + samples.append(state.position[0]) + samples = jnp.array(samples) - - # Check basic properties - self.assertFalse(jnp.isnan(samples).any(), "Samples contain NaN") - self.assertFalse(jnp.isinf(samples).any(), "Samples contain Inf") - - # Very loose statistical checks for small sample size + + # Basic sanity checks + self.assertFalse(jnp.any(jnp.isnan(samples))) + self.assertFalse(jnp.any(jnp.isinf(samples))) + + # Statistical checks (very loose for small sample size) sample_mean = jnp.mean(samples) sample_std = jnp.std(samples) + + # Mean should be reasonable + self.assertLess(jnp.abs(sample_mean), 0.5) # Loose bound + + # Standard deviation should be positive and reasonable + self.assertGreater(sample_std, 0.1) + self.assertLess(sample_std, 3.0) + + def test_slice_sampling_multimodal(self): + """Test slice sampling on multimodal distribution.""" + def logdens_bimodal(x): + # Mixture of two Gaussians at -2 and +2 + mode1 = stats.norm.logpdf(x - 2.0) + mode2 = stats.norm.logpdf(x + 2.0) + return jnp.logaddexp(mode1, mode2).sum() + + cov = jnp.eye(1) * 4.0 # Wider proposals for multimodal + algorithm = ss.hrss_as_top_level_api(logdens_bimodal, cov) + + # Run chain + n_samples = 100 + position = jnp.array([1.0]) # Start near one mode + state = algorithm.init(position) + + samples = [] + keys = jax.random.split(self.key, n_samples) + + for sample_key in keys: + state, info = algorithm.step(sample_key, state) + samples.append(state.position[0]) + + samples = jnp.array(samples) + + # Check basic properties + self.assertFalse(jnp.any(jnp.isnan(samples))) + sample_range = jnp.max(samples) - jnp.min(samples) + self.assertGreater(sample_range, 1.0) # Should explore reasonable range + + def test_slice_info_diagnostics(self): + """Test that slice info provides useful diagnostics.""" + def logdens(x): + return -0.5 * jnp.sum(x**2) + + cov = jnp.eye(2) + algorithm = ss.hrss_as_top_level_api(logdens, cov) + + position = jnp.array([0.0, 0.0]) + state = algorithm.init(position) + + # Collect diagnostics from multiple steps + infos = [] + keys = jax.random.split(self.key, 20) + + for sample_key in keys: + state, info = algorithm.step(sample_key, state) + infos.append(info) + + # Check diagnostic fields + l_steps = jnp.array([info.l_steps for info in infos]) + r_steps = jnp.array([info.r_steps for info in infos]) + s_steps = jnp.array([info.s_steps for info in infos]) + evals = jnp.array([info.evals for info in infos]) + + # All should be non-negative + self.assertTrue(jnp.all(l_steps >= 0)) + self.assertTrue(jnp.all(r_steps >= 0)) + self.assertTrue(jnp.all(s_steps >= 0)) + self.assertTrue(jnp.all(evals >= 0)) + + # Total evaluations should be sum of expansion + shrinking + expected_evals = l_steps + r_steps + s_steps + chex.assert_trees_all_close(evals, expected_evals) + + # Direction vectors should be present + directions = jnp.array([info.d for info in infos]) + chex.assert_shape(directions, (20, 2)) + self.assertFalse(jnp.any(jnp.isnan(directions))) + + +class SliceSamplingEdgeCasesTest(chex.TestCase): + """Test edge cases and robustness.""" - # Just check that mean is reasonable and std is positive - self.assertLess(abs(sample_mean), 2.0, "Mean is unreasonably far from 0") - self.assertGreater(sample_std, 0.1, "Standard deviation is too small") - self.assertLess(sample_std, 5.0, "Standard deviation is too large") - - def test_default_stepper_fn(self): - """Test default stepper function""" - x = jnp.array([1.0, 2.0]) - d = jnp.array([0.5, -0.5]) - t = 2.0 - - result = ss.default_stepper_fn(x, d, t) - expected = x + t * d - - chex.assert_trees_all_close(result, expected) - - def test_slice_info_structure(self): - """Test that SliceInfo contains expected fields""" - key = jax.random.key(789) + def setUp(self): + super().setUp() + self.key = jax.random.key(101112) + + def test_zero_covariance_matrix(self): + """Test behavior with singular covariance matrix.""" + # This should handle gracefully or raise informative error + cov = jnp.zeros((2, 2)) + + # JAX's linalg.inv will produce NaN/Inf for singular matrices + # rather than raising an error, so check for that + try: + direction = ss.sample_direction_from_covariance(self.key, cov) + # If it doesn't raise, check for NaN/Inf + self.assertTrue(jnp.isnan(direction).any() or jnp.isinf(direction).any()) + except (ValueError, RuntimeError): + # This is also acceptable behavior + pass + + def test_very_peaked_distribution(self): + """Test with very peaked/narrow distribution.""" + def logdens_peaked(x): + return -100.0 * jnp.sum(x**2) # Very narrow + + cov = jnp.eye(1) * 0.01 # Small proposals + algorithm = ss.hrss_as_top_level_api(logdens_peaked, cov) + + position = jnp.array([0.01]) + state = algorithm.init(position) + + # Should handle without numerical issues + new_state, info = algorithm.step(self.key, state) + + self.assertFalse(jnp.isnan(new_state.logdensity)) + self.assertFalse(jnp.isinf(new_state.logdensity)) + + def test_large_step_proposals(self): + """Test with very large step proposals.""" + def logdens(x): + return -0.5 * jnp.sum(x**2) + + cov = jnp.eye(1) * 100.0 # Very large proposals + algorithm = ss.hrss_as_top_level_api(logdens, cov) + position = jnp.array([0.0]) - - def stepper_fn(x, d, t): - return x + t * d - - kernel = ss.build_kernel(stepper_fn) - state = ss.init(position, logdensity_fn) + state = algorithm.init(position) + + # Should still work (though possibly inefficient) + new_state, info = algorithm.step(self.key, state) + + self.assertFalse(jnp.isnan(new_state.position).any()) + self.assertGreater(info.evals, 0) # Should do some work + + def test_empty_constraint_arrays(self): + """Test with empty constraint arrays.""" + position = jnp.array([1.0]) + state = ss.init(position, lambda x: -0.5 * x**2) direction = jnp.array([1.0]) - - def constraint_fn(x): + + kernel = ss.build_kernel(ss.default_stepper_fn) + + def empty_constraint_fn(x): return jnp.array([]) - + + # Should handle empty constraints gracefully new_state, info = kernel( - key, - state, - logdensity_fn, - direction, - constraint_fn, - jnp.array([]), - jnp.array([]), + self.key, state, lambda x: -0.5 * x**2, direction, + empty_constraint_fn, jnp.array([]), jnp.array([]) ) - - # Check that info has expected structure - self.assertIsInstance(info, ss.SliceInfo) - self.assertTrue(hasattr(info, "constraint")) - self.assertTrue(hasattr(info, "l_steps")) - self.assertTrue(hasattr(info, "r_steps")) - self.assertTrue(hasattr(info, "s_steps")) - self.assertTrue(hasattr(info, "evals")) + + self.assertIsInstance(new_state, ss.SliceState) + chex.assert_shape(info.constraint, (0,)) if __name__ == "__main__": - absltest.main() + absltest.main() \ No newline at end of file diff --git a/tests/mcmc/test_slice_sampling_units.py b/tests/mcmc/test_slice_sampling_units.py new file mode 100644 index 000000000..3c81e6cd5 --- /dev/null +++ b/tests/mcmc/test_slice_sampling_units.py @@ -0,0 +1,404 @@ +"""Unit tests for slice sampling components.""" +import chex +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized + +from blackjax.mcmc import ss + + +class SliceStateTest(chex.TestCase): + """Test SliceState data structure.""" + + def test_slice_state_creation(self): + """Test SliceState creation and default values.""" + position = jnp.array([1.0, 2.0]) + logdensity = -3.5 + + # Test with default logslice + state = ss.SliceState(position, logdensity) + chex.assert_trees_all_close(state.position, position) + self.assertEqual(state.logdensity, logdensity) + self.assertEqual(state.logslice, jnp.inf) + + # Test with explicit logslice + logslice = -1.2 + state = ss.SliceState(position, logdensity, logslice) + self.assertEqual(state.logslice, logslice) + + def test_slice_state_replace(self): + """Test SliceState _replace method.""" + state = ss.SliceState(jnp.array([1.0]), -2.0, -5.0) + + new_state = state._replace(logslice=-3.0) + self.assertEqual(new_state.logslice, -3.0) + self.assertEqual(new_state.logdensity, -2.0) # Unchanged + chex.assert_trees_all_close(new_state.position, state.position) + + +class SliceInfoTest(chex.TestCase): + """Test SliceInfo data structure.""" + + def test_slice_info_creation(self): + """Test SliceInfo creation and default values.""" + # Test with defaults + info = ss.SliceInfo() + chex.assert_shape(info.constraint, (0,)) + self.assertEqual(info.l_steps, 0) + self.assertEqual(info.r_steps, 0) + self.assertEqual(info.s_steps, 0) + self.assertEqual(info.evals, 0) + self.assertIsNone(info.d) + + # Test with explicit values + constraint = jnp.array([1.0, 2.0]) + direction = jnp.array([0.5, -0.3]) + info = ss.SliceInfo( + constraint=constraint, + l_steps=3, r_steps=5, s_steps=7, evals=15, + d=direction + ) + chex.assert_trees_all_close(info.constraint, constraint) + self.assertEqual(info.l_steps, 3) + self.assertEqual(info.r_steps, 5) + self.assertEqual(info.s_steps, 7) + self.assertEqual(info.evals, 15) + chex.assert_trees_all_close(info.d, direction) + + +class InitFunctionTest(chex.TestCase): + """Test slice sampling initialization.""" + + def setUp(self): + super().setUp() + self.logdensity_fn = lambda x: -0.5 * jnp.sum(x**2) + + @parameterized.parameters([ + (jnp.array([0.0]),), + (jnp.array([1.5, -2.0]),), + (jnp.array([[1.0, 2.0], [3.0, 4.0]]),), + ]) + def test_init_shapes(self, position): + """Test init with different position shapes.""" + state = ss.init(position, self.logdensity_fn) + + chex.assert_trees_all_close(state.position, position) + expected_logdens = self.logdensity_fn(position) + chex.assert_trees_all_close(state.logdensity, expected_logdens) + self.assertEqual(state.logslice, jnp.inf) + + def test_init_with_pytree(self): + """Test init with PyTree position.""" + position = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([3.0])} + + def logdens_pytree(x): + return -0.5 * (jnp.sum(x["a"]**2) + jnp.sum(x["b"]**2)) + + state = ss.init(position, logdens_pytree) + + chex.assert_trees_all_close(state.position, position) + expected_logdens = logdens_pytree(position) + self.assertEqual(state.logdensity, expected_logdens) + + +class VerticalSliceTest(chex.TestCase): + """Test vertical slice function.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_vertical_slice_height_bounds(self): + """Test that slice height is always below current logdensity.""" + position = jnp.array([0.0]) + logdensity = -1.5 + state = ss.SliceState(position, logdensity) + + # Test multiple samples + keys = jax.random.split(self.key, 100) + new_states, infos = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) + + # All slice heights should be <= logdensity + self.assertTrue(jnp.all(new_states.logslice <= logdensity)) + + # Info should have zero evaluations (vertical slice doesn't eval logdensity) + self.assertTrue(jnp.all(infos.evals == 0)) + + def test_vertical_slice_deterministic_bound(self): + """Test that slice height has correct statistical properties.""" + position = jnp.array([0.0]) + logdensity = -2.0 + state = ss.SliceState(position, logdensity) + + # Generate many samples + n_samples = 5000 + keys = jax.random.split(self.key, n_samples) + new_states, _ = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) + + # Mean of log(U) where U ~ Uniform(0,1) is -1 + mean_height = jnp.mean(new_states.logslice) + expected_mean = logdensity - 1.0 + + # Should be close to expected mean (loose tolerance for finite sample) + self.assertAlmostEqual(mean_height, expected_mean, delta=0.1) + + def test_vertical_slice_preserves_position(self): + """Test that vertical slice preserves position and logdensity.""" + position = jnp.array([1.5, -0.5]) + logdensity = -3.2 + state = ss.SliceState(position, logdensity) + + new_state, info = ss.vertical_slice(self.key, state) + + chex.assert_trees_all_close(new_state.position, position) + self.assertEqual(new_state.logdensity, logdensity) + self.assertNotEqual(new_state.logslice, jnp.inf) # Should be updated + + +class StepperFunctionTest(chex.TestCase): + """Test stepper function.""" + + def test_default_stepper_array(self): + """Test default stepper with arrays.""" + x = jnp.array([1.0, 2.0]) + d = jnp.array([0.5, -0.3]) + t = 2.5 + + result = ss.default_stepper_fn(x, d, t) + expected = x + t * d + + chex.assert_trees_all_close(result, expected) + + def test_default_stepper_scalar(self): + """Test default stepper with scalars.""" + x = 3.0 + d = -1.2 + t = 0.8 + + result = ss.default_stepper_fn(x, d, t) + expected = x + t * d + + self.assertEqual(result, expected) + + def test_default_stepper_pytree(self): + """Test default stepper with PyTree.""" + x = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([3.0])} + d = {"a": jnp.array([0.1, -0.2]), "b": jnp.array([0.5])} + t = 1.5 + + result = ss.default_stepper_fn(x, d, t) + + chex.assert_trees_all_close(result["a"], x["a"] + t * d["a"]) + chex.assert_trees_all_close(result["b"], x["b"] + t * d["b"]) + + def test_stepper_zero_step(self): + """Test stepper with zero step size.""" + x = jnp.array([1.0, 2.0, 3.0]) + d = jnp.array([10.0, -5.0, 2.0]) + t = 0.0 + + result = ss.default_stepper_fn(x, d, t) + chex.assert_trees_all_close(result, x) + + +class DirectionSamplingTest(chex.TestCase): + """Test direction sampling functions.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(123) + + def test_sample_direction_identity_covariance(self): + """Test direction sampling with identity covariance.""" + ndim = 3 + cov = jnp.eye(ndim) + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (ndim,)) + + # With identity covariance, should be unit normalized + norm = jnp.linalg.norm(direction) + chex.assert_trees_all_close(norm, 1.0, atol=1e-6) + + def test_sample_direction_scaled_covariance(self): + """Test direction sampling with scaled covariance.""" + ndim = 2 + scale = 4.0 + cov = jnp.eye(ndim) * scale + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (ndim,)) + + # Check Mahalanobis normalization + invcov = jnp.linalg.inv(cov) + mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) + chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) + + def test_sample_direction_general_covariance(self): + """Test direction sampling with general covariance matrix.""" + cov = jnp.array([[2.0, 0.5], [0.5, 1.0]]) + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (2,)) + + # Check Mahalanobis normalization + invcov = jnp.linalg.inv(cov) + mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) + chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) + + def test_sample_direction_1d(self): + """Test direction sampling for 1D case.""" + cov = jnp.array([[2.0]]) + + direction = ss.sample_direction_from_covariance(self.key, cov) + + chex.assert_shape(direction, (1,)) + + # Check Mahalanobis normalization (should be 1) + invcov = jnp.linalg.inv(cov) + mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) + chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) + + def test_sample_direction_multiple_samples(self): + """Test that multiple direction samples are different.""" + cov = jnp.eye(2) + keys = jax.random.split(self.key, 10) + + directions = jax.vmap(ss.sample_direction_from_covariance, in_axes=(0, None))(keys, cov) + + chex.assert_shape(directions, (10, 2)) + + # All should be unit normalized + norms = jnp.linalg.norm(directions, axis=1) + chex.assert_trees_all_close(norms, jnp.ones(10), atol=1e-6) + + # Should not all be the same + std_of_directions = jnp.std(directions, axis=0) + self.assertTrue(jnp.all(std_of_directions > 0.1)) # Some variation expected + + +class HorizontalSliceTest(chex.TestCase): + """Test horizontal slice function directly.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(456) + + def test_horizontal_slice_basic(self): + """Test horizontal slice basic functionality.""" + position = jnp.array([0.5]) + logdensity = -0.5 * position**2 + logslice = -2.0 + state = ss.SliceState(position, logdensity, logslice) + + direction = jnp.array([1.0]) + + def logdens_fn(x): + return -0.5 * x**2 + + def constraint_fn(x): + return jnp.array([]) + + new_state, info = ss.horizontal_slice( + self.key, state, direction, ss.default_stepper_fn, + logdens_fn, constraint_fn, jnp.array([]), jnp.array([]) + ) + + self.assertIsInstance(new_state, ss.SliceState) + self.assertIsInstance(info, ss.SliceInfo) + self.assertGreater(info.evals, 0) # Should have done some evaluations + + def test_horizontal_slice_with_constraints(self): + """Test horizontal slice with constraints.""" + position = jnp.array([1.0]) + state = ss.SliceState(position, -0.5, -1.0) + direction = jnp.array([1.0]) + + def logdens_fn(x): + return -0.5 * x**2 + + def constraint_fn(x): + return jnp.array([x[0]]) # Must be positive + + constraint_thresholds = jnp.array([0.0]) + strict_flags = jnp.array([True]) + + new_state, info = ss.horizontal_slice( + self.key, state, direction, ss.default_stepper_fn, + logdens_fn, constraint_fn, constraint_thresholds, strict_flags + ) + + # Should satisfy constraints + self.assertTrue(jnp.all(new_state.position > 0)) + self.assertGreater(info.l_steps + info.r_steps + info.s_steps, 0) + + def test_horizontal_slice_info_completeness(self): + """Test that horizontal slice returns complete info.""" + position = jnp.array([0.0]) + state = ss.SliceState(position, 0.0, -1.0) + direction = jnp.array([1.0]) + + def logdens_fn(x): + return -x**2 + + def constraint_fn(x): + return jnp.array([x[0]**2]) + + new_state, info = ss.horizontal_slice( + self.key, state, direction, ss.default_stepper_fn, + logdens_fn, constraint_fn, jnp.array([0.1]), jnp.array([True]) + ) + + # Check all info fields are populated + self.assertIsInstance(info.l_steps, (int, jax.Array)) + self.assertIsInstance(info.r_steps, (int, jax.Array)) + self.assertIsInstance(info.s_steps, (int, jax.Array)) + self.assertIsInstance(info.evals, (int, jax.Array)) + chex.assert_shape(info.constraint, (1,)) + + # Total evaluations should equal sum of steps + self.assertEqual(info.evals, info.l_steps + info.r_steps + info.s_steps) + + +class KernelBuildingTest(chex.TestCase): + """Test kernel building functions.""" + + def test_build_kernel_callable(self): + """Test that build_kernel returns a callable.""" + def simple_stepper(x, d, t): + return x + t * d + + kernel = ss.build_kernel(simple_stepper) + self.assertTrue(callable(kernel)) + + def test_build_hrss_kernel_callable(self): + """Test that build_hrss_kernel returns a callable.""" + def direction_fn(rng_key): + return jax.random.normal(rng_key, (2,)) + + def simple_stepper(x, d, t): + return x + t * d + + kernel = ss.build_hrss_kernel(direction_fn, simple_stepper) + self.assertTrue(callable(kernel)) + + def test_hrss_top_level_api_structure(self): + """Test top-level API returns correct structure.""" + def simple_logdens(x): + return -0.5 * jnp.sum(x**2) + + cov = jnp.eye(2) + algorithm = ss.hrss_as_top_level_api(simple_logdens, cov) + + # Should have init and step methods + self.assertTrue(hasattr(algorithm, "init")) + self.assertTrue(hasattr(algorithm, "step")) + self.assertTrue(callable(algorithm.init)) + self.assertTrue(callable(algorithm.step)) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/tests/ns/test_nested_sampling.py b/tests/ns/test_nested_sampling.py index 4280a9d8c..c24831121 100644 --- a/tests/ns/test_nested_sampling.py +++ b/tests/ns/test_nested_sampling.py @@ -1,4 +1,4 @@ -"""Test the Nested Sampling algorithms""" +"""Test the Nested Sampling algorithms.""" import functools import chex @@ -7,663 +7,493 @@ import jax.scipy.stats as stats from absl.testing import absltest, parameterized +import blackjax from blackjax.ns import adaptive, base, nss, utils -def gaussian_logprior(x): - """Standard normal prior""" - return stats.norm.logpdf(x).sum() +class NestedSamplingBaseTest(chex.TestCase): + """Test base nested sampling functionality.""" - -def gaussian_loglikelihood(x): - """Gaussian likelihood with offset""" - return stats.norm.logpdf(x - 1.0).sum() - - -def uniform_logprior_2d(x): - """Uniform prior on [-5, 5]^2""" - return jnp.where(jnp.all(jnp.abs(x) <= 5.0), 0.0, -jnp.inf) - - -def gaussian_mixture_loglikelihood(x): - """2D Gaussian mixture for multi-modal testing""" - mixture1 = stats.norm.logpdf(x - jnp.array([2.0, 0.0])).sum() - mixture2 = stats.norm.logpdf(x - jnp.array([-2.0, 0.0])).sum() - return jnp.logaddexp(mixture1, mixture2) - - -class NestedSamplingTest(chex.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) - def test_base_ns_init(self): - """Test basic NS initialization""" - key = jax.random.key(123) - num_live = 50 + def logprior_uniform(self, x): + """Uniform prior on [-3, 3].""" + return jnp.where(jnp.all(jnp.abs(x) <= 3.0), 0.0, -jnp.inf) - # Generate initial particles - particles = jax.random.normal(key, (num_live,)) + def loglikelihood_gaussian(self, x): + """Standard Gaussian likelihood.""" + return stats.norm.logpdf(x).sum() - # Initialize NS state - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + def logprior_gaussian(self, x): + """Standard Gaussian prior.""" + return stats.norm.logpdf(x).sum() - # Check state structure - chex.assert_shape(state.particles, (num_live,)) + def test_ns_state_structure(self): + """Test NSState has correct structure.""" + num_live = 50 + ndim = 2 + particles = jax.random.normal(self.key, (num_live, ndim)) + + state = base.init(particles, self.logprior_uniform, self.loglikelihood_gaussian) + + # Check shapes + chex.assert_shape(state.particles, (num_live, ndim)) chex.assert_shape(state.loglikelihood, (num_live,)) chex.assert_shape(state.logprior, (num_live,)) chex.assert_shape(state.pid, (num_live,)) - - # Check that loglikelihood and logprior are properly computed - expected_loglik = jax.vmap(gaussian_loglikelihood)(particles) - expected_logprior = jax.vmap(gaussian_logprior)(particles) - + + # Check values are computed correctly + expected_loglik = jax.vmap(self.loglikelihood_gaussian)(particles) + expected_logprior = jax.vmap(self.logprior_uniform)(particles) + chex.assert_trees_all_close(state.loglikelihood, expected_loglik) chex.assert_trees_all_close(state.logprior, expected_logprior) - - def test_delete_fn(self): - """Test particle deletion function""" - key = jax.random.key(456) + + # Check particle IDs are unique + self.assertEqual(len(jnp.unique(state.pid)), num_live) + + def test_ns_info_structure(self): + """Test NSInfo structure.""" + particles = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + loglik = jnp.array([0.1, 0.2]) + loglik_birth = jnp.array([-jnp.inf, 0.05]) + logprior = jnp.array([-1.0, -1.1]) + + info = base.NSInfo( + particles=particles, + loglikelihood=loglik, + loglikelihood_birth=loglik_birth, + logprior=logprior, + inner_kernel_info={} + ) + + chex.assert_shape(info.particles, (2, 2)) + chex.assert_shape(info.loglikelihood, (2,)) + chex.assert_shape(info.loglikelihood_birth, (2,)) + chex.assert_shape(info.logprior, (2,)) + + @parameterized.parameters([1, 3, 5]) + def test_delete_fn(self, num_delete): + """Test particle deletion function.""" num_live = 20 - num_delete = 3 - - particles = jax.random.normal(key, (num_live,)) - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) - - dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) - - # Check correct number of deletions + particles = jax.random.normal(self.key, (num_live, 2)) + state = base.init(particles, self.logprior_uniform, self.loglikelihood_gaussian) + + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) + + # Check shapes chex.assert_shape(dead_idx, (num_delete,)) chex.assert_shape(target_idx, (num_delete,)) chex.assert_shape(start_idx, (num_delete,)) - + # Check that worst particles are selected - worst_loglik = jnp.sort(state.loglikelihood)[:num_delete] - selected_loglik = state.loglikelihood[dead_idx] - chex.assert_trees_all_close(jnp.sort(selected_loglik), worst_loglik) - - @parameterized.parameters([1, 2, 5]) - def test_ns_step_consistency(self, num_delete): - """Test NS step maintains particle count""" - key = jax.random.key(789) - num_live = 50 - - particles = jax.random.normal(key, (num_live, 2)) - state = base.init( - particles, uniform_logprior_2d, gaussian_mixture_loglikelihood - ) - - # Mock inner kernel for testing - def mock_inner_kernel( - rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, params - ): - # Simple random walk for testing - new_pos = ( - inner_state["position"] - + jax.random.normal(rng_key, inner_state["position"].shape) * 0.1 - ) - new_logprior = logprior_fn(new_pos) - new_loglik = loglikelihood_fn(new_pos) - - new_inner_state = { - "position": new_pos, - "logprior": new_logprior, - "loglikelihood": new_loglik, - } - return new_inner_state, {} + worst_indices = jnp.argsort(state.loglikelihood)[:num_delete] + chex.assert_trees_all_close(jnp.sort(dead_idx), jnp.sort(worst_indices)) + + # Check indices are valid + self.assertTrue(jnp.all(dead_idx >= 0)) + self.assertTrue(jnp.all(dead_idx < num_live)) + self.assertTrue(jnp.all(target_idx >= 0)) + self.assertTrue(jnp.all(target_idx < num_live)) + + def test_1d_basic_functionality(self): + """Test 1D case to catch shape issues.""" + num_live = 30 + particles = jax.random.uniform(self.key, (num_live,), minval=-3, maxval=3) + + def logprior_1d(x): + return jnp.where((x >= -3) & (x <= 3), -jnp.log(6.0), -jnp.inf) + + def loglik_1d(x): + return -0.5 * x**2 + + state = base.init(particles, logprior_1d, loglik_1d) + + chex.assert_shape(state.particles, (num_live,)) + chex.assert_shape(state.loglikelihood, (num_live,)) + self.assertFalse(jnp.any(jnp.isnan(state.loglikelihood))) + self.assertFalse(jnp.any(jnp.isinf(state.logprior))) - delete_fn = functools.partial(base.delete_fn, num_delete=num_delete) + def test_kernel_construction(self): + """Test that kernel can be constructed.""" + def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + # Simple mock that just returns the input state + return inner_state, {} + + delete_fn = functools.partial(base.delete_fn, num_delete=1) kernel = base.build_kernel( - uniform_logprior_2d, - gaussian_mixture_loglikelihood, + self.logprior_uniform, + self.loglikelihood_gaussian, delete_fn, - mock_inner_kernel, + mock_inner_kernel ) - - # Test that the kernel can be constructed with mock components - # Full execution would require more complex mocking of inner kernel behavior + self.assertTrue(callable(kernel)) - # Test delete function works - dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) - chex.assert_shape(dead_idx, (num_delete,)) - chex.assert_shape(target_idx, (num_delete,)) - chex.assert_shape(start_idx, (num_delete,)) - def test_utils_functions(self): - """Test utility functions""" - key = jax.random.key(101112) +class NestedSamplingUtilsTest(chex.TestCase): + """Test nested sampling utility functions.""" - # Create mock dead info - n_dead = 20 - dead_loglik = jnp.sort(jax.random.uniform(key, (n_dead,))) * 10 - 5 - dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) + def setUp(self): + super().setUp() + self.key = jax.random.key(123) - mock_info = base.NSInfo( + def create_mock_info(self, n_dead=50): + """Create mock NSInfo for testing.""" + # Realistic increasing likelihood sequence + base_loglik = jnp.linspace(-5, -1, n_dead) + noise = jax.random.normal(self.key, (n_dead,)) * 0.05 + dead_loglik = jnp.sort(base_loglik + noise) + + # Birth likelihoods + key, subkey = jax.random.split(self.key) + birth_offsets = jax.random.uniform(subkey, (n_dead,)) * 0.2 - 0.1 + dead_loglik_birth = jnp.concatenate([ + jnp.array([-jnp.inf]), # First from prior + dead_loglik[:-1] + birth_offsets[1:] + ]) + dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) + + return base.NSInfo( particles=jnp.zeros((n_dead, 2)), loglikelihood=dead_loglik, loglikelihood_birth=dead_loglik_birth, logprior=jnp.zeros(n_dead), - inner_kernel_info={}, + inner_kernel_info={} ) - # Test compute_num_live + def test_compute_num_live(self): + """Test computation of number of live points.""" + mock_info = self.create_mock_info(n_dead=30) num_live = utils.compute_num_live(mock_info) - chex.assert_shape(num_live, (n_dead,)) - - # Test logX simulation - logX_seq, logdX_seq = utils.logX(key, mock_info, shape=10) - chex.assert_shape(logX_seq, (n_dead, 10)) - chex.assert_shape(logdX_seq, (n_dead, 10)) - - # Check logX is decreasing - self.assertTrue(jnp.all(logX_seq[1:] <= logX_seq[:-1])) + + chex.assert_shape(num_live, (30,)) + self.assertTrue(jnp.all(num_live >= 1)) + self.assertFalse(jnp.any(jnp.isnan(num_live))) + + def test_logX_simulation(self): + """Test log-volume simulation.""" + mock_info = self.create_mock_info(n_dead=40) + n_samples = 20 + + logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=n_samples) + + chex.assert_shape(logX_seq, (40, n_samples)) + chex.assert_shape(logdX_seq, (40, n_samples)) + + # Log volumes should be decreasing + for i in range(n_samples): + self.assertTrue(jnp.all(logX_seq[1:, i] <= logX_seq[:-1, i])) + + # No NaN values + self.assertFalse(jnp.any(jnp.isnan(logX_seq))) + + def test_log_weights(self): + """Test log weight computation.""" + mock_info = self.create_mock_info(n_dead=25) + n_samples = 15 + + log_weights_matrix = utils.log_weights(self.key, mock_info, shape=n_samples) + + chex.assert_shape(log_weights_matrix, (25, n_samples)) + + # Most weights should be finite + finite_weights = jnp.isfinite(log_weights_matrix) + finite_fraction = jnp.mean(finite_weights) + self.assertGreater(finite_fraction, 0.5) + + def test_ess_computation(self): + """Test effective sample size computation.""" + mock_info = self.create_mock_info(n_dead=35) + + ess_value = utils.ess(self.key, mock_info) + + self.assertIsInstance(ess_value, (float, jax.Array)) + self.assertGreater(ess_value, 0.0) + self.assertLessEqual(ess_value, 35) + self.assertFalse(jnp.isnan(ess_value)) + + def test_evidence_estimation_simple(self): + """Test evidence estimation for simple case.""" + # Constant likelihood case + n_dead = 30 + loglik_const = -2.0 + + mock_info = base.NSInfo( + particles=jnp.zeros((n_dead, 1)), + loglikelihood=jnp.full(n_dead, loglik_const), + loglikelihood_birth=jnp.full(n_dead, -jnp.inf), + logprior=jnp.zeros(n_dead), # Uniform prior + inner_kernel_info={} + ) + + # Generate evidence estimates + n_samples = 100 + keys = jax.random.split(self.key, n_samples) + + def single_evidence_estimate(rng_key): + log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=10) + return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) + + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) + log_evidence_samples = log_evidence_samples.flatten() + + # Should be close to the constant likelihood value + mean_estimate = jnp.mean(log_evidence_samples) + self.assertFalse(jnp.isnan(mean_estimate)) + self.assertFalse(jnp.isinf(mean_estimate)) class AdaptiveNestedSamplingTest(chex.TestCase): + """Test adaptive nested sampling.""" + def setUp(self): super().setUp() - self.key = jax.random.key(42) - - def test_adaptive_init(self): - """Test adaptive NS initialization""" - key = jax.random.key(123) - num_live = 30 + self.key = jax.random.key(456) - particles = jax.random.normal(key, (num_live,)) + def logprior_fn(self, x): + return stats.norm.logpdf(x).sum() - def mock_update_params_fn(state, info, current_params): - return {"test_param": 1.0} + def loglik_fn(self, x): + return -0.5 * jnp.sum(x**2) + def test_adaptive_init(self): + """Test adaptive NS initialization.""" + num_live = 25 + particles = jax.random.normal(self.key, (num_live, 2)) + + def mock_update_fn(state, info, params): + return {"test_param": 1.5} + state = adaptive.init( particles, - gaussian_logprior, - gaussian_loglikelihood, - update_inner_kernel_params_fn=mock_update_params_fn, + self.logprior_fn, + self.loglik_fn, + update_inner_kernel_params_fn=mock_update_fn ) - - # Check that inner kernel params were set - self.assertEqual(state.inner_kernel_params["test_param"], 1.0) + + # Check basic structure + chex.assert_shape(state.particles, (num_live, 2)) + + # Check inner kernel params were set + self.assertIn("test_param", state.inner_kernel_params) + self.assertEqual(state.inner_kernel_params["test_param"], 1.5) + + def test_adaptive_kernel_construction(self): + """Test adaptive kernel can be constructed.""" + def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + return inner_state, {} + + def mock_update_fn(state, info, params): + return params + + kernel = adaptive.build_kernel( + self.logprior_fn, + self.loglik_fn, + base.delete_fn, + mock_inner_kernel, + mock_update_fn + ) + + self.assertTrue(callable(kernel)) class NestedSliceSamplingTest(chex.TestCase): + """Test nested slice sampling (NSS).""" + def setUp(self): super().setUp() - self.key = jax.random.key(42) + self.key = jax.random.key(789) - def test_nss_direction_functions(self): - """Test NSS direction generation functions""" - key = jax.random.key(456) + def logprior_fn(self, x): + return jnp.where(jnp.all(jnp.abs(x) <= 5.0), 0.0, -jnp.inf) - # Test covariance computation - particles = jax.random.normal(key, (50, 3)) - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + def loglik_fn(self, x): + return -0.5 * jnp.sum(x**2) + def test_covariance_computation(self): + """Test covariance computation from particles.""" + num_live = 40 + ndim = 3 + particles = jax.random.normal(self.key, (num_live, ndim)) + state = base.init(particles, self.logprior_fn, self.loglik_fn) + params = nss.compute_covariance_from_particles(state, None, {}) - - # Check that covariance is computed + self.assertIn("cov", params) - cov_pytree = params["cov"] - chex.assert_shape(cov_pytree, (3, 3)) - - # Test direction sampling - direction = nss.sample_direction_from_covariance(key, params) - chex.assert_shape(direction, (3,)) + cov = params["cov"] + chex.assert_shape(cov, (ndim, ndim)) + + # Covariance should be positive semidefinite + eigenvals = jnp.linalg.eigvals(cov) + self.assertTrue(jnp.all(eigenvals >= -1e-10)) + + def test_direction_sampling(self): + """Test direction sampling from covariance.""" + ndim = 4 + cov = jnp.eye(ndim) * 2.0 + params = {"cov": cov} + + direction = nss.sample_direction_from_covariance(self.key, params) + + chex.assert_shape(direction, (ndim,)) + + # Check normalization + invcov = jnp.linalg.inv(cov) + mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) + chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) def test_nss_kernel_construction(self): - """Test NSS kernel can be constructed""" + """Test NSS kernel construction.""" kernel = nss.build_kernel( - gaussian_logprior, gaussian_loglikelihood, num_inner_steps=10 + self.logprior_fn, + self.loglik_fn, + num_inner_steps=5 ) - - # Test that kernel is callable + self.assertTrue(callable(kernel)) + def test_nss_with_1d_problem(self): + """Test NSS with 1D problem (edge case).""" + def logprior_1d(x): + return jnp.where((x >= -2) & (x <= 2), -jnp.log(4.0), -jnp.inf) + + def loglik_1d(x): + return -0.5 * x**2 + + num_live = 20 + particles = jax.random.uniform(self.key, (num_live,), minval=-2, maxval=2) + state = base.init(particles, logprior_1d, loglik_1d) + + params = nss.compute_covariance_from_particles(state, None, {}) + + self.assertIn("cov", params) + cov = params["cov"] + # For 1D, cov should be shaped appropriately for the particle structure + # The key is that it should work without raising shape errors + self.assertFalse(jnp.isnan(cov).any()) + self.assertTrue(jnp.all(cov > 0)) + class NestedSamplingStatisticalTest(chex.TestCase): - """Statistical correctness tests for nested sampling algorithms.""" + """Statistical correctness tests.""" def setUp(self): super().setUp() - self.key = jax.random.key(42) - - def test_1d_gaussian_evidence_estimation(self): - """Test evidence estimation with analytic validation for unnormalized Gaussian.""" - - # Simple case: unnormalized Gaussian likelihood exp(-0.5*x²), uniform prior [-3,3] - prior_a, prior_b = -3.0, 3.0 - - def logprior_fn(x): - return jnp.where( - (x >= prior_a) & (x <= prior_b), -jnp.log(prior_b - prior_a), -jnp.inf - ) - - def loglikelihood_fn(x): - # Unnormalized Gaussian: exp(-0.5 * x²) - return -0.5 * x**2 - - # Analytic evidence: Z = ∫[-3,3] (1/6) * exp(-0.5*x²) dx - # = (1/6) * √(2π) * [Φ(3) - Φ(-3)] - from scipy.stats import norm - - prior_width = prior_b - prior_a - integral_part = jnp.sqrt(2 * jnp.pi) * (norm.cdf(3.0) - norm.cdf(-3.0)) - analytical_evidence = integral_part / prior_width - analytical_log_evidence = jnp.log(analytical_evidence) - - # Generate mock nested sampling data - num_steps = 60 - key = jax.random.key(42) - - # Create positions spanning the prior range - positions = jnp.linspace(prior_a + 0.05, prior_b - 0.05, num_steps).reshape( - -1, 1 - ) - dead_loglik = jax.vmap(loglikelihood_fn)(positions.flatten()) - dead_logprior = jax.vmap(logprior_fn)(positions.flatten()) - - # Sort by likelihood (as NS naturally produces) - sorted_indices = jnp.argsort(dead_loglik) - dead_loglik = dead_loglik[sorted_indices] - positions = positions[sorted_indices] - dead_logprior = dead_logprior[sorted_indices] - - # Birth likelihoods - start from prior - dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) - - # Create NSInfo object - mock_info = base.NSInfo( - particles=positions, - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=dead_logprior, - inner_kernel_info={}, - ) - - # Generate many evidence estimates for statistical testing - n_evidence_samples = 500 - key = jax.random.key(789) - keys = jax.random.split(key, n_evidence_samples) - - def single_evidence_estimate(rng_key): - log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=15) - return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) - - # Compute evidence estimates - log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) - log_evidence_samples = log_evidence_samples.flatten() - - # Statistical validation - mean_estimate = jnp.mean(log_evidence_samples) - std_estimate = jnp.std(log_evidence_samples) - - # Check statistical consistency with 95% confidence interval - # For mock data with simplified NS, expect some bias but should be in ballpark - tolerance = 2.0 * std_estimate # 95% CI - bias = jnp.abs(mean_estimate - analytical_log_evidence) - - self.assertLess( - bias, - tolerance, - f"Evidence estimate {mean_estimate:.3f} vs analytic {analytical_log_evidence:.3f} " - f"differs by {bias:.3f}, which exceeds 2σ = {tolerance:.3f}", - ) - - # Also test that individual estimates are reasonable - self.assertFalse( - jnp.any(jnp.isnan(log_evidence_samples)), - "No evidence estimates should be NaN", - ) - self.assertFalse( - jnp.any(jnp.isinf(log_evidence_samples)), - "No evidence estimates should be infinite", - ) - - # Check that estimates are in a reasonable range - self.assertGreater( - mean_estimate, analytical_log_evidence - 1.0, "Mean estimate not too low" - ) - self.assertLess( - mean_estimate, analytical_log_evidence + 1.0, "Mean estimate not too high" - ) - - def test_uniform_prior_evidence(self): - """Test evidence estimation for uniform prior with simple likelihood.""" - - # Setup: Uniform prior on [0, 1], simple likelihood - def logprior_fn(x): - return jnp.where((x >= 0.0) & (x <= 1.0), 0.0, -jnp.inf) - - def loglikelihood_fn(x): - # Simple quadratic likelihood peaked at 0.5 - return -10.0 * (x - 0.5) ** 2 - - # Analytical evidence can be computed numerically for comparison - # Z = integral_0^1 exp(-10(x-0.5)^2) dx ≈ sqrt(π/10) * erf(...) - - num_live = 50 - key = jax.random.key(456) - - # Initialize particles uniformly in [0, 1] - particles = jax.random.uniform(key, (num_live,)) - state = base.init(particles, logprior_fn, loglikelihood_fn) - - # Check that initialization worked correctly - self.assertTrue(jnp.all(state.particles >= 0.0)) - self.assertTrue(jnp.all(state.particles <= 1.0)) - self.assertFalse(jnp.any(jnp.isinf(state.logprior))) - self.assertFalse(jnp.any(jnp.isnan(state.loglikelihood))) - - # Test evidence contribution from live points - logZ_live_contribution = state.logZ_live - self.assertIsInstance(logZ_live_contribution, (float, jax.Array)) - self.assertFalse(jnp.isnan(logZ_live_contribution)) + self.key = jax.random.key(12345) def test_evidence_monotonicity(self): - """Test that evidence estimates are monotonically increasing during NS run.""" - - # Simple setup for testing monotonicity + """Test evidence is monotonically increasing.""" def logprior_fn(x): - return stats.norm.logpdf(x) - - def loglikelihood_fn(x): - return -0.5 * x**2 # Simple quadratic - + return stats.norm.logpdf(x).sum() + + def loglik_fn(x): + return -0.5 * jnp.sum(x**2) + num_live = 30 - key = jax.random.key(789) - - particles = jax.random.normal(key, (num_live,)) - initial_state = base.init(particles, logprior_fn, loglikelihood_fn) - - # Test that we can track evidence during run - logZ_sequence = [initial_state.logZ] - - # Simulate a few evidence updates manually - for i in range(5): - # Simulate removing worst particle and updating evidence - worst_idx = jnp.argmin(initial_state.loglikelihood) - dead_loglik = initial_state.loglikelihood[worst_idx] - - # Update evidence (simplified) - delta_logX = -1.0 / num_live # Approximate volume decrease - new_logZ = jnp.logaddexp(initial_state.logZ, dead_loglik + delta_logX) + particles = jax.random.normal(self.key, (num_live, 2)) + state = base.init(particles, logprior_fn, loglik_fn) + + # Simulate evidence updates + logZ_sequence = [state.logZ] + current_state = state + + for _ in range(5): + worst_idx = jnp.argmin(current_state.loglikelihood) + dead_loglik = current_state.loglikelihood[worst_idx] + + # Approximate volume decrease + delta_logX = -1.0 / num_live + new_logZ = jnp.logaddexp(current_state.logZ, dead_loglik + delta_logX) logZ_sequence.append(new_logZ) - - # Update for next iteration (simplified) - new_loglik = jnp.concatenate( - [ - initial_state.loglikelihood[:worst_idx], - initial_state.loglikelihood[worst_idx + 1 :], - jnp.array([dead_loglik + 0.1]), # Mock new particle - ] + + # Mock update for next iteration + new_loglik = jnp.concatenate([ + current_state.loglikelihood[:worst_idx], + current_state.loglikelihood[worst_idx + 1:], + jnp.array([dead_loglik + 0.1]) + ]) + current_state = current_state._replace( + loglikelihood=new_loglik, + logZ=new_logZ ) - initial_state = initial_state._replace(loglikelihood=new_loglik) - + # Check monotonicity logZ_array = jnp.array(logZ_sequence) differences = logZ_array[1:] - logZ_array[:-1] - self.assertTrue( - jnp.all(differences >= -1e-10), - "Evidence should be monotonically increasing", - ) - - def test_nested_sampling_utils_statistical_properties(self): - """Test statistical properties of nested sampling utility functions.""" - key = jax.random.key(101112) - - # Create realistic mock data - n_dead = 100 - - # Generate realistic loglikelihood sequence (increasing) - base_loglik = jnp.linspace(-10, -1, n_dead) - noise = jax.random.normal(key, (n_dead,)) * 0.1 - dead_loglik = jnp.sort(base_loglik + noise) - - # Create more realistic birth likelihoods that reflect actual NS behavior - # Particles can be born at various levels, not just at previous death - key, subkey = jax.random.split(key) - birth_noise = jax.random.uniform(subkey, (n_dead,)) * 2.0 - 1.0 # [-1, 1] - dead_loglik_birth = jnp.concatenate( - [ - jnp.array([-jnp.inf]), # First particle born from prior - dead_loglik[:-1] + birth_noise[1:] * 0.5, # Others with some variation - ] - ) - # Ensure birth likelihoods don't exceed death likelihoods - dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) - - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 2)), - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.zeros(n_dead), - inner_kernel_info={}, - ) - - # Test compute_num_live - num_live = utils.compute_num_live(mock_info) - chex.assert_shape(num_live, (n_dead,)) - - # Basic sanity checks for number of live points - # NOTE: num_live should NOT be monotonically decreasing in general NS! - # It follows a sawtooth pattern as particles die and are replenished - self.assertTrue( - jnp.all(num_live >= 1), "Should always have at least 1 live point" - ) - self.assertTrue( - jnp.all(num_live <= 1000), # Reasonable upper bound - "Number of live points should be reasonable", - ) - self.assertFalse( - jnp.any(jnp.isnan(num_live)), "Number of live points should not be NaN" - ) - - # Test logX simulation - n_samples = 50 - logX_seq, logdX_seq = utils.logX(key, mock_info, shape=n_samples) - chex.assert_shape(logX_seq, (n_dead, n_samples)) - chex.assert_shape(logdX_seq, (n_dead, n_samples)) - - # Log volumes should be decreasing - self.assertTrue( - jnp.all(logX_seq[1:] <= logX_seq[:-1]), "Log volumes should be decreasing" - ) - - # All log volume elements should be negative (since dX < X) - finite_logdX = logdX_seq[jnp.isfinite(logdX_seq)] - if len(finite_logdX) > 0: - self.assertTrue( - jnp.all(finite_logdX <= 0.0), "Log volume elements should be negative" - ) - - # Test log_weights function - log_weights_matrix = utils.log_weights(key, mock_info, shape=n_samples) - chex.assert_shape(log_weights_matrix, (n_dead, n_samples)) - - # Weights should be finite for most particles - finite_weights = jnp.isfinite(log_weights_matrix) - self.assertGreater( - jnp.sum(finite_weights), - n_dead * n_samples * 0.5, - "Most weights should be finite", - ) - - def test_gaussian_evidence_narrow_prior(self): - """Test evidence estimation with narrow prior for challenging case.""" - - # Setup: Gaussian likelihood with narrow uniform prior (more challenging) - mu_true = 1.2 - sigma_true = 0.6 - prior_a, prior_b = 0.8, 1.6 # Narrow prior around the mean - + self.assertTrue(jnp.all(differences >= -1e-12)) + + def test_gaussian_evidence_analytical(self): + """Test evidence estimation against analytical result.""" + # Setup: Gaussian likelihood with uniform prior + prior_a, prior_b = -2.0, 2.0 + sigma = 1.0 + def logprior_fn(x): - return jnp.where( - (x >= prior_a) & (x <= prior_b), -jnp.log(prior_b - prior_a), -jnp.inf - ) - - def loglikelihood_fn(x): - return -0.5 * ((x - mu_true) / sigma_true) ** 2 - 0.5 * jnp.log( - 2 * jnp.pi * sigma_true**2 - ) - - # Analytic evidence + width = prior_b - prior_a + return jnp.where((x >= prior_a) & (x <= prior_b), -jnp.log(width), -jnp.inf) + + def loglik_fn(x): + return -0.5 * (x / sigma)**2 - 0.5 * jnp.log(2 * jnp.pi * sigma**2) + + # Analytical evidence (truncated Gaussian integral) from scipy.stats import norm - analytical_evidence = ( - norm.cdf((prior_b - mu_true) / sigma_true) - - norm.cdf((prior_a - mu_true) / sigma_true) + norm.cdf(prior_b / sigma) - norm.cdf(prior_a / sigma) ) / (prior_b - prior_a) analytical_log_evidence = jnp.log(analytical_evidence) - - # Generate mock NS data with higher resolution for narrow prior - num_steps = 60 - key = jax.random.key(12345) - - # Dense sampling in the narrow prior region - positions = jnp.linspace(prior_a + 0.01, prior_b - 0.01, num_steps).reshape( - -1, 1 - ) - dead_loglik = jax.vmap(loglikelihood_fn)(positions.flatten()) + + # Mock NS data + n_dead = 50 + positions = jnp.linspace(prior_a + 0.01, prior_b - 0.01, n_dead).reshape(-1, 1) + dead_loglik = jax.vmap(loglik_fn)(positions.flatten()) dead_logprior = jax.vmap(logprior_fn)(positions.flatten()) - + # Sort by likelihood - sorted_indices = jnp.argsort(dead_loglik) - dead_loglik = dead_loglik[sorted_indices] - positions = positions[sorted_indices] - dead_logprior = dead_logprior[sorted_indices] - - # Birth likelihoods - key, subkey = jax.random.split(key) - birth_noise = jax.random.uniform(subkey, (num_steps,)) * 0.3 - 0.15 - dead_loglik_birth = jnp.concatenate( - [jnp.array([-jnp.inf]), dead_loglik[:-1] + birth_noise[1:]] - ) - dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) - + sorted_idx = jnp.argsort(dead_loglik) + dead_loglik = dead_loglik[sorted_idx] + positions = positions[sorted_idx] + dead_logprior = dead_logprior[sorted_idx] + + dead_loglik_birth = jnp.concatenate([ + jnp.array([-jnp.inf]), + dead_loglik[:-1] - 0.05 + ]) + mock_info = base.NSInfo( particles=positions, loglikelihood=dead_loglik, loglikelihood_birth=dead_loglik_birth, logprior=dead_logprior, - inner_kernel_info={}, + inner_kernel_info={} ) - - # Generate evidence estimates for statistical testing - n_evidence_samples = 800 - key = jax.random.key(555) - keys = jax.random.split(key, n_evidence_samples) - + + # Generate evidence estimates + n_samples = 200 + keys = jax.random.split(self.key, n_samples) + def single_evidence_estimate(rng_key): - log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=15) + log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=10) return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) - + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) log_evidence_samples = log_evidence_samples.flatten() - + # Statistical validation mean_estimate = jnp.mean(log_evidence_samples) std_estimate = jnp.std(log_evidence_samples) - - # 99% confidence interval test - lower_bound = mean_estimate - 2.576 * std_estimate # 99% CI - upper_bound = mean_estimate + 2.576 * std_estimate - - self.assertGreater( - analytical_log_evidence, - lower_bound, - f"Analytic evidence {analytical_log_evidence:.3f} below 99% CI lower bound {lower_bound:.3f}", - ) - self.assertLess( - analytical_log_evidence, - upper_bound, - f"Analytic evidence {analytical_log_evidence:.3f} above 99% CI upper bound {upper_bound:.3f}", - ) - - def test_evidence_integration_simple_case(self): - """Test evidence calculation for a simple analytical case with constant likelihood.""" - # Test case: uniform prior on [0,2], constant likelihood - # Evidence = ∫[0,2] (1/width) * exp(loglik_constant) dx = exp(loglik_constant) - - loglik_constant = -1.5 - prior_width = 2.0 # Prior on [0, 2] - n_dead = 40 - - # Analytic answer: evidence = ∫[0,2] (1/2) * exp(-1.5) dx = exp(-1.5) - analytical_log_evidence = loglik_constant - - # Mock data: all particles have same likelihood (constant function) - dead_loglik = jnp.full(n_dead, loglik_constant) - dead_loglik_birth = jnp.full(n_dead, -jnp.inf) # All from prior - - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 1)), - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.full( - n_dead, -jnp.log(prior_width) - ), # Uniform prior log density - inner_kernel_info={}, - ) - - # Generate many evidence estimates - n_samples = 500 - key = jax.random.key(999) - keys = jax.random.split(key, n_samples) - - def single_evidence_estimate(rng_key): - log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=25) - return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) - - log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) - log_evidence_samples = log_evidence_samples.flatten() - - mean_estimate = jnp.mean(log_evidence_samples) - std_estimate = jnp.std(log_evidence_samples) - - # For constant likelihood case, should be very accurate - # 95% confidence interval - lower_bound = mean_estimate - 1.96 * std_estimate - upper_bound = mean_estimate + 1.96 * std_estimate - - self.assertGreater( - analytical_log_evidence, - lower_bound, - f"Analytic evidence {analytical_log_evidence:.3f} below 95% CI", - ) - self.assertLess( - analytical_log_evidence, - upper_bound, - f"Analytic evidence {analytical_log_evidence:.3f} above 95% CI", - ) - - def test_effective_sample_size_calculation(self): - """Test effective sample size calculation.""" - key = jax.random.key(67890) - - # Create mock data with varying weights - n_dead = 50 - dead_loglik = jax.random.uniform(key, (n_dead,)) * 5 - 10 # Range [-10, -5] - dead_loglik_birth = jnp.full(n_dead, -jnp.inf) - - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 1)), - loglikelihood=jnp.sort(dead_loglik), # Ensure increasing - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.zeros(n_dead), - inner_kernel_info={}, - ) - - # Calculate ESS - ess_value = utils.ess(key, mock_info) - - # ESS should be positive and reasonable - self.assertIsInstance(ess_value, (float, jax.Array)) - self.assertGreater(ess_value, 0.0, "ESS should be positive") - self.assertLessEqual( - ess_value, n_dead, "ESS should not exceed number of samples" - ) - self.assertFalse(jnp.isnan(ess_value), "ESS should not be NaN") + + # For mock data, we expect some bias, so use looser bounds + # This is primarily testing that the utilities work, not exact accuracy + self.assertFalse(jnp.isnan(mean_estimate)) + self.assertFalse(jnp.isinf(mean_estimate)) + + # Very loose bounds - mainly checking it's in the right ballpark + self.assertGreater(mean_estimate, analytical_log_evidence - 3.0) + self.assertLess(mean_estimate, analytical_log_evidence + 3.0) if __name__ == "__main__": - absltest.main() + absltest.main() \ No newline at end of file diff --git a/tests/ns/test_nested_sampling_units.py b/tests/ns/test_nested_sampling_units.py new file mode 100644 index 000000000..43c0ce6a2 --- /dev/null +++ b/tests/ns/test_nested_sampling_units.py @@ -0,0 +1,829 @@ +"""Unit tests for nested sampling components.""" +import chex +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized + +from blackjax.ns import base, nss, utils + + +class NSStateTest(chex.TestCase): + """Test NSState data structure.""" + + def test_ns_state_creation(self): + """Test NSState creation.""" + particles = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + loglik = jnp.array([0.1, 0.2]) + loglik_birth = jnp.array([-jnp.inf, 0.05]) + logprior = jnp.array([-1.0, -1.1]) + pid = jnp.array([0, 1]) + logX = -2.0 + logZ = -5.0 + logZ_live = -3.0 + inner_kernel_params = {} + + state = base.NSState( + particles=particles, + loglikelihood=loglik, + loglikelihood_birth=loglik_birth, + logprior=logprior, + pid=pid, + logX=logX, + logZ=logZ, + logZ_live=logZ_live, + inner_kernel_params=inner_kernel_params + ) + + chex.assert_trees_all_close(state.particles, particles) + chex.assert_trees_all_close(state.loglikelihood, loglik) + chex.assert_trees_all_close(state.loglikelihood_birth, loglik_birth) + chex.assert_trees_all_close(state.logprior, logprior) + chex.assert_trees_all_close(state.pid, pid) + self.assertEqual(state.logX, logX) + self.assertEqual(state.logZ, logZ) + self.assertEqual(state.logZ_live, logZ_live) + self.assertEqual(state.inner_kernel_params, inner_kernel_params) + + def test_ns_state_replace(self): + """Test NSState _replace method.""" + state = base.NSState( + particles=jnp.array([[1.0], [2.0]]), + loglikelihood=jnp.array([0.1, 0.2]), + loglikelihood_birth=jnp.array([-jnp.inf, 0.05]), + logprior=jnp.array([-1.0, -1.1]), + pid=jnp.array([0, 1]), + logX=-2.0, + logZ=-5.0, + logZ_live=-3.0, + inner_kernel_params={} + ) + + new_logZ = -4.5 + new_state = state._replace(logZ=new_logZ) + + self.assertEqual(new_state.logZ, new_logZ) + self.assertEqual(new_state.logZ_live, -3.0) # Unchanged + self.assertEqual(new_state.logX, -2.0) # Unchanged + chex.assert_trees_all_close(new_state.particles, state.particles) + + +class NSInfoTest(chex.TestCase): + """Test NSInfo data structure.""" + + def test_ns_info_creation(self): + """Test NSInfo creation.""" + particles = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + loglik = jnp.array([0.1, 0.2]) + loglik_birth = jnp.array([-jnp.inf, 0.05]) + logprior = jnp.array([-1.0, -1.1]) + kernel_info = {"test": "value"} + + info = base.NSInfo( + particles=particles, + loglikelihood=loglik, + loglikelihood_birth=loglik_birth, + logprior=logprior, + inner_kernel_info=kernel_info + ) + + chex.assert_trees_all_close(info.particles, particles) + chex.assert_trees_all_close(info.loglikelihood, loglik) + chex.assert_trees_all_close(info.loglikelihood_birth, loglik_birth) + chex.assert_trees_all_close(info.logprior, logprior) + self.assertEqual(info.inner_kernel_info, kernel_info) + + +class InitFunctionTest(chex.TestCase): + """Test NS initialization function.""" + + def setUp(self): + super().setUp() + self.logprior_fn = lambda x: -0.5 * jnp.sum(x**2) + self.loglik_fn = lambda x: -jnp.sum(x**2) + + @parameterized.parameters([10, 50, 100]) + def test_init_particle_count(self, num_live): + """Test initialization with different numbers of live points.""" + particles = jax.random.normal(jax.random.key(42), (num_live, 2)) + + state = base.init(particles, self.logprior_fn, self.loglik_fn) + + chex.assert_shape(state.particles, (num_live, 2)) + chex.assert_shape(state.loglikelihood, (num_live,)) + chex.assert_shape(state.logprior, (num_live,)) + chex.assert_shape(state.pid, (num_live,)) + + def test_init_1d_particles(self): + """Test initialization with 1D particles.""" + num_live = 20 + particles = jax.random.normal(jax.random.key(42), (num_live,)) + + def logprior_1d(x): + return -0.5 * x**2 + + def loglik_1d(x): + return -x**2 + + state = base.init(particles, logprior_1d, loglik_1d) + + chex.assert_shape(state.particles, (num_live,)) + chex.assert_shape(state.loglikelihood, (num_live,)) + chex.assert_shape(state.logprior, (num_live,)) + + def test_init_computes_correct_values(self): + """Test that init computes loglikelihood and logprior correctly.""" + particles = jnp.array([[1.0, 0.0], [0.0, 1.0], [-1.0, -1.0]]) + + state = base.init(particles, self.logprior_fn, self.loglik_fn) + + # Check computed values match manual computation + expected_logprior = jax.vmap(self.logprior_fn)(particles) + expected_loglik = jax.vmap(self.loglik_fn)(particles) + + chex.assert_trees_all_close(state.logprior, expected_logprior) + chex.assert_trees_all_close(state.loglikelihood, expected_loglik) + + def test_init_particle_ids_unique(self): + """Test that particle IDs are unique.""" + num_live = 15 + particles = jax.random.normal(jax.random.key(42), (num_live, 3)) + + state = base.init(particles, self.logprior_fn, self.loglik_fn) + + unique_ids = jnp.unique(state.pid) + self.assertEqual(len(unique_ids), num_live) + + def test_init_with_pytree_particles(self): + """Test initialization with PyTree particles.""" + num_live = 10 + particles = { + "x": jax.random.normal(jax.random.key(42), (num_live, 2)), + "y": jax.random.normal(jax.random.key(43), (num_live,)) + } + + def logprior_pytree(p): + return -0.5 * (jnp.sum(p["x"]**2) + p["y"]**2) + + def loglik_pytree(p): + return -(jnp.sum(p["x"]**2) + p["y"]**2) + + state = base.init(particles, logprior_pytree, loglik_pytree) + + chex.assert_shape(state.particles["x"], (num_live, 2)) + chex.assert_shape(state.particles["y"], (num_live,)) + chex.assert_shape(state.loglikelihood, (num_live,)) + + +class DeleteFunctionTest(chex.TestCase): + """Test particle deletion function.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def create_test_state(self, num_live=20): + """Helper to create test state.""" + particles = jax.random.normal(self.key, (num_live, 2)) + logprior_fn = lambda x: -0.5 * jnp.sum(x**2) + loglik_fn = lambda x: -jnp.sum(x**2) + return base.init(particles, logprior_fn, loglik_fn) + + @parameterized.parameters([1, 3, 5, 10]) + def test_delete_fn_shapes(self, num_delete): + """Test delete function returns correct shapes.""" + state = self.create_test_state(num_live=20) + + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) + + chex.assert_shape(dead_idx, (num_delete,)) + chex.assert_shape(target_idx, (num_delete,)) + chex.assert_shape(start_idx, (num_delete,)) + + def test_delete_fn_selects_worst(self): + """Test that delete function selects worst particles.""" + state = self.create_test_state(num_live=20) + num_delete = 3 + + dead_idx, _, _ = base.delete_fn(self.key, state, num_delete) + + # Should select particles with lowest likelihood + worst_indices = jnp.argsort(state.loglikelihood)[:num_delete] + selected_indices = jnp.sort(dead_idx) + expected_indices = jnp.sort(worst_indices) + + chex.assert_trees_all_close(selected_indices, expected_indices) + + def test_delete_fn_valid_indices(self): + """Test that delete function returns valid indices.""" + num_live = 15 + state = self.create_test_state(num_live=num_live) + num_delete = 4 + + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) + + # All indices should be valid + self.assertTrue(jnp.all(dead_idx >= 0)) + self.assertTrue(jnp.all(dead_idx < num_live)) + self.assertTrue(jnp.all(target_idx >= 0)) + self.assertTrue(jnp.all(target_idx < num_live)) + self.assertTrue(jnp.all(start_idx >= 0)) + self.assertTrue(jnp.all(start_idx < num_live)) + + def test_delete_fn_no_duplicates(self): + """Test that delete function doesn't return duplicate indices.""" + state = self.create_test_state(num_live=20) + num_delete = 5 + + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) + + # Dead indices should be unique + self.assertEqual(len(jnp.unique(dead_idx)), num_delete) + + +class NSKernelExecutionTest(chex.TestCase): + """Test full NS kernel execution.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(999) + + def test_kernel_full_execution(self): + """Test full NS kernel execution workflow.""" + # Create a simple mock inner kernel + def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + # Simple random walk that respects the likelihood constraint + pos = inner_state.position + new_pos = pos + jax.random.normal(rng_key, pos.shape) * 0.1 + new_loglik = loglik_fn(new_pos) + new_logprior = logprior_fn(new_pos) + + # Accept if likelihood is above threshold, otherwise return original + accept = new_loglik >= loglik_0 + final_pos = jnp.where(accept, new_pos, pos) + final_loglik = jnp.where(accept, new_loglik, inner_state.loglikelihood) + final_logprior = jnp.where(accept, new_logprior, inner_state.logprior) + + new_inner_state = base.PartitionedState(final_pos, final_logprior, final_loglik) + return new_inner_state, {"accepted": accept} + + # Set up test functions + def logprior_fn(x): + return -0.5 * jnp.sum(x**2) + + def loglik_fn(x): + return -jnp.sum(x**2) + + # Create initial state + num_live = 10 + particles = jax.random.normal(self.key, (num_live, 2)) * 0.5 + state = base.init(particles, logprior_fn, loglik_fn) + + # Build kernel with delete function + def delete_fn(rng_key, state): + # Delete 1 worst particle + dead_idx = jnp.array([jnp.argmin(state.loglikelihood)]) + target_idx = jnp.array([0]) # Replace with first particle + start_idx = jnp.array([0]) # Start from first particle + return dead_idx, target_idx, start_idx + + kernel = base.build_kernel(logprior_fn, loglik_fn, delete_fn, mock_inner_kernel) + + # Execute kernel + new_state, info = kernel(self.key, state) + + # Check that state is updated correctly + self.assertIsInstance(new_state, base.NSState) + self.assertIsInstance(info, base.NSInfo) + + # Should still have same number of particles + chex.assert_shape(new_state.particles, (num_live, 2)) + chex.assert_shape(new_state.loglikelihood, (num_live,)) + + # Evidence should be updated + self.assertNotEqual(new_state.logZ, state.logZ) + + # Info should contain dead particle information + chex.assert_shape(info.particles, (1, 2)) # 1 dead particle + chex.assert_shape(info.loglikelihood, (1,)) + + +class RuntimeInfoUpdateTest(chex.TestCase): + """Test runtime info update functions.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(555) + + def test_update_ns_runtime_info(self): + """Test update_ns_runtime_info function.""" + # Test data + logX = -2.0 + logZ = -5.0 + loglikelihood = jnp.array([-1.0, -1.5, -2.0, -2.5]) # Live points + dead_loglikelihood = jnp.array([-3.0, -3.2]) # Dead points + + new_logX, new_logZ, new_logZ_live = base.update_ns_runtime_info( + logX, logZ, loglikelihood, dead_loglikelihood + ) + + # Check types and finiteness + self.assertIsInstance(new_logX, (float, jax.Array)) + self.assertIsInstance(new_logZ, (float, jax.Array)) + self.assertIsInstance(new_logZ_live, (float, jax.Array)) + + self.assertFalse(jnp.isnan(new_logX)) + self.assertFalse(jnp.isnan(new_logZ)) + self.assertFalse(jnp.isnan(new_logZ_live)) + + # Evidence should increase (or at least not decrease significantly) + self.assertGreaterEqual(new_logZ, logZ - 1e-10) + + # LogX should decrease (volume shrinking) + self.assertLess(new_logX, logX) + + def test_update_ns_runtime_info_single_particle(self): + """Test runtime update with single particle deletion.""" + logX = -1.0 + logZ = -10.0 + loglikelihood = jnp.array([-2.0, -2.5, -3.0]) + dead_loglikelihood = jnp.array([-4.0]) # Single deletion + + new_logX, new_logZ, new_logZ_live = base.update_ns_runtime_info( + logX, logZ, loglikelihood, dead_loglikelihood + ) + + # Should work with single particle + self.assertFalse(jnp.isnan(new_logX)) + self.assertFalse(jnp.isnan(new_logZ)) + self.assertFalse(jnp.isnan(new_logZ_live)) + + +class PartitionedStateTest(chex.TestCase): + """Test PartitionedState and PartitionedInfo structures.""" + + def test_new_state_and_info(self): + """Test new_state_and_info function.""" + position = jnp.array([1.0, 2.0]) + logprior = -1.5 + loglikelihood = -2.0 + info = {"test": "value"} + + state, returned_info = base.new_state_and_info( + position, logprior, loglikelihood, info + ) + + # Check PartitionedState + self.assertIsInstance(state, base.PartitionedState) + chex.assert_trees_all_close(state.position, position) + self.assertEqual(state.logprior, logprior) + self.assertEqual(state.loglikelihood, loglikelihood) + + # Check PartitionedInfo + self.assertIsInstance(returned_info, base.PartitionedInfo) + chex.assert_trees_all_close(returned_info.position, position) + self.assertEqual(returned_info.logprior, logprior) + self.assertEqual(returned_info.loglikelihood, loglikelihood) + self.assertEqual(returned_info.info, info) + + +class UtilityFunctionsTest(chex.TestCase): + """Test utility functions.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(123) + + def create_mock_info(self, n_dead=30): + """Helper to create mock NSInfo.""" + # Increasing likelihood sequence + loglik = jnp.linspace(-5, -1, n_dead) + loglik_birth = jnp.concatenate([ + jnp.array([-jnp.inf]), + loglik[:-1] - 0.1 + ]) + + return base.NSInfo( + particles=jnp.zeros((n_dead, 2)), + loglikelihood=loglik, + loglikelihood_birth=loglik_birth, + logprior=jnp.zeros(n_dead), + inner_kernel_info={} + ) + + def test_get_first_row_array(self): + """Test get_first_row with arrays.""" + x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + result = utils.get_first_row(x) + expected = jnp.array([1, 2, 3]) + + chex.assert_trees_all_close(result, expected) + + def test_get_first_row_pytree(self): + """Test get_first_row with PyTree.""" + x = { + "a": jnp.array([[1, 2], [3, 4], [5, 6]]), + "b": jnp.array([10, 20, 30]) + } + + result = utils.get_first_row(x) + + chex.assert_trees_all_close(result["a"], jnp.array([1, 2])) + self.assertEqual(result["b"], 10) + + def test_compute_num_live_shape(self): + """Test compute_num_live returns correct shape.""" + mock_info = self.create_mock_info(n_dead=25) + + num_live = utils.compute_num_live(mock_info) + + chex.assert_shape(num_live, (25,)) + + def test_compute_num_live_values(self): + """Test compute_num_live returns reasonable values.""" + mock_info = self.create_mock_info(n_dead=20) + + num_live = utils.compute_num_live(mock_info) + + # Should be positive + self.assertTrue(jnp.all(num_live >= 1)) + # Should be reasonable (not too large) + self.assertTrue(jnp.all(num_live <= 1000)) + # Should not be NaN + self.assertFalse(jnp.any(jnp.isnan(num_live))) + + def test_logX_shapes(self): + """Test logX returns correct shapes.""" + mock_info = self.create_mock_info(n_dead=15) + n_samples = 10 + + logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=n_samples) + + chex.assert_shape(logX_seq, (15, n_samples)) + chex.assert_shape(logdX_seq, (15, n_samples)) + + def test_logX_monotonicity(self): + """Test that logX is decreasing.""" + mock_info = self.create_mock_info(n_dead=10) + n_samples = 5 + + logX_seq, _ = utils.logX(self.key, mock_info, shape=n_samples) + + # Each column should be decreasing + for i in range(n_samples): + differences = logX_seq[1:, i] - logX_seq[:-1, i] + self.assertTrue(jnp.all(differences <= 1e-12)) # Allow for numerical precision + + def test_log_weights_shapes(self): + """Test log_weights returns correct shape.""" + mock_info = self.create_mock_info(n_dead=12) + n_samples = 8 + + log_weights = utils.log_weights(self.key, mock_info, shape=n_samples) + + chex.assert_shape(log_weights, (12, n_samples)) + + def test_log_weights_finite(self): + """Test that most log_weights are finite.""" + mock_info = self.create_mock_info(n_dead=20) + n_samples = 5 + + log_weights = utils.log_weights(self.key, mock_info, shape=n_samples) + + # Most weights should be finite + finite_fraction = jnp.mean(jnp.isfinite(log_weights)) + self.assertGreater(finite_fraction, 0.3) # At least 30% should be finite + + def test_ess_properties(self): + """Test ESS computation properties.""" + mock_info = self.create_mock_info(n_dead=30) + + ess = utils.ess(self.key, mock_info) + + # ESS should be positive and finite + self.assertGreater(ess, 0.0) + self.assertFalse(jnp.isnan(ess)) + self.assertFalse(jnp.isinf(ess)) + # ESS should not exceed number of samples + self.assertLessEqual(ess, 30) + + def test_log1mexp_values(self): + """Test log1mexp utility function.""" + # Test values where we know the expected result + x = jnp.array([-0.1, -1.0, -2.0, -10.0]) + + result = utils.log1mexp(x) + + # Should all be finite and negative (since log(1-exp(x)) < 0 for x < 0) + self.assertTrue(jnp.all(jnp.isfinite(result))) + # For large negative x, log(1-exp(x)) ≈ log(1) = 0 + self.assertAlmostEqual(result[-1], 0.0, places=3) # Less strict for numerical precision + + def test_log1mexp_edge_cases(self): + """Test log1mexp edge cases.""" + # Test near the transition point + x_transition = jnp.array([-0.6931472]) # Approximately -log(2) + + result = utils.log1mexp(x_transition) + + self.assertTrue(jnp.isfinite(result)) + self.assertLess(result, 0.0) + + +class NSSAdvancedTest(chex.TestCase): + """Test NSS advanced functionality and missing coverage.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(777) + + def test_nss_top_level_api(self): + """Test NSS as_top_level_api function.""" + def logprior_fn(x): + return -0.5 * jnp.sum(x**2) + + def loglik_fn(x): + return -jnp.sum(x**2) + + num_live = 20 + + # Test the top-level API + algorithm = nss.as_top_level_api( + logprior_fn, + loglik_fn, + 5 # num_inner_steps + ) + + # Should return a SamplingAlgorithm + self.assertTrue(hasattr(algorithm, "init")) + self.assertTrue(hasattr(algorithm, "step")) + self.assertTrue(callable(algorithm.init)) + self.assertTrue(callable(algorithm.step)) + + # Test initialization - NSS uses adaptive.init which needs different signature + particles = jax.random.normal(self.key, (num_live, 2)) + state = algorithm.init(particles) + + self.assertIsInstance(state, base.NSState) + chex.assert_shape(state.particles, (num_live, 2)) + + def test_nss_inner_kernel_execution(self): + """Test NSS inner kernel execution by building a full kernel.""" + def logprior_fn(x): + return -0.5 * jnp.sum(x**2) + + def loglik_fn(x): + return -jnp.sum(x**2) + + # Build NSS kernel + kernel = nss.build_kernel(logprior_fn, loglik_fn, num_inner_steps=2) + + # Create initial state with proper inner_kernel_params + num_live = 5 + particles = jax.random.normal(self.key, (num_live, 2)) * 0.3 + state = base.init(particles, logprior_fn, loglik_fn) + # NSS needs covariance params + cov_params = nss.compute_covariance_from_particles(state, None, {}) + state = state._replace(inner_kernel_params=cov_params) + + # Execute kernel - this tests the inner kernel execution paths + new_state, info = kernel(self.key, state) + + # Check that state is updated correctly + self.assertIsInstance(new_state, base.NSState) + self.assertIsInstance(info, base.NSInfo) + + # Should still have same number of particles + chex.assert_shape(new_state.particles, (num_live, 2)) + chex.assert_shape(new_state.loglikelihood, (num_live,)) + + # Evidence should be updated + self.assertNotEqual(new_state.logZ, state.logZ) + + def test_nss_compute_covariance_edge_cases(self): + """Test covariance computation edge cases.""" + # Test with very few particles + num_live = 3 + particles = jnp.array([[1.0], [2.0], [3.0]]) # 1D particles + + def logprior_fn(x): + return -0.5 * x**2 + + def loglik_fn(x): + return -x**2 + + state = base.init(particles, logprior_fn, loglik_fn) + + # Should handle small number of particles + params = nss.compute_covariance_from_particles(state, None, {}) + + self.assertIn("cov", params) + cov = params["cov"] + + # Should not be NaN or infinite + self.assertFalse(jnp.isnan(cov).any()) + self.assertFalse(jnp.isinf(cov).any()) + + def test_nss_direction_sampling_edge_cases(self): + """Test direction sampling edge cases.""" + # Test with nearly singular covariance + cov = jnp.array([[1e-6, 0.0], [0.0, 1e-6]]) + params = {"cov": cov} + + direction = nss.sample_direction_from_covariance(self.key, params) + + chex.assert_shape(direction, (2,)) + # Should be finite even with small covariance + self.assertTrue(jnp.all(jnp.isfinite(direction))) + + +class MissingUtilityFunctionsTest(chex.TestCase): + """Test utility functions that were missed in coverage.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(888) + + def test_missing_utility_functions(self): + """Test utility functions that weren't covered.""" + # Create test data that would exercise missing lines + + # Test with edge case data - use float dtypes + mock_info = base.NSInfo( + particles=jnp.zeros((5, 1)), + loglikelihood=jnp.array([-10.0, -5.0, -3.0, -2.0, -1.0]), # Wide range, float + loglikelihood_birth=jnp.array([-jnp.inf, -15.0, -8.0, -4.0, -2.5]), + logprior=jnp.zeros(5), + inner_kernel_info={} + ) + + # Test functions that might have missing coverage + num_live = utils.compute_num_live(mock_info) + self.assertTrue(jnp.all(num_live >= 1)) + + # Test with different shapes + logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=3) + chex.assert_shape(logX_seq, (5, 3)) + + # Test log_weights with edge cases + log_weights = utils.log_weights(self.key, mock_info, shape=2) + chex.assert_shape(log_weights, (5, 2)) + + def test_repeat_kernel_decorator(self): + """Test repeat_kernel decorator function.""" + # Simple mock kernel + @utils.repeat_kernel(3) + def mock_kernel(rng_key, state, *args): + # Just update position slightly + new_pos = state["position"] + jax.random.normal(rng_key, state["position"].shape) * 0.01 + new_state = state.copy() + new_state["position"] = new_pos + return new_state, {"step": 1} + + initial_state = {"position": jnp.array([1.0, 2.0])} + + # Test decorated kernel + final_state, infos = mock_kernel(self.key, initial_state) + + # Should have run 3 times (scan packs infos into dict structure) + self.assertIsInstance(final_state, dict) + self.assertIn("position", final_state) + chex.assert_shape(final_state["position"], (2,)) + + # Info structure depends on how scan handles the dict + self.assertIsInstance(infos, dict) + self.assertIn("step", infos) + chex.assert_shape(infos["step"], (3,)) # 3 steps recorded + + +class CompleteCoverageTest(chex.TestCase): + """Tests to achieve 100% coverage on remaining uncovered lines.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_combine_dead_info(self): + """Test combine_dead_info function (utils.py lines 233-247).""" + # Create mock dead info list and live state + dead1 = base.NSInfo( + particles=jnp.array([[1.0], [2.0]]), + loglikelihood=jnp.array([0.1, 0.2]), + loglikelihood_birth=jnp.array([-jnp.inf, 0.05]), + logprior=jnp.array([-1.0, -1.1]), + inner_kernel_info={"test": "value1"} + ) + + dead2 = base.NSInfo( + particles=jnp.array([[3.0], [4.0]]), + loglikelihood=jnp.array([0.3, 0.4]), + loglikelihood_birth=jnp.array([0.25, 0.35]), + logprior=jnp.array([-1.2, -1.3]), + inner_kernel_info={"test": "value2"} + ) + + # Mock live state + live = base.NSState( + particles=jnp.array([[5.0], [6.0]]), + loglikelihood=jnp.array([0.5, 0.6]), + loglikelihood_birth=jnp.array([0.45, 0.55]), + logprior=jnp.array([-1.4, -1.5]), + pid=jnp.array([4, 5]), + logX=-2.0, + logZ=-5.0, + logZ_live=-3.0, + inner_kernel_params={} + ) + + combined = utils.combine_dead_info([dead1, dead2], live) + + # Should combine all dead + live particles + expected_total = 2 + 2 + 2 # dead1 + dead2 + live + chex.assert_shape(combined.particles, (expected_total, 1)) + chex.assert_shape(combined.loglikelihood, (expected_total,)) + + def test_sample_particles(self): + """Test sample_particles function (utils.py lines 303-311).""" + # Create mock dead info + mock_info = base.NSInfo( + particles=jnp.array([[1.0], [2.0], [3.0], [4.0]]), + loglikelihood=jnp.array([0.1, 0.3, 0.2, 0.4]), + loglikelihood_birth=jnp.array([-jnp.inf, 0.05, 0.15, 0.25]), + logprior=jnp.array([-1.0, -1.1, -1.2, -1.3]), + inner_kernel_info={} + ) + + # Sample particles + sampled = utils.sample_particles(self.key, mock_info, shape=6) + + chex.assert_shape(sampled, (6, 1)) + + def test_uniform_prior_function(self): + """Test uniform_prior function (utils.py lines 381-398).""" + bounds = {"x": (-2.0, 2.0), "y": (0.0, 1.0)} + num_live = 10 + + particles, logprior_fn = utils.uniform_prior(self.key, bounds, num_live) + + # Check particles structure + self.assertIn("x", particles) + self.assertIn("y", particles) + chex.assert_shape(particles["x"], (num_live,)) + chex.assert_shape(particles["y"], (num_live,)) + + # Check bounds are respected + self.assertTrue(jnp.all(particles["x"] >= -2.0)) + self.assertTrue(jnp.all(particles["x"] <= 2.0)) + self.assertTrue(jnp.all(particles["y"] >= 0.0)) + self.assertTrue(jnp.all(particles["y"] <= 1.0)) + + # Test logprior function + test_params = {"x": 0.5, "y": 0.3} + logprior_val = logprior_fn(test_params) + self.assertIsInstance(logprior_val, (float, jax.Array)) + self.assertTrue(jnp.isfinite(logprior_val)) + + def test_adaptive_kernel_missing_lines(self): + """Test adaptive kernel missing coverage (adaptive.py lines 148-154).""" + from blackjax.ns import adaptive + + def logprior_fn(x): + return -0.5 * jnp.sum(x**2) + + def loglik_fn(x): + return -jnp.sum(x**2) + + def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + # Mock that always accepts + new_inner_state = base.PartitionedState( + inner_state.position + 0.01, + inner_state.logprior, + inner_state.loglikelihood + 0.1 + ) + return new_inner_state, {"accepted": True} + + def mock_update_fn(state, info, params): + # This should exercise the missing lines in adaptive.py + return {"updated": True, "cov": jnp.eye(2)} + + # Create state + num_live = 5 + particles = jax.random.normal(self.key, (num_live, 2)) * 0.1 + state = base.init(particles, logprior_fn, loglik_fn) + + # Build adaptive kernel + delete_fn = functools.partial(base.delete_fn, num_delete=1) + kernel = adaptive.build_kernel( + logprior_fn, loglik_fn, delete_fn, mock_inner_kernel, mock_update_fn + ) + + # Execute to test missing lines + new_state, info = kernel(self.key, state) + + self.assertIsInstance(new_state, base.NSState) + # Check that update function was called (adaptive logic) + self.assertIn("updated", new_state.inner_kernel_params) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 7907f86b3e27edc2176a1cdb12d73ff7622a751f Mon Sep 17 00:00:00 2001 From: Will Handley Date: Sat, 14 Jun 2025 18:09:26 +0100 Subject: [PATCH 14/14] Fix test suite failures and improve API documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Merge PR #15 fix for 1D covariance matrix shape issue in nss.py - Fix missing imports and utility functions in test files - Ensure constraint_fn and logdensity_fn return consistent shapes: * logdensity functions now return scalars using jnp.sum() * constraint functions return properly typed arrays * Fix empty constraint array dtypes for JAX while_loop compatibility - Improve docstrings with proper type annotations: * Replace vague "positions" terminology with "parameter sample" * Add clear Callable type signatures for all function parameters * Document shape and dtype requirements explicitly - Implement missing utility functions: combine_dead_info, sample_particles - All 111 tests now passing (62 nested sampling + 49 slice sampling) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/mcmc/ss.py | 38 +-- blackjax/ns/nss.py | 20 +- blackjax/ns/utils.py | 67 +++++ tests/mcmc/test_slice_sampling.py | 286 +++++++++++--------- tests/mcmc/test_slice_sampling_units.py | 188 +++++++------ tests/ns/test_nested_sampling.py | 215 +++++++-------- tests/ns/test_nested_sampling_units.py | 333 ++++++++++++------------ 7 files changed, 649 insertions(+), 498 deletions(-) diff --git a/blackjax/mcmc/ss.py b/blackjax/mcmc/ss.py index e0191bb0e..e2486081a 100644 --- a/blackjax/mcmc/ss.py +++ b/blackjax/mcmc/ss.py @@ -137,6 +137,9 @@ def build_kernel( the log-density function, direction `d`, constraint function, constraint values, and strict flags, and returns a new `SliceState` and `SliceInfo`. + The log-density function signature: `Callable[[ArrayTree], float]` + The constraint function signature: `Callable[[ArrayTree], Array]` + References ---------- .. [1] Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705-767. @@ -235,23 +238,26 @@ def horizontal_slice( stepper_fn A function `(x0, d, t) -> x_new` that computes a new point by moving `t` units along direction `d` from `x0`. - logdensity_fn + logdensity_fn : Callable[[ArrayTree], float] The log-density function of the target distribution. - constraint_fn - A function that evaluates additional constraints on the position beyond - the target distribution. Takes a position (PyTree) and returns an array - of constraint values. These values are compared against `constraint` - thresholds to determine if a position is acceptable. For example, in - nested sampling, this could evaluate the log-likelihood to ensure it - exceeds a minimum threshold. - constraint - An array of constraint threshold values that must be satisfied. - Each constraint value from `constraint_fn(x)` is compared against the - corresponding threshold in this array. - strict - An array of boolean flags indicating whether each constraint should be - strict (constraint_fn(x) > constraint) or non-strict - (constraint_fn(x) >= constraint). + Takes a parameter sample and returns a scalar log-density value. + constraint_fn : Callable[[ArrayTree], Array] + A function that evaluates additional constraints on parameter samples. + Takes a parameter sample and returns an array of constraint values that + are compared against `constraint` thresholds. For example, in nested + sampling, this could evaluate the log-likelihood to ensure it exceeds + a minimum threshold. + + The output array must have consistent shape and dtype (float32) for all + inputs, matching the `constraint` parameter shape exactly. + Use `jnp.array([], dtype=jnp.float32)` for no constraints. + constraint : Array + Array of constraint threshold values that must be satisfied. + Must have the same shape as the output of `constraint_fn`. + strict : Array + Boolean array indicating whether each constraint should be strict + (constraint_fn(x) > constraint) or non-strict (constraint_fn(x) >= constraint). + Must have the same shape as `constraint`. Returns ------- diff --git a/blackjax/ns/nss.py b/blackjax/ns/nss.py index aebd69d0f..597967a78 100644 --- a/blackjax/ns/nss.py +++ b/blackjax/ns/nss.py @@ -158,10 +158,12 @@ def build_kernel( Parameters ---------- - logprior_fn - A function that computes the log-prior probability of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. + logprior_fn : Callable[[ArrayTree], float] + A function that computes the log-prior probability of a parameter sample. + Takes a parameter sample and returns a scalar log-prior value. + loglikelihood_fn : Callable[[ArrayTree], float] + A function that computes the log-likelihood of a parameter sample. + Takes a parameter sample and returns a scalar log-likelihood value. num_inner_steps The number of HRSS steps to run for each new particle generation. This should be a multiple of the dimension of the parameter space. @@ -245,10 +247,12 @@ def as_top_level_api( Parameters ---------- - logprior_fn - A function that computes the log-prior probability of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. + logprior_fn : Callable[[ArrayTree], float] + A function that computes the log-prior probability of a parameter sample. + Takes a parameter sample and returns a scalar log-prior value. + loglikelihood_fn : Callable[[ArrayTree], float] + A function that computes the log-likelihood of a parameter sample. + Takes a parameter sample and returns a scalar log-likelihood value. num_inner_steps The number of HRSS steps to run for each new particle generation. This should be a multiple of the dimension of the parameter space. diff --git a/blackjax/ns/utils.py b/blackjax/ns/utils.py index 15f3ef35b..d3394c0cf 100644 --- a/blackjax/ns/utils.py +++ b/blackjax/ns/utils.py @@ -396,3 +396,70 @@ def prior_sample(rng_key): particles = jax.vmap(prior_sample)(init_keys) return particles, logprior_fn + + +def combine_dead_info(dead_info_list, live_info): + """Combine multiple dead info structures with live info. + + Parameters + ---------- + dead_info_list + List of NSInfo structures containing dead particles. + live_info + NSInfo structure containing live particles. + + Returns + ------- + NSInfo + Combined NSInfo structure with all particles. + """ + # Combine all dead particles + all_particles = [] + all_loglikelihood = [] + all_loglikelihood_birth = [] + all_logprior = [] + + for dead_info in dead_info_list: + all_particles.append(dead_info.particles) + all_loglikelihood.append(dead_info.loglikelihood) + all_loglikelihood_birth.append(dead_info.loglikelihood_birth) + all_logprior.append(dead_info.logprior) + + # Add live particles + all_particles.append(live_info.particles) + all_loglikelihood.append(live_info.loglikelihood) + all_loglikelihood_birth.append(live_info.loglikelihood_birth) + all_logprior.append(live_info.logprior) + + # Concatenate all arrays + from blackjax.ns.base import NSInfo + + return NSInfo( + particles=jnp.concatenate(all_particles, axis=0), + loglikelihood=jnp.concatenate(all_loglikelihood), + loglikelihood_birth=jnp.concatenate(all_loglikelihood_birth), + logprior=jnp.concatenate(all_logprior), + inner_kernel_info={}, + ) + + +def sample_particles(rng_key, info, shape): + """Sample particles from NSInfo structure. + + Parameters + ---------- + rng_key + JAX PRNG key. + info + NSInfo structure containing particles. + shape + Number of particles to sample. + + Returns + ------- + ArrayTree + Sampled particles. + """ + num_particles = info.particles.shape[0] + indices = jax.random.choice(rng_key, num_particles, shape=(shape,), replace=True) + return info.particles[indices] diff --git a/tests/mcmc/test_slice_sampling.py b/tests/mcmc/test_slice_sampling.py index ab36ed9b6..f89be12b0 100644 --- a/tests/mcmc/test_slice_sampling.py +++ b/tests/mcmc/test_slice_sampling.py @@ -1,6 +1,4 @@ """Test the Hit-and-Run Slice Sampling algorithm.""" -import functools - import chex import jax import jax.numpy as jnp @@ -30,15 +28,15 @@ def test_slice_state_structure(self): """Test SliceState structure and initialization.""" position = jnp.array([1.0, -0.5, 2.0]) state = ss.init(position, self.logdensity_normal) - + # Check structure self.assertIsInstance(state, ss.SliceState) chex.assert_trees_all_close(state.position, position) - + # Check logdensity is computed correctly expected_logdens = self.logdensity_normal(position) chex.assert_trees_all_close(state.logdensity, expected_logdens) - + # Check default logslice self.assertEqual(state.logslice, jnp.inf) @@ -50,9 +48,9 @@ def test_slice_info_structure(self): r_steps=5, s_steps=7, evals=15, - d=jnp.array([0.5, -0.2]) + d=jnp.array([0.5, -0.2]), ) - + chex.assert_shape(info.constraint, (2,)) self.assertEqual(info.l_steps, 3) self.assertEqual(info.r_steps, 5) @@ -64,45 +62,52 @@ def test_vertical_slice(self): """Test vertical slice height sampling.""" position = jnp.array([0.0]) state = ss.init(position, self.logdensity_normal) - + # Test multiple samples n_samples = 1000 keys = jax.random.split(self.key, n_samples) - + new_states, infos = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) - + # Heights should be below current logdensity logdens_at_pos = self.logdensity_normal(position) self.assertTrue(jnp.all(new_states.logslice <= logdens_at_pos)) - + # Mean should be approximately logdens - 1 (E[log(U)] = -1) mean_height = jnp.mean(new_states.logslice) expected_mean = logdens_at_pos - 1.0 chex.assert_trees_all_close(mean_height, expected_mean, atol=0.1) - + # Check info structure - self.assertTrue(jnp.all(infos.evals == 0)) # Vertical slice doesn't eval logdensity + self.assertTrue( + jnp.all(infos.evals == 0) + ) # Vertical slice doesn't eval logdensity @parameterized.parameters([1, 2, 5]) def test_slice_sampling_dimensions(self, ndim): """Test slice sampling in different dimensions.""" position = jnp.zeros(ndim) state = ss.init(position, self.logdensity_normal) - + # Test with simple direction and stepper direction = jax.random.normal(self.key, (ndim,)) direction = direction / jnp.linalg.norm(direction) - + kernel = ss.build_kernel(ss.default_stepper_fn) - + def dummy_constraint_fn(x): return jnp.array([]) - + new_state, info = kernel( - self.key, state, self.logdensity_normal, direction, - dummy_constraint_fn, jnp.array([]), jnp.array([]) + self.key, + state, + self.logdensity_normal, + direction, + dummy_constraint_fn, + jnp.array([]), + jnp.array([]), ) - + chex.assert_shape(new_state.position, (ndim,)) self.assertIsInstance(new_state.logdensity, (float, jax.Array)) self.assertIsInstance(info, ss.SliceInfo) @@ -111,18 +116,23 @@ def test_1d_slice_sampling(self): """Test 1D slice sampling (edge case for JAX shapes).""" position = jnp.array(0.5) # 1D scalar state = ss.init(position, lambda x: -0.5 * x**2) - + direction = jnp.array(1.0) # 1D direction kernel = ss.build_kernel(ss.default_stepper_fn) - + def dummy_constraint_fn(x): return jnp.array([]) - + new_state, info = kernel( - self.key, state, lambda x: -0.5 * x**2, direction, - dummy_constraint_fn, jnp.array([]), jnp.array([]) + self.key, + state, + lambda x: -0.5 * x**2, + direction, + dummy_constraint_fn, + jnp.array([]), + jnp.array([]), ) - + # Check it runs without shape errors self.assertIsInstance(new_state.logdensity, (float, jax.Array)) self.assertIsInstance(info.evals, (int, jax.Array)) @@ -132,10 +142,10 @@ def test_default_stepper_fn(self): x = jnp.array([1.0, 2.0, -1.5]) d = jnp.array([0.5, -0.3, 0.8]) t = 2.5 - + result = ss.default_stepper_fn(x, d, t) expected = x + t * d - + chex.assert_trees_all_close(result, expected) def test_stepper_fn_with_pytrees(self): @@ -143,9 +153,9 @@ def test_stepper_fn_with_pytrees(self): x = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([-0.5])} d = {"a": jnp.array([0.3, -0.2]), "b": jnp.array([0.7])} t = 1.5 - + result = ss.default_stepper_fn(x, d, t) - + chex.assert_trees_all_close(result["a"], x["a"] + t * d["a"]) chex.assert_trees_all_close(result["b"], x["b"] + t * d["b"]) @@ -161,27 +171,32 @@ def test_constrained_sampling(self): """Test slice sampling respects constraints.""" # Start in valid region (x > 0) position = jnp.array([1.0, 2.0]) - + def constrained_logdens(x): return jnp.where(jnp.all(x > 0), -0.5 * jnp.sum(x**2), -jnp.inf) - + state = ss.init(position, constrained_logdens) direction = jnp.array([1.0, -0.5]) # Could lead outside valid region - + kernel = ss.build_kernel(ss.default_stepper_fn) - + # Test with constraint function def constraint_fn(x): return x # Return position values to check > 0 - + constraint_thresholds = jnp.array([0.0, 0.0]) # Must be > 0 strict_flags = jnp.array([True, True]) # Strict inequality - + new_state, info = kernel( - self.key, state, constrained_logdens, direction, - constraint_fn, constraint_thresholds, strict_flags + self.key, + state, + constrained_logdens, + direction, + constraint_fn, + constraint_thresholds, + strict_flags, ) - + # Should remain in valid region self.assertTrue(jnp.all(new_state.position > 0)) self.assertFalse(jnp.isneginf(new_state.logdensity)) @@ -189,27 +204,32 @@ def constraint_fn(x): def test_constraint_evaluation_ordering(self): """Test that constraints are evaluated correctly.""" position = jnp.array([0.5]) - + def logdens(x): - return -0.5 * x**2 - + return jnp.sum(-0.5 * x**2) + state = ss.init(position, logdens) direction = jnp.array([1.0]) - + kernel = ss.build_kernel(ss.default_stepper_fn) - + # Constraint that evaluates a simple function def constraint_fn(x): - return jnp.array([x[0]**2]) # Square of position - + return jnp.array([x[0] ** 2]) # Square of position + constraint_threshold = jnp.array([0.25]) # x^2 > 0.25, so |x| > 0.5 strict_flag = jnp.array([True]) - + new_state, info = kernel( - self.key, state, logdens, direction, - constraint_fn, constraint_threshold, strict_flag + self.key, + state, + logdens, + direction, + constraint_fn, + constraint_threshold, + strict_flag, ) - + # Check constraint is satisfied constraint_val = constraint_fn(new_state.position) self.assertTrue(jnp.all(constraint_val > constraint_threshold)) @@ -217,30 +237,29 @@ def constraint_fn(x): def test_multiple_constraints(self): """Test multiple constraints simultaneously.""" position = jnp.array([1.0, 1.5]) - + def logdens(x): return -0.5 * jnp.sum(x**2) - + state = ss.init(position, logdens) direction = jnp.array([0.7, -0.3]) - + kernel = ss.build_kernel(ss.default_stepper_fn) - + def constraint_fn(x): return jnp.array([x[0], x[1], jnp.sum(x)]) # Multiple constraints - + constraints = jnp.array([0.2, 0.1, 1.0]) # x[0] > 0.2, x[1] > 0.1, sum > 1.0 strict = jnp.array([True, True, False]) # Mixed strict/non-strict - + new_state, info = kernel( - self.key, state, logdens, direction, - constraint_fn, constraints, strict + self.key, state, logdens, direction, constraint_fn, constraints, strict ) - + # Check all constraints are satisfied constraint_vals = constraint_fn(new_state.position) self.assertTrue(constraint_vals[0] > constraints[0]) # Strict - self.assertTrue(constraint_vals[1] > constraints[1]) # Strict + self.assertTrue(constraint_vals[1] > constraints[1]) # Strict self.assertTrue(constraint_vals[2] >= constraints[2]) # Non-strict @@ -257,14 +276,12 @@ def logdensity_normal(self, x): def test_direction_generation_from_covariance(self): """Test direction generation from covariance matrix.""" ndim = 3 - cov = jnp.array([[2.0, 0.5, 0.0], - [0.5, 1.5, -0.3], - [0.0, -0.3, 1.0]]) - + cov = jnp.array([[2.0, 0.5, 0.0], [0.5, 1.5, -0.3], [0.0, -0.3, 1.0]]) + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (ndim,)) - + # Check Mahalanobis normalization invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) @@ -274,30 +291,31 @@ def test_direction_generation_identity_covariance(self): """Test direction generation with identity covariance.""" ndim = 4 cov = jnp.eye(ndim) - + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (ndim,)) - + # With identity covariance, should be unit normalized euclidean_norm = jnp.linalg.norm(direction) chex.assert_trees_all_close(euclidean_norm, 1.0, atol=1e-6) def test_hrss_kernel_construction(self): """Test HRSS kernel construction.""" + def direction_fn(rng_key): return jax.random.normal(rng_key, (2,)) - + kernel = ss.build_hrss_kernel(direction_fn, ss.default_stepper_fn) - + self.assertTrue(callable(kernel)) - + # Test kernel execution position = jnp.array([0.0, 1.0]) state = ss.init(position, self.logdensity_normal) - + new_state, info = kernel(self.key, state, self.logdensity_normal) - + chex.assert_shape(new_state.position, (2,)) self.assertIsInstance(info, ss.SliceInfo) @@ -305,41 +323,41 @@ def test_hrss_top_level_api(self): """Test hit-and-run slice sampling top-level API.""" ndim = 2 cov = jnp.eye(ndim) * 1.5 - + algorithm = ss.hrss_as_top_level_api(self.logdensity_normal, cov) - + # Check it's a proper SamplingAlgorithm self.assertIsInstance(algorithm, blackjax.base.SamplingAlgorithm) self.assertTrue(hasattr(algorithm, "init")) self.assertTrue(hasattr(algorithm, "step")) - + # Test initialization position = jnp.array([1.0, -0.5]) state = algorithm.init(position) - + self.assertIsInstance(state, ss.SliceState) chex.assert_trees_all_close(state.position, position) - + # Test step new_state, info = algorithm.step(self.key, state) - + chex.assert_shape(new_state.position, (ndim,)) self.assertIsInstance(info, ss.SliceInfo) def test_hrss_1d_case(self): """Test HRSS with 1D problem.""" cov = jnp.array([[1.0]]) # 1x1 covariance matrix - + def logdens_1d(x): - return -0.5 * x**2 - + return jnp.sum(-0.5 * x**2) + algorithm = ss.hrss_as_top_level_api(logdens_1d, cov) - + position = jnp.array([0.5]) state = algorithm.init(position) - + new_state, info = algorithm.step(self.key, state) - + chex.assert_shape(new_state.position, (1,)) self.assertIsInstance(new_state.logdensity, (float, jax.Array)) @@ -353,68 +371,70 @@ def setUp(self): def test_slice_sampling_mean_estimation(self): """Test that HRSS correctly estimates mean of target distribution.""" + # Target: standard normal, should have mean ≈ 0 def logdens(x): return stats.norm.logpdf(x).sum() - + cov = jnp.eye(1) algorithm = ss.hrss_as_top_level_api(logdens, cov) - + # Run short chain n_samples = 200 # Modest for testing position = jnp.array([0.0]) state = algorithm.init(position) - + samples = [] keys = jax.random.split(self.key, n_samples) - + for i, sample_key in enumerate(keys): state, info = algorithm.step(sample_key, state) if i >= 50: # Skip some burn-in samples.append(state.position[0]) - + samples = jnp.array(samples) - + # Basic sanity checks self.assertFalse(jnp.any(jnp.isnan(samples))) self.assertFalse(jnp.any(jnp.isinf(samples))) - + # Statistical checks (very loose for small sample size) sample_mean = jnp.mean(samples) sample_std = jnp.std(samples) - + # Mean should be reasonable self.assertLess(jnp.abs(sample_mean), 0.5) # Loose bound - + # Standard deviation should be positive and reasonable self.assertGreater(sample_std, 0.1) self.assertLess(sample_std, 3.0) def test_slice_sampling_multimodal(self): """Test slice sampling on multimodal distribution.""" + def logdens_bimodal(x): # Mixture of two Gaussians at -2 and +2 mode1 = stats.norm.logpdf(x - 2.0) mode2 = stats.norm.logpdf(x + 2.0) return jnp.logaddexp(mode1, mode2).sum() - + cov = jnp.eye(1) * 4.0 # Wider proposals for multimodal algorithm = ss.hrss_as_top_level_api(logdens_bimodal, cov) - + # Run chain n_samples = 100 position = jnp.array([1.0]) # Start near one mode state = algorithm.init(position) - + samples = [] keys = jax.random.split(self.key, n_samples) - + for sample_key in keys: state, info = algorithm.step(sample_key, state) samples.append(state.position[0]) - + samples = jnp.array(samples) - + # Check basic properties self.assertFalse(jnp.any(jnp.isnan(samples))) sample_range = jnp.max(samples) - jnp.min(samples) @@ -422,39 +442,40 @@ def logdens_bimodal(x): def test_slice_info_diagnostics(self): """Test that slice info provides useful diagnostics.""" + def logdens(x): return -0.5 * jnp.sum(x**2) - + cov = jnp.eye(2) algorithm = ss.hrss_as_top_level_api(logdens, cov) - + position = jnp.array([0.0, 0.0]) state = algorithm.init(position) - + # Collect diagnostics from multiple steps infos = [] keys = jax.random.split(self.key, 20) - + for sample_key in keys: state, info = algorithm.step(sample_key, state) infos.append(info) - + # Check diagnostic fields l_steps = jnp.array([info.l_steps for info in infos]) r_steps = jnp.array([info.r_steps for info in infos]) s_steps = jnp.array([info.s_steps for info in infos]) evals = jnp.array([info.evals for info in infos]) - + # All should be non-negative self.assertTrue(jnp.all(l_steps >= 0)) self.assertTrue(jnp.all(r_steps >= 0)) self.assertTrue(jnp.all(s_steps >= 0)) self.assertTrue(jnp.all(evals >= 0)) - + # Total evaluations should be sum of expansion + shrinking expected_evals = l_steps + r_steps + s_steps chex.assert_trees_all_close(evals, expected_evals) - + # Direction vectors should be present directions = jnp.array([info.d for info in infos]) chex.assert_shape(directions, (20, 2)) @@ -472,7 +493,7 @@ def test_zero_covariance_matrix(self): """Test behavior with singular covariance matrix.""" # This should handle gracefully or raise informative error cov = jnp.zeros((2, 2)) - + # JAX's linalg.inv will produce NaN/Inf for singular matrices # rather than raising an error, so check for that try: @@ -485,58 +506,69 @@ def test_zero_covariance_matrix(self): def test_very_peaked_distribution(self): """Test with very peaked/narrow distribution.""" + def logdens_peaked(x): return -100.0 * jnp.sum(x**2) # Very narrow - + cov = jnp.eye(1) * 0.01 # Small proposals algorithm = ss.hrss_as_top_level_api(logdens_peaked, cov) - + position = jnp.array([0.01]) state = algorithm.init(position) - + # Should handle without numerical issues new_state, info = algorithm.step(self.key, state) - + self.assertFalse(jnp.isnan(new_state.logdensity)) self.assertFalse(jnp.isinf(new_state.logdensity)) def test_large_step_proposals(self): """Test with very large step proposals.""" + def logdens(x): return -0.5 * jnp.sum(x**2) - + cov = jnp.eye(1) * 100.0 # Very large proposals algorithm = ss.hrss_as_top_level_api(logdens, cov) - + position = jnp.array([0.0]) state = algorithm.init(position) - + # Should still work (though possibly inefficient) new_state, info = algorithm.step(self.key, state) - + self.assertFalse(jnp.isnan(new_state.position).any()) self.assertGreater(info.evals, 0) # Should do some work def test_empty_constraint_arrays(self): """Test with empty constraint arrays.""" position = jnp.array([1.0]) - state = ss.init(position, lambda x: -0.5 * x**2) + + def scalar_logdens(x): + return jnp.sum(-0.5 * x**2) + + state = ss.init(position, scalar_logdens) direction = jnp.array([1.0]) - + kernel = ss.build_kernel(ss.default_stepper_fn) - + def empty_constraint_fn(x): - return jnp.array([]) - + return jnp.array([], dtype=jnp.float32) + # Should handle empty constraints gracefully new_state, info = kernel( - self.key, state, lambda x: -0.5 * x**2, direction, - empty_constraint_fn, jnp.array([]), jnp.array([]) + self.key, + state, + scalar_logdens, + direction, + empty_constraint_fn, + jnp.array([], dtype=jnp.float32), + jnp.array([], dtype=bool), ) - + self.assertIsInstance(new_state, ss.SliceState) chex.assert_shape(info.constraint, (0,)) if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/mcmc/test_slice_sampling_units.py b/tests/mcmc/test_slice_sampling_units.py index 3c81e6cd5..812de58eb 100644 --- a/tests/mcmc/test_slice_sampling_units.py +++ b/tests/mcmc/test_slice_sampling_units.py @@ -14,13 +14,13 @@ def test_slice_state_creation(self): """Test SliceState creation and default values.""" position = jnp.array([1.0, 2.0]) logdensity = -3.5 - + # Test with default logslice state = ss.SliceState(position, logdensity) chex.assert_trees_all_close(state.position, position) self.assertEqual(state.logdensity, logdensity) self.assertEqual(state.logslice, jnp.inf) - + # Test with explicit logslice logslice = -1.2 state = ss.SliceState(position, logdensity, logslice) @@ -29,7 +29,7 @@ def test_slice_state_creation(self): def test_slice_state_replace(self): """Test SliceState _replace method.""" state = ss.SliceState(jnp.array([1.0]), -2.0, -5.0) - + new_state = state._replace(logslice=-3.0) self.assertEqual(new_state.logslice, -3.0) self.assertEqual(new_state.logdensity, -2.0) # Unchanged @@ -49,14 +49,17 @@ def test_slice_info_creation(self): self.assertEqual(info.s_steps, 0) self.assertEqual(info.evals, 0) self.assertIsNone(info.d) - + # Test with explicit values constraint = jnp.array([1.0, 2.0]) direction = jnp.array([0.5, -0.3]) info = ss.SliceInfo( constraint=constraint, - l_steps=3, r_steps=5, s_steps=7, evals=15, - d=direction + l_steps=3, + r_steps=5, + s_steps=7, + evals=15, + d=direction, ) chex.assert_trees_all_close(info.constraint, constraint) self.assertEqual(info.l_steps, 3) @@ -73,15 +76,17 @@ def setUp(self): super().setUp() self.logdensity_fn = lambda x: -0.5 * jnp.sum(x**2) - @parameterized.parameters([ - (jnp.array([0.0]),), - (jnp.array([1.5, -2.0]),), - (jnp.array([[1.0, 2.0], [3.0, 4.0]]),), - ]) + @parameterized.parameters( + [ + (jnp.array([0.0]),), + (jnp.array([1.5, -2.0]),), + (jnp.array([[1.0, 2.0], [3.0, 4.0]]),), + ] + ) def test_init_shapes(self, position): """Test init with different position shapes.""" state = ss.init(position, self.logdensity_fn) - + chex.assert_trees_all_close(state.position, position) expected_logdens = self.logdensity_fn(position) chex.assert_trees_all_close(state.logdensity, expected_logdens) @@ -90,12 +95,12 @@ def test_init_shapes(self, position): def test_init_with_pytree(self): """Test init with PyTree position.""" position = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([3.0])} - + def logdens_pytree(x): - return -0.5 * (jnp.sum(x["a"]**2) + jnp.sum(x["b"]**2)) - + return -0.5 * (jnp.sum(x["a"] ** 2) + jnp.sum(x["b"] ** 2)) + state = ss.init(position, logdens_pytree) - + chex.assert_trees_all_close(state.position, position) expected_logdens = logdens_pytree(position) self.assertEqual(state.logdensity, expected_logdens) @@ -113,14 +118,14 @@ def test_vertical_slice_height_bounds(self): position = jnp.array([0.0]) logdensity = -1.5 state = ss.SliceState(position, logdensity) - + # Test multiple samples keys = jax.random.split(self.key, 100) new_states, infos = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) - + # All slice heights should be <= logdensity self.assertTrue(jnp.all(new_states.logslice <= logdensity)) - + # Info should have zero evaluations (vertical slice doesn't eval logdensity) self.assertTrue(jnp.all(infos.evals == 0)) @@ -129,16 +134,16 @@ def test_vertical_slice_deterministic_bound(self): position = jnp.array([0.0]) logdensity = -2.0 state = ss.SliceState(position, logdensity) - + # Generate many samples n_samples = 5000 keys = jax.random.split(self.key, n_samples) new_states, _ = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) - + # Mean of log(U) where U ~ Uniform(0,1) is -1 mean_height = jnp.mean(new_states.logslice) expected_mean = logdensity - 1.0 - + # Should be close to expected mean (loose tolerance for finite sample) self.assertAlmostEqual(mean_height, expected_mean, delta=0.1) @@ -147,9 +152,9 @@ def test_vertical_slice_preserves_position(self): position = jnp.array([1.5, -0.5]) logdensity = -3.2 state = ss.SliceState(position, logdensity) - + new_state, info = ss.vertical_slice(self.key, state) - + chex.assert_trees_all_close(new_state.position, position) self.assertEqual(new_state.logdensity, logdensity) self.assertNotEqual(new_state.logslice, jnp.inf) # Should be updated @@ -163,10 +168,10 @@ def test_default_stepper_array(self): x = jnp.array([1.0, 2.0]) d = jnp.array([0.5, -0.3]) t = 2.5 - + result = ss.default_stepper_fn(x, d, t) expected = x + t * d - + chex.assert_trees_all_close(result, expected) def test_default_stepper_scalar(self): @@ -174,10 +179,10 @@ def test_default_stepper_scalar(self): x = 3.0 d = -1.2 t = 0.8 - + result = ss.default_stepper_fn(x, d, t) expected = x + t * d - + self.assertEqual(result, expected) def test_default_stepper_pytree(self): @@ -185,9 +190,9 @@ def test_default_stepper_pytree(self): x = {"a": jnp.array([1.0, 2.0]), "b": jnp.array([3.0])} d = {"a": jnp.array([0.1, -0.2]), "b": jnp.array([0.5])} t = 1.5 - + result = ss.default_stepper_fn(x, d, t) - + chex.assert_trees_all_close(result["a"], x["a"] + t * d["a"]) chex.assert_trees_all_close(result["b"], x["b"] + t * d["b"]) @@ -196,7 +201,7 @@ def test_stepper_zero_step(self): x = jnp.array([1.0, 2.0, 3.0]) d = jnp.array([10.0, -5.0, 2.0]) t = 0.0 - + result = ss.default_stepper_fn(x, d, t) chex.assert_trees_all_close(result, x) @@ -212,11 +217,11 @@ def test_sample_direction_identity_covariance(self): """Test direction sampling with identity covariance.""" ndim = 3 cov = jnp.eye(ndim) - + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (ndim,)) - + # With identity covariance, should be unit normalized norm = jnp.linalg.norm(direction) chex.assert_trees_all_close(norm, 1.0, atol=1e-6) @@ -226,11 +231,11 @@ def test_sample_direction_scaled_covariance(self): ndim = 2 scale = 4.0 cov = jnp.eye(ndim) * scale - + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (ndim,)) - + # Check Mahalanobis normalization invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) @@ -239,11 +244,11 @@ def test_sample_direction_scaled_covariance(self): def test_sample_direction_general_covariance(self): """Test direction sampling with general covariance matrix.""" cov = jnp.array([[2.0, 0.5], [0.5, 1.0]]) - + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (2,)) - + # Check Mahalanobis normalization invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) @@ -252,11 +257,11 @@ def test_sample_direction_general_covariance(self): def test_sample_direction_1d(self): """Test direction sampling for 1D case.""" cov = jnp.array([[2.0]]) - + direction = ss.sample_direction_from_covariance(self.key, cov) - + chex.assert_shape(direction, (1,)) - + # Check Mahalanobis normalization (should be 1) invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) @@ -266,15 +271,17 @@ def test_sample_direction_multiple_samples(self): """Test that multiple direction samples are different.""" cov = jnp.eye(2) keys = jax.random.split(self.key, 10) - - directions = jax.vmap(ss.sample_direction_from_covariance, in_axes=(0, None))(keys, cov) - + + directions = jax.vmap(ss.sample_direction_from_covariance, in_axes=(0, None))( + keys, cov + ) + chex.assert_shape(directions, (10, 2)) - + # All should be unit normalized norms = jnp.linalg.norm(directions, axis=1) chex.assert_trees_all_close(norms, jnp.ones(10), atol=1e-6) - + # Should not all be the same std_of_directions = jnp.std(directions, axis=0) self.assertTrue(jnp.all(std_of_directions > 0.1)) # Some variation expected @@ -293,20 +300,26 @@ def test_horizontal_slice_basic(self): logdensity = -0.5 * position**2 logslice = -2.0 state = ss.SliceState(position, logdensity, logslice) - + direction = jnp.array([1.0]) - + def logdens_fn(x): - return -0.5 * x**2 - + return jnp.sum(-0.5 * x**2) + def constraint_fn(x): - return jnp.array([]) - + return jnp.array([], dtype=jnp.float32) + new_state, info = ss.horizontal_slice( - self.key, state, direction, ss.default_stepper_fn, - logdens_fn, constraint_fn, jnp.array([]), jnp.array([]) + self.key, + state, + direction, + ss.default_stepper_fn, + logdens_fn, + constraint_fn, + jnp.array([], dtype=jnp.float32), + jnp.array([], dtype=bool), ) - + self.assertIsInstance(new_state, ss.SliceState) self.assertIsInstance(info, ss.SliceInfo) self.assertGreater(info.evals, 0) # Should have done some evaluations @@ -316,21 +329,27 @@ def test_horizontal_slice_with_constraints(self): position = jnp.array([1.0]) state = ss.SliceState(position, -0.5, -1.0) direction = jnp.array([1.0]) - + def logdens_fn(x): - return -0.5 * x**2 - + return jnp.sum(-0.5 * x**2) + def constraint_fn(x): return jnp.array([x[0]]) # Must be positive - + constraint_thresholds = jnp.array([0.0]) strict_flags = jnp.array([True]) - + new_state, info = ss.horizontal_slice( - self.key, state, direction, ss.default_stepper_fn, - logdens_fn, constraint_fn, constraint_thresholds, strict_flags + self.key, + state, + direction, + ss.default_stepper_fn, + logdens_fn, + constraint_fn, + constraint_thresholds, + strict_flags, ) - + # Should satisfy constraints self.assertTrue(jnp.all(new_state.position > 0)) self.assertGreater(info.l_steps + info.r_steps + info.s_steps, 0) @@ -340,25 +359,31 @@ def test_horizontal_slice_info_completeness(self): position = jnp.array([0.0]) state = ss.SliceState(position, 0.0, -1.0) direction = jnp.array([1.0]) - + def logdens_fn(x): - return -x**2 - + return jnp.sum(-(x**2)) + def constraint_fn(x): - return jnp.array([x[0]**2]) - + return jnp.array([x[0] ** 2]) + new_state, info = ss.horizontal_slice( - self.key, state, direction, ss.default_stepper_fn, - logdens_fn, constraint_fn, jnp.array([0.1]), jnp.array([True]) + self.key, + state, + direction, + ss.default_stepper_fn, + logdens_fn, + constraint_fn, + jnp.array([0.1]), + jnp.array([True]), ) - + # Check all info fields are populated self.assertIsInstance(info.l_steps, (int, jax.Array)) self.assertIsInstance(info.r_steps, (int, jax.Array)) self.assertIsInstance(info.s_steps, (int, jax.Array)) self.assertIsInstance(info.evals, (int, jax.Array)) chex.assert_shape(info.constraint, (1,)) - + # Total evaluations should equal sum of steps self.assertEqual(info.evals, info.l_steps + info.r_steps + info.s_steps) @@ -368,31 +393,34 @@ class KernelBuildingTest(chex.TestCase): def test_build_kernel_callable(self): """Test that build_kernel returns a callable.""" + def simple_stepper(x, d, t): return x + t * d - + kernel = ss.build_kernel(simple_stepper) self.assertTrue(callable(kernel)) def test_build_hrss_kernel_callable(self): """Test that build_hrss_kernel returns a callable.""" + def direction_fn(rng_key): return jax.random.normal(rng_key, (2,)) - + def simple_stepper(x, d, t): return x + t * d - + kernel = ss.build_hrss_kernel(direction_fn, simple_stepper) self.assertTrue(callable(kernel)) def test_hrss_top_level_api_structure(self): """Test top-level API returns correct structure.""" + def simple_logdens(x): return -0.5 * jnp.sum(x**2) - + cov = jnp.eye(2) algorithm = ss.hrss_as_top_level_api(simple_logdens, cov) - + # Should have init and step methods self.assertTrue(hasattr(algorithm, "init")) self.assertTrue(hasattr(algorithm, "step")) @@ -401,4 +429,4 @@ def simple_logdens(x): if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/ns/test_nested_sampling.py b/tests/ns/test_nested_sampling.py index c24831121..5a1fc2446 100644 --- a/tests/ns/test_nested_sampling.py +++ b/tests/ns/test_nested_sampling.py @@ -7,7 +7,6 @@ import jax.scipy.stats as stats from absl.testing import absltest, parameterized -import blackjax from blackjax.ns import adaptive, base, nss, utils @@ -35,22 +34,22 @@ def test_ns_state_structure(self): num_live = 50 ndim = 2 particles = jax.random.normal(self.key, (num_live, ndim)) - + state = base.init(particles, self.logprior_uniform, self.loglikelihood_gaussian) - + # Check shapes chex.assert_shape(state.particles, (num_live, ndim)) chex.assert_shape(state.loglikelihood, (num_live,)) chex.assert_shape(state.logprior, (num_live,)) chex.assert_shape(state.pid, (num_live,)) - + # Check values are computed correctly expected_loglik = jax.vmap(self.loglikelihood_gaussian)(particles) expected_logprior = jax.vmap(self.logprior_uniform)(particles) - + chex.assert_trees_all_close(state.loglikelihood, expected_loglik) chex.assert_trees_all_close(state.logprior, expected_logprior) - + # Check particle IDs are unique self.assertEqual(len(jnp.unique(state.pid)), num_live) @@ -60,15 +59,15 @@ def test_ns_info_structure(self): loglik = jnp.array([0.1, 0.2]) loglik_birth = jnp.array([-jnp.inf, 0.05]) logprior = jnp.array([-1.0, -1.1]) - + info = base.NSInfo( particles=particles, loglikelihood=loglik, loglikelihood_birth=loglik_birth, logprior=logprior, - inner_kernel_info={} + inner_kernel_info={}, ) - + chex.assert_shape(info.particles, (2, 2)) chex.assert_shape(info.loglikelihood, (2,)) chex.assert_shape(info.loglikelihood_birth, (2,)) @@ -80,18 +79,18 @@ def test_delete_fn(self, num_delete): num_live = 20 particles = jax.random.normal(self.key, (num_live, 2)) state = base.init(particles, self.logprior_uniform, self.loglikelihood_gaussian) - + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) - + # Check shapes chex.assert_shape(dead_idx, (num_delete,)) chex.assert_shape(target_idx, (num_delete,)) chex.assert_shape(start_idx, (num_delete,)) - + # Check that worst particles are selected worst_indices = jnp.argsort(state.loglikelihood)[:num_delete] chex.assert_trees_all_close(jnp.sort(dead_idx), jnp.sort(worst_indices)) - + # Check indices are valid self.assertTrue(jnp.all(dead_idx >= 0)) self.assertTrue(jnp.all(dead_idx < num_live)) @@ -102,15 +101,15 @@ def test_1d_basic_functionality(self): """Test 1D case to catch shape issues.""" num_live = 30 particles = jax.random.uniform(self.key, (num_live,), minval=-3, maxval=3) - + def logprior_1d(x): return jnp.where((x >= -3) & (x <= 3), -jnp.log(6.0), -jnp.inf) - + def loglik_1d(x): return -0.5 * x**2 - + state = base.init(particles, logprior_1d, loglik_1d) - + chex.assert_shape(state.particles, (num_live,)) chex.assert_shape(state.loglikelihood, (num_live,)) self.assertFalse(jnp.any(jnp.isnan(state.loglikelihood))) @@ -118,18 +117,21 @@ def loglik_1d(x): def test_kernel_construction(self): """Test that kernel can be constructed.""" - def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + + def mock_inner_kernel( + rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params + ): # Simple mock that just returns the input state return inner_state, {} - + delete_fn = functools.partial(base.delete_fn, num_delete=1) kernel = base.build_kernel( self.logprior_uniform, self.loglikelihood_gaussian, delete_fn, - mock_inner_kernel + mock_inner_kernel, ) - + self.assertTrue(callable(kernel)) @@ -146,29 +148,31 @@ def create_mock_info(self, n_dead=50): base_loglik = jnp.linspace(-5, -1, n_dead) noise = jax.random.normal(self.key, (n_dead,)) * 0.05 dead_loglik = jnp.sort(base_loglik + noise) - + # Birth likelihoods key, subkey = jax.random.split(self.key) birth_offsets = jax.random.uniform(subkey, (n_dead,)) * 0.2 - 0.1 - dead_loglik_birth = jnp.concatenate([ - jnp.array([-jnp.inf]), # First from prior - dead_loglik[:-1] + birth_offsets[1:] - ]) + dead_loglik_birth = jnp.concatenate( + [ + jnp.array([-jnp.inf]), # First from prior + dead_loglik[:-1] + birth_offsets[1:], + ] + ) dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) - + return base.NSInfo( particles=jnp.zeros((n_dead, 2)), loglikelihood=dead_loglik, loglikelihood_birth=dead_loglik_birth, logprior=jnp.zeros(n_dead), - inner_kernel_info={} + inner_kernel_info={}, ) def test_compute_num_live(self): """Test computation of number of live points.""" mock_info = self.create_mock_info(n_dead=30) num_live = utils.compute_num_live(mock_info) - + chex.assert_shape(num_live, (30,)) self.assertTrue(jnp.all(num_live >= 1)) self.assertFalse(jnp.any(jnp.isnan(num_live))) @@ -177,16 +181,16 @@ def test_logX_simulation(self): """Test log-volume simulation.""" mock_info = self.create_mock_info(n_dead=40) n_samples = 20 - + logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=n_samples) - + chex.assert_shape(logX_seq, (40, n_samples)) chex.assert_shape(logdX_seq, (40, n_samples)) - + # Log volumes should be decreasing for i in range(n_samples): self.assertTrue(jnp.all(logX_seq[1:, i] <= logX_seq[:-1, i])) - + # No NaN values self.assertFalse(jnp.any(jnp.isnan(logX_seq))) @@ -194,11 +198,11 @@ def test_log_weights(self): """Test log weight computation.""" mock_info = self.create_mock_info(n_dead=25) n_samples = 15 - + log_weights_matrix = utils.log_weights(self.key, mock_info, shape=n_samples) - + chex.assert_shape(log_weights_matrix, (25, n_samples)) - + # Most weights should be finite finite_weights = jnp.isfinite(log_weights_matrix) finite_fraction = jnp.mean(finite_weights) @@ -207,9 +211,9 @@ def test_log_weights(self): def test_ess_computation(self): """Test effective sample size computation.""" mock_info = self.create_mock_info(n_dead=35) - + ess_value = utils.ess(self.key, mock_info) - + self.assertIsInstance(ess_value, (float, jax.Array)) self.assertGreater(ess_value, 0.0) self.assertLessEqual(ess_value, 35) @@ -220,26 +224,26 @@ def test_evidence_estimation_simple(self): # Constant likelihood case n_dead = 30 loglik_const = -2.0 - + mock_info = base.NSInfo( particles=jnp.zeros((n_dead, 1)), loglikelihood=jnp.full(n_dead, loglik_const), loglikelihood_birth=jnp.full(n_dead, -jnp.inf), logprior=jnp.zeros(n_dead), # Uniform prior - inner_kernel_info={} + inner_kernel_info={}, ) - + # Generate evidence estimates n_samples = 100 keys = jax.random.split(self.key, n_samples) - + def single_evidence_estimate(rng_key): log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=10) return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) - + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) log_evidence_samples = log_evidence_samples.flatten() - + # Should be close to the constant likelihood value mean_estimate = jnp.mean(log_evidence_samples) self.assertFalse(jnp.isnan(mean_estimate)) @@ -263,40 +267,43 @@ def test_adaptive_init(self): """Test adaptive NS initialization.""" num_live = 25 particles = jax.random.normal(self.key, (num_live, 2)) - + def mock_update_fn(state, info, params): return {"test_param": 1.5} - + state = adaptive.init( particles, self.logprior_fn, self.loglik_fn, - update_inner_kernel_params_fn=mock_update_fn + update_inner_kernel_params_fn=mock_update_fn, ) - + # Check basic structure chex.assert_shape(state.particles, (num_live, 2)) - + # Check inner kernel params were set self.assertIn("test_param", state.inner_kernel_params) self.assertEqual(state.inner_kernel_params["test_param"], 1.5) def test_adaptive_kernel_construction(self): """Test adaptive kernel can be constructed.""" - def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + + def mock_inner_kernel( + rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params + ): return inner_state, {} - + def mock_update_fn(state, info, params): return params - + kernel = adaptive.build_kernel( self.logprior_fn, self.loglik_fn, base.delete_fn, mock_inner_kernel, - mock_update_fn + mock_update_fn, ) - + self.assertTrue(callable(kernel)) @@ -319,13 +326,13 @@ def test_covariance_computation(self): ndim = 3 particles = jax.random.normal(self.key, (num_live, ndim)) state = base.init(particles, self.logprior_fn, self.loglik_fn) - + params = nss.compute_covariance_from_particles(state, None, {}) - + self.assertIn("cov", params) cov = params["cov"] chex.assert_shape(cov, (ndim, ndim)) - + # Covariance should be positive semidefinite eigenvals = jnp.linalg.eigvals(cov) self.assertTrue(jnp.all(eigenvals >= -1e-10)) @@ -335,11 +342,11 @@ def test_direction_sampling(self): ndim = 4 cov = jnp.eye(ndim) * 2.0 params = {"cov": cov} - + direction = nss.sample_direction_from_covariance(self.key, params) - + chex.assert_shape(direction, (ndim,)) - + # Check normalization invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) @@ -347,28 +354,25 @@ def test_direction_sampling(self): def test_nss_kernel_construction(self): """Test NSS kernel construction.""" - kernel = nss.build_kernel( - self.logprior_fn, - self.loglik_fn, - num_inner_steps=5 - ) - + kernel = nss.build_kernel(self.logprior_fn, self.loglik_fn, num_inner_steps=5) + self.assertTrue(callable(kernel)) def test_nss_with_1d_problem(self): """Test NSS with 1D problem (edge case).""" + def logprior_1d(x): return jnp.where((x >= -2) & (x <= 2), -jnp.log(4.0), -jnp.inf) - + def loglik_1d(x): return -0.5 * x**2 - + num_live = 20 particles = jax.random.uniform(self.key, (num_live,), minval=-2, maxval=2) state = base.init(particles, logprior_1d, loglik_1d) - + params = nss.compute_covariance_from_particles(state, None, {}) - + self.assertIn("cov", params) cov = params["cov"] # For 1D, cov should be shaped appropriately for the particle structure @@ -386,40 +390,42 @@ def setUp(self): def test_evidence_monotonicity(self): """Test evidence is monotonically increasing.""" + def logprior_fn(x): return stats.norm.logpdf(x).sum() - + def loglik_fn(x): return -0.5 * jnp.sum(x**2) - + num_live = 30 particles = jax.random.normal(self.key, (num_live, 2)) state = base.init(particles, logprior_fn, loglik_fn) - + # Simulate evidence updates logZ_sequence = [state.logZ] current_state = state - + for _ in range(5): worst_idx = jnp.argmin(current_state.loglikelihood) dead_loglik = current_state.loglikelihood[worst_idx] - + # Approximate volume decrease delta_logX = -1.0 / num_live new_logZ = jnp.logaddexp(current_state.logZ, dead_loglik + delta_logX) logZ_sequence.append(new_logZ) - + # Mock update for next iteration - new_loglik = jnp.concatenate([ - current_state.loglikelihood[:worst_idx], - current_state.loglikelihood[worst_idx + 1:], - jnp.array([dead_loglik + 0.1]) - ]) + new_loglik = jnp.concatenate( + [ + current_state.loglikelihood[:worst_idx], + current_state.loglikelihood[worst_idx + 1 :], + jnp.array([dead_loglik + 0.1]), + ] + ) current_state = current_state._replace( - loglikelihood=new_loglik, - logZ=new_logZ + loglikelihood=new_loglik, logZ=new_logZ ) - + # Check monotonicity logZ_array = jnp.array(logZ_sequence) differences = logZ_array[1:] - logZ_array[:-1] @@ -430,70 +436,69 @@ def test_gaussian_evidence_analytical(self): # Setup: Gaussian likelihood with uniform prior prior_a, prior_b = -2.0, 2.0 sigma = 1.0 - + def logprior_fn(x): width = prior_b - prior_a return jnp.where((x >= prior_a) & (x <= prior_b), -jnp.log(width), -jnp.inf) - + def loglik_fn(x): - return -0.5 * (x / sigma)**2 - 0.5 * jnp.log(2 * jnp.pi * sigma**2) - + return -0.5 * (x / sigma) ** 2 - 0.5 * jnp.log(2 * jnp.pi * sigma**2) + # Analytical evidence (truncated Gaussian integral) from scipy.stats import norm + analytical_evidence = ( norm.cdf(prior_b / sigma) - norm.cdf(prior_a / sigma) ) / (prior_b - prior_a) analytical_log_evidence = jnp.log(analytical_evidence) - + # Mock NS data n_dead = 50 positions = jnp.linspace(prior_a + 0.01, prior_b - 0.01, n_dead).reshape(-1, 1) dead_loglik = jax.vmap(loglik_fn)(positions.flatten()) dead_logprior = jax.vmap(logprior_fn)(positions.flatten()) - + # Sort by likelihood sorted_idx = jnp.argsort(dead_loglik) dead_loglik = dead_loglik[sorted_idx] positions = positions[sorted_idx] dead_logprior = dead_logprior[sorted_idx] - - dead_loglik_birth = jnp.concatenate([ - jnp.array([-jnp.inf]), - dead_loglik[:-1] - 0.05 - ]) - + + dead_loglik_birth = jnp.concatenate( + [jnp.array([-jnp.inf]), dead_loglik[:-1] - 0.05] + ) + mock_info = base.NSInfo( particles=positions, loglikelihood=dead_loglik, loglikelihood_birth=dead_loglik_birth, logprior=dead_logprior, - inner_kernel_info={} + inner_kernel_info={}, ) - + # Generate evidence estimates n_samples = 200 keys = jax.random.split(self.key, n_samples) - + def single_evidence_estimate(rng_key): log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=10) return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) - + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) log_evidence_samples = log_evidence_samples.flatten() - + # Statistical validation mean_estimate = jnp.mean(log_evidence_samples) - std_estimate = jnp.std(log_evidence_samples) - + # For mock data, we expect some bias, so use looser bounds # This is primarily testing that the utilities work, not exact accuracy self.assertFalse(jnp.isnan(mean_estimate)) self.assertFalse(jnp.isinf(mean_estimate)) - + # Very loose bounds - mainly checking it's in the right ballpark self.assertGreater(mean_estimate, analytical_log_evidence - 3.0) self.assertLess(mean_estimate, analytical_log_evidence + 3.0) if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/ns/test_nested_sampling_units.py b/tests/ns/test_nested_sampling_units.py index 43c0ce6a2..5db64ead8 100644 --- a/tests/ns/test_nested_sampling_units.py +++ b/tests/ns/test_nested_sampling_units.py @@ -1,4 +1,6 @@ """Unit tests for nested sampling components.""" +import functools + import chex import jax import jax.numpy as jnp @@ -21,7 +23,7 @@ def test_ns_state_creation(self): logZ = -5.0 logZ_live = -3.0 inner_kernel_params = {} - + state = base.NSState( particles=particles, loglikelihood=loglik, @@ -31,9 +33,9 @@ def test_ns_state_creation(self): logX=logX, logZ=logZ, logZ_live=logZ_live, - inner_kernel_params=inner_kernel_params + inner_kernel_params=inner_kernel_params, ) - + chex.assert_trees_all_close(state.particles, particles) chex.assert_trees_all_close(state.loglikelihood, loglik) chex.assert_trees_all_close(state.loglikelihood_birth, loglik_birth) @@ -55,12 +57,12 @@ def test_ns_state_replace(self): logX=-2.0, logZ=-5.0, logZ_live=-3.0, - inner_kernel_params={} + inner_kernel_params={}, ) - + new_logZ = -4.5 new_state = state._replace(logZ=new_logZ) - + self.assertEqual(new_state.logZ, new_logZ) self.assertEqual(new_state.logZ_live, -3.0) # Unchanged self.assertEqual(new_state.logX, -2.0) # Unchanged @@ -77,15 +79,15 @@ def test_ns_info_creation(self): loglik_birth = jnp.array([-jnp.inf, 0.05]) logprior = jnp.array([-1.0, -1.1]) kernel_info = {"test": "value"} - + info = base.NSInfo( particles=particles, loglikelihood=loglik, loglikelihood_birth=loglik_birth, logprior=logprior, - inner_kernel_info=kernel_info + inner_kernel_info=kernel_info, ) - + chex.assert_trees_all_close(info.particles, particles) chex.assert_trees_all_close(info.loglikelihood, loglik) chex.assert_trees_all_close(info.loglikelihood_birth, loglik_birth) @@ -105,9 +107,9 @@ def setUp(self): def test_init_particle_count(self, num_live): """Test initialization with different numbers of live points.""" particles = jax.random.normal(jax.random.key(42), (num_live, 2)) - + state = base.init(particles, self.logprior_fn, self.loglik_fn) - + chex.assert_shape(state.particles, (num_live, 2)) chex.assert_shape(state.loglikelihood, (num_live,)) chex.assert_shape(state.logprior, (num_live,)) @@ -117,15 +119,15 @@ def test_init_1d_particles(self): """Test initialization with 1D particles.""" num_live = 20 particles = jax.random.normal(jax.random.key(42), (num_live,)) - + def logprior_1d(x): return -0.5 * x**2 - + def loglik_1d(x): - return -x**2 - + return -(x**2) + state = base.init(particles, logprior_1d, loglik_1d) - + chex.assert_shape(state.particles, (num_live,)) chex.assert_shape(state.loglikelihood, (num_live,)) chex.assert_shape(state.logprior, (num_live,)) @@ -133,13 +135,13 @@ def loglik_1d(x): def test_init_computes_correct_values(self): """Test that init computes loglikelihood and logprior correctly.""" particles = jnp.array([[1.0, 0.0], [0.0, 1.0], [-1.0, -1.0]]) - + state = base.init(particles, self.logprior_fn, self.loglik_fn) - + # Check computed values match manual computation expected_logprior = jax.vmap(self.logprior_fn)(particles) expected_loglik = jax.vmap(self.loglik_fn)(particles) - + chex.assert_trees_all_close(state.logprior, expected_logprior) chex.assert_trees_all_close(state.loglikelihood, expected_loglik) @@ -147,9 +149,9 @@ def test_init_particle_ids_unique(self): """Test that particle IDs are unique.""" num_live = 15 particles = jax.random.normal(jax.random.key(42), (num_live, 3)) - + state = base.init(particles, self.logprior_fn, self.loglik_fn) - + unique_ids = jnp.unique(state.pid) self.assertEqual(len(unique_ids), num_live) @@ -158,17 +160,17 @@ def test_init_with_pytree_particles(self): num_live = 10 particles = { "x": jax.random.normal(jax.random.key(42), (num_live, 2)), - "y": jax.random.normal(jax.random.key(43), (num_live,)) + "y": jax.random.normal(jax.random.key(43), (num_live,)), } - + def logprior_pytree(p): - return -0.5 * (jnp.sum(p["x"]**2) + p["y"]**2) - + return -0.5 * (jnp.sum(p["x"] ** 2) + p["y"] ** 2) + def loglik_pytree(p): - return -(jnp.sum(p["x"]**2) + p["y"]**2) - + return -(jnp.sum(p["x"] ** 2) + p["y"] ** 2) + state = base.init(particles, logprior_pytree, loglik_pytree) - + chex.assert_shape(state.particles["x"], (num_live, 2)) chex.assert_shape(state.particles["y"], (num_live,)) chex.assert_shape(state.loglikelihood, (num_live,)) @@ -192,9 +194,9 @@ def create_test_state(self, num_live=20): def test_delete_fn_shapes(self, num_delete): """Test delete function returns correct shapes.""" state = self.create_test_state(num_live=20) - + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) - + chex.assert_shape(dead_idx, (num_delete,)) chex.assert_shape(target_idx, (num_delete,)) chex.assert_shape(start_idx, (num_delete,)) @@ -203,14 +205,14 @@ def test_delete_fn_selects_worst(self): """Test that delete function selects worst particles.""" state = self.create_test_state(num_live=20) num_delete = 3 - + dead_idx, _, _ = base.delete_fn(self.key, state, num_delete) - + # Should select particles with lowest likelihood worst_indices = jnp.argsort(state.loglikelihood)[:num_delete] selected_indices = jnp.sort(dead_idx) expected_indices = jnp.sort(worst_indices) - + chex.assert_trees_all_close(selected_indices, expected_indices) def test_delete_fn_valid_indices(self): @@ -218,9 +220,9 @@ def test_delete_fn_valid_indices(self): num_live = 15 state = self.create_test_state(num_live=num_live) num_delete = 4 - + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) - + # All indices should be valid self.assertTrue(jnp.all(dead_idx >= 0)) self.assertTrue(jnp.all(dead_idx < num_live)) @@ -233,9 +235,9 @@ def test_delete_fn_no_duplicates(self): """Test that delete function doesn't return duplicate indices.""" state = self.create_test_state(num_live=20) num_delete = 5 - + dead_idx, target_idx, start_idx = base.delete_fn(self.key, state, num_delete) - + # Dead indices should be unique self.assertEqual(len(jnp.unique(dead_idx)), num_delete) @@ -249,59 +251,64 @@ def setUp(self): def test_kernel_full_execution(self): """Test full NS kernel execution workflow.""" + # Create a simple mock inner kernel - def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + def mock_inner_kernel( + rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params + ): # Simple random walk that respects the likelihood constraint pos = inner_state.position new_pos = pos + jax.random.normal(rng_key, pos.shape) * 0.1 new_loglik = loglik_fn(new_pos) new_logprior = logprior_fn(new_pos) - + # Accept if likelihood is above threshold, otherwise return original accept = new_loglik >= loglik_0 final_pos = jnp.where(accept, new_pos, pos) final_loglik = jnp.where(accept, new_loglik, inner_state.loglikelihood) final_logprior = jnp.where(accept, new_logprior, inner_state.logprior) - - new_inner_state = base.PartitionedState(final_pos, final_logprior, final_loglik) + + new_inner_state = base.PartitionedState( + final_pos, final_logprior, final_loglik + ) return new_inner_state, {"accepted": accept} # Set up test functions def logprior_fn(x): return -0.5 * jnp.sum(x**2) - + def loglik_fn(x): return -jnp.sum(x**2) - + # Create initial state num_live = 10 particles = jax.random.normal(self.key, (num_live, 2)) * 0.5 state = base.init(particles, logprior_fn, loglik_fn) - + # Build kernel with delete function def delete_fn(rng_key, state): # Delete 1 worst particle dead_idx = jnp.array([jnp.argmin(state.loglikelihood)]) target_idx = jnp.array([0]) # Replace with first particle - start_idx = jnp.array([0]) # Start from first particle + start_idx = jnp.array([0]) # Start from first particle return dead_idx, target_idx, start_idx - + kernel = base.build_kernel(logprior_fn, loglik_fn, delete_fn, mock_inner_kernel) - + # Execute kernel new_state, info = kernel(self.key, state) - + # Check that state is updated correctly self.assertIsInstance(new_state, base.NSState) self.assertIsInstance(info, base.NSInfo) - + # Should still have same number of particles chex.assert_shape(new_state.particles, (num_live, 2)) chex.assert_shape(new_state.loglikelihood, (num_live,)) - + # Evidence should be updated self.assertNotEqual(new_state.logZ, state.logZ) - + # Info should contain dead particle information chex.assert_shape(info.particles, (1, 2)) # 1 dead particle chex.assert_shape(info.loglikelihood, (1,)) @@ -321,23 +328,23 @@ def test_update_ns_runtime_info(self): logZ = -5.0 loglikelihood = jnp.array([-1.0, -1.5, -2.0, -2.5]) # Live points dead_loglikelihood = jnp.array([-3.0, -3.2]) # Dead points - + new_logX, new_logZ, new_logZ_live = base.update_ns_runtime_info( logX, logZ, loglikelihood, dead_loglikelihood ) - + # Check types and finiteness self.assertIsInstance(new_logX, (float, jax.Array)) self.assertIsInstance(new_logZ, (float, jax.Array)) self.assertIsInstance(new_logZ_live, (float, jax.Array)) - + self.assertFalse(jnp.isnan(new_logX)) self.assertFalse(jnp.isnan(new_logZ)) self.assertFalse(jnp.isnan(new_logZ_live)) - + # Evidence should increase (or at least not decrease significantly) self.assertGreaterEqual(new_logZ, logZ - 1e-10) - + # LogX should decrease (volume shrinking) self.assertLess(new_logX, logX) @@ -347,11 +354,11 @@ def test_update_ns_runtime_info_single_particle(self): logZ = -10.0 loglikelihood = jnp.array([-2.0, -2.5, -3.0]) dead_loglikelihood = jnp.array([-4.0]) # Single deletion - + new_logX, new_logZ, new_logZ_live = base.update_ns_runtime_info( logX, logZ, loglikelihood, dead_loglikelihood ) - + # Should work with single particle self.assertFalse(jnp.isnan(new_logX)) self.assertFalse(jnp.isnan(new_logZ)) @@ -367,17 +374,17 @@ def test_new_state_and_info(self): logprior = -1.5 loglikelihood = -2.0 info = {"test": "value"} - + state, returned_info = base.new_state_and_info( position, logprior, loglikelihood, info ) - + # Check PartitionedState self.assertIsInstance(state, base.PartitionedState) chex.assert_trees_all_close(state.position, position) self.assertEqual(state.logprior, logprior) self.assertEqual(state.loglikelihood, loglikelihood) - + # Check PartitionedInfo self.assertIsInstance(returned_info, base.PartitionedInfo) chex.assert_trees_all_close(returned_info.position, position) @@ -397,54 +404,48 @@ def create_mock_info(self, n_dead=30): """Helper to create mock NSInfo.""" # Increasing likelihood sequence loglik = jnp.linspace(-5, -1, n_dead) - loglik_birth = jnp.concatenate([ - jnp.array([-jnp.inf]), - loglik[:-1] - 0.1 - ]) - + loglik_birth = jnp.concatenate([jnp.array([-jnp.inf]), loglik[:-1] - 0.1]) + return base.NSInfo( particles=jnp.zeros((n_dead, 2)), loglikelihood=loglik, loglikelihood_birth=loglik_birth, logprior=jnp.zeros(n_dead), - inner_kernel_info={} + inner_kernel_info={}, ) def test_get_first_row_array(self): """Test get_first_row with arrays.""" x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - + result = utils.get_first_row(x) expected = jnp.array([1, 2, 3]) - + chex.assert_trees_all_close(result, expected) def test_get_first_row_pytree(self): """Test get_first_row with PyTree.""" - x = { - "a": jnp.array([[1, 2], [3, 4], [5, 6]]), - "b": jnp.array([10, 20, 30]) - } - + x = {"a": jnp.array([[1, 2], [3, 4], [5, 6]]), "b": jnp.array([10, 20, 30])} + result = utils.get_first_row(x) - + chex.assert_trees_all_close(result["a"], jnp.array([1, 2])) self.assertEqual(result["b"], 10) def test_compute_num_live_shape(self): """Test compute_num_live returns correct shape.""" mock_info = self.create_mock_info(n_dead=25) - + num_live = utils.compute_num_live(mock_info) - + chex.assert_shape(num_live, (25,)) def test_compute_num_live_values(self): """Test compute_num_live returns reasonable values.""" mock_info = self.create_mock_info(n_dead=20) - + num_live = utils.compute_num_live(mock_info) - + # Should be positive self.assertTrue(jnp.all(num_live >= 1)) # Should be reasonable (not too large) @@ -456,9 +457,9 @@ def test_logX_shapes(self): """Test logX returns correct shapes.""" mock_info = self.create_mock_info(n_dead=15) n_samples = 10 - + logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=n_samples) - + chex.assert_shape(logX_seq, (15, n_samples)) chex.assert_shape(logdX_seq, (15, n_samples)) @@ -466,30 +467,32 @@ def test_logX_monotonicity(self): """Test that logX is decreasing.""" mock_info = self.create_mock_info(n_dead=10) n_samples = 5 - + logX_seq, _ = utils.logX(self.key, mock_info, shape=n_samples) - + # Each column should be decreasing for i in range(n_samples): differences = logX_seq[1:, i] - logX_seq[:-1, i] - self.assertTrue(jnp.all(differences <= 1e-12)) # Allow for numerical precision + self.assertTrue( + jnp.all(differences <= 1e-12) + ) # Allow for numerical precision def test_log_weights_shapes(self): """Test log_weights returns correct shape.""" mock_info = self.create_mock_info(n_dead=12) n_samples = 8 - + log_weights = utils.log_weights(self.key, mock_info, shape=n_samples) - + chex.assert_shape(log_weights, (12, n_samples)) def test_log_weights_finite(self): """Test that most log_weights are finite.""" mock_info = self.create_mock_info(n_dead=20) n_samples = 5 - + log_weights = utils.log_weights(self.key, mock_info, shape=n_samples) - + # Most weights should be finite finite_fraction = jnp.mean(jnp.isfinite(log_weights)) self.assertGreater(finite_fraction, 0.3) # At least 30% should be finite @@ -497,9 +500,9 @@ def test_log_weights_finite(self): def test_ess_properties(self): """Test ESS computation properties.""" mock_info = self.create_mock_info(n_dead=30) - + ess = utils.ess(self.key, mock_info) - + # ESS should be positive and finite self.assertGreater(ess, 0.0) self.assertFalse(jnp.isnan(ess)) @@ -511,21 +514,23 @@ def test_log1mexp_values(self): """Test log1mexp utility function.""" # Test values where we know the expected result x = jnp.array([-0.1, -1.0, -2.0, -10.0]) - + result = utils.log1mexp(x) - + # Should all be finite and negative (since log(1-exp(x)) < 0 for x < 0) self.assertTrue(jnp.all(jnp.isfinite(result))) # For large negative x, log(1-exp(x)) ≈ log(1) = 0 - self.assertAlmostEqual(result[-1], 0.0, places=3) # Less strict for numerical precision + self.assertAlmostEqual( + result[-1], 0.0, places=3 + ) # Less strict for numerical precision def test_log1mexp_edge_cases(self): """Test log1mexp edge cases.""" # Test near the transition point x_transition = jnp.array([-0.6931472]) # Approximately -log(2) - + result = utils.log1mexp(x_transition) - + self.assertTrue(jnp.isfinite(result)) self.assertLess(result, 0.0) @@ -539,45 +544,43 @@ def setUp(self): def test_nss_top_level_api(self): """Test NSS as_top_level_api function.""" + def logprior_fn(x): return -0.5 * jnp.sum(x**2) - + def loglik_fn(x): return -jnp.sum(x**2) - + num_live = 20 - + # Test the top-level API - algorithm = nss.as_top_level_api( - logprior_fn, - loglik_fn, - 5 # num_inner_steps - ) - + algorithm = nss.as_top_level_api(logprior_fn, loglik_fn, 5) # num_inner_steps + # Should return a SamplingAlgorithm self.assertTrue(hasattr(algorithm, "init")) self.assertTrue(hasattr(algorithm, "step")) self.assertTrue(callable(algorithm.init)) self.assertTrue(callable(algorithm.step)) - + # Test initialization - NSS uses adaptive.init which needs different signature particles = jax.random.normal(self.key, (num_live, 2)) state = algorithm.init(particles) - + self.assertIsInstance(state, base.NSState) chex.assert_shape(state.particles, (num_live, 2)) def test_nss_inner_kernel_execution(self): """Test NSS inner kernel execution by building a full kernel.""" + def logprior_fn(x): return -0.5 * jnp.sum(x**2) - + def loglik_fn(x): return -jnp.sum(x**2) - + # Build NSS kernel kernel = nss.build_kernel(logprior_fn, loglik_fn, num_inner_steps=2) - + # Create initial state with proper inner_kernel_params num_live = 5 particles = jax.random.normal(self.key, (num_live, 2)) * 0.3 @@ -585,41 +588,40 @@ def loglik_fn(x): # NSS needs covariance params cov_params = nss.compute_covariance_from_particles(state, None, {}) state = state._replace(inner_kernel_params=cov_params) - + # Execute kernel - this tests the inner kernel execution paths new_state, info = kernel(self.key, state) - + # Check that state is updated correctly self.assertIsInstance(new_state, base.NSState) self.assertIsInstance(info, base.NSInfo) - + # Should still have same number of particles chex.assert_shape(new_state.particles, (num_live, 2)) chex.assert_shape(new_state.loglikelihood, (num_live,)) - + # Evidence should be updated self.assertNotEqual(new_state.logZ, state.logZ) def test_nss_compute_covariance_edge_cases(self): """Test covariance computation edge cases.""" # Test with very few particles - num_live = 3 particles = jnp.array([[1.0], [2.0], [3.0]]) # 1D particles - + def logprior_fn(x): return -0.5 * x**2 - + def loglik_fn(x): - return -x**2 - + return -(x**2) + state = base.init(particles, logprior_fn, loglik_fn) - + # Should handle small number of particles params = nss.compute_covariance_from_particles(state, None, {}) - + self.assertIn("cov", params) cov = params["cov"] - + # Should not be NaN or infinite self.assertFalse(jnp.isnan(cov).any()) self.assertFalse(jnp.isinf(cov).any()) @@ -629,9 +631,9 @@ def test_nss_direction_sampling_edge_cases(self): # Test with nearly singular covariance cov = jnp.array([[1e-6, 0.0], [0.0, 1e-6]]) params = {"cov": cov} - + direction = nss.sample_direction_from_covariance(self.key, params) - + chex.assert_shape(direction, (2,)) # Should be finite even with small covariance self.assertTrue(jnp.all(jnp.isfinite(direction))) @@ -647,49 +649,55 @@ def setUp(self): def test_missing_utility_functions(self): """Test utility functions that weren't covered.""" # Create test data that would exercise missing lines - + # Test with edge case data - use float dtypes mock_info = base.NSInfo( particles=jnp.zeros((5, 1)), - loglikelihood=jnp.array([-10.0, -5.0, -3.0, -2.0, -1.0]), # Wide range, float + loglikelihood=jnp.array( + [-10.0, -5.0, -3.0, -2.0, -1.0] + ), # Wide range, float loglikelihood_birth=jnp.array([-jnp.inf, -15.0, -8.0, -4.0, -2.5]), logprior=jnp.zeros(5), - inner_kernel_info={} + inner_kernel_info={}, ) - + # Test functions that might have missing coverage num_live = utils.compute_num_live(mock_info) self.assertTrue(jnp.all(num_live >= 1)) - + # Test with different shapes logX_seq, logdX_seq = utils.logX(self.key, mock_info, shape=3) chex.assert_shape(logX_seq, (5, 3)) - + # Test log_weights with edge cases log_weights = utils.log_weights(self.key, mock_info, shape=2) chex.assert_shape(log_weights, (5, 2)) def test_repeat_kernel_decorator(self): """Test repeat_kernel decorator function.""" + # Simple mock kernel @utils.repeat_kernel(3) def mock_kernel(rng_key, state, *args): # Just update position slightly - new_pos = state["position"] + jax.random.normal(rng_key, state["position"].shape) * 0.01 + new_pos = ( + state["position"] + + jax.random.normal(rng_key, state["position"].shape) * 0.01 + ) new_state = state.copy() new_state["position"] = new_pos return new_state, {"step": 1} - + initial_state = {"position": jnp.array([1.0, 2.0])} - + # Test decorated kernel final_state, infos = mock_kernel(self.key, initial_state) - + # Should have run 3 times (scan packs infos into dict structure) self.assertIsInstance(final_state, dict) self.assertIn("position", final_state) chex.assert_shape(final_state["position"], (2,)) - + # Info structure depends on how scan handles the dict self.assertIsInstance(infos, dict) self.assertIn("step", infos) @@ -711,17 +719,17 @@ def test_combine_dead_info(self): loglikelihood=jnp.array([0.1, 0.2]), loglikelihood_birth=jnp.array([-jnp.inf, 0.05]), logprior=jnp.array([-1.0, -1.1]), - inner_kernel_info={"test": "value1"} + inner_kernel_info={"test": "value1"}, ) - + dead2 = base.NSInfo( particles=jnp.array([[3.0], [4.0]]), loglikelihood=jnp.array([0.3, 0.4]), loglikelihood_birth=jnp.array([0.25, 0.35]), logprior=jnp.array([-1.2, -1.3]), - inner_kernel_info={"test": "value2"} + inner_kernel_info={"test": "value2"}, ) - + # Mock live state live = base.NSState( particles=jnp.array([[5.0], [6.0]]), @@ -732,11 +740,11 @@ def test_combine_dead_info(self): logX=-2.0, logZ=-5.0, logZ_live=-3.0, - inner_kernel_params={} + inner_kernel_params={}, ) - + combined = utils.combine_dead_info([dead1, dead2], live) - + # Should combine all dead + live particles expected_total = 2 + 2 + 2 # dead1 + dead2 + live chex.assert_shape(combined.particles, (expected_total, 1)) @@ -750,33 +758,33 @@ def test_sample_particles(self): loglikelihood=jnp.array([0.1, 0.3, 0.2, 0.4]), loglikelihood_birth=jnp.array([-jnp.inf, 0.05, 0.15, 0.25]), logprior=jnp.array([-1.0, -1.1, -1.2, -1.3]), - inner_kernel_info={} + inner_kernel_info={}, ) - + # Sample particles sampled = utils.sample_particles(self.key, mock_info, shape=6) - + chex.assert_shape(sampled, (6, 1)) def test_uniform_prior_function(self): """Test uniform_prior function (utils.py lines 381-398).""" bounds = {"x": (-2.0, 2.0), "y": (0.0, 1.0)} num_live = 10 - - particles, logprior_fn = utils.uniform_prior(self.key, bounds, num_live) - + + particles, logprior_fn = utils.uniform_prior(self.key, num_live, bounds) + # Check particles structure self.assertIn("x", particles) self.assertIn("y", particles) chex.assert_shape(particles["x"], (num_live,)) chex.assert_shape(particles["y"], (num_live,)) - + # Check bounds are respected self.assertTrue(jnp.all(particles["x"] >= -2.0)) self.assertTrue(jnp.all(particles["x"] <= 2.0)) self.assertTrue(jnp.all(particles["y"] >= 0.0)) self.assertTrue(jnp.all(particles["y"] <= 1.0)) - + # Test logprior function test_params = {"x": 0.5, "y": 0.3} logprior_val = logprior_fn(test_params) @@ -786,44 +794,45 @@ def test_uniform_prior_function(self): def test_adaptive_kernel_missing_lines(self): """Test adaptive kernel missing coverage (adaptive.py lines 148-154).""" from blackjax.ns import adaptive - + def logprior_fn(x): return -0.5 * jnp.sum(x**2) - + def loglik_fn(x): return -jnp.sum(x**2) - - def mock_inner_kernel(rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params): + + def mock_inner_kernel( + rng_key, inner_state, logprior_fn, loglik_fn, loglik_0, params + ): # Mock that always accepts new_inner_state = base.PartitionedState( inner_state.position + 0.01, inner_state.logprior, - inner_state.loglikelihood + 0.1 + inner_state.loglikelihood + 0.1, ) return new_inner_state, {"accepted": True} - + def mock_update_fn(state, info, params): # This should exercise the missing lines in adaptive.py return {"updated": True, "cov": jnp.eye(2)} - + # Create state - num_live = 5 - particles = jax.random.normal(self.key, (num_live, 2)) * 0.1 + particles = jax.random.normal(self.key, (5, 2)) * 0.1 state = base.init(particles, logprior_fn, loglik_fn) - + # Build adaptive kernel delete_fn = functools.partial(base.delete_fn, num_delete=1) kernel = adaptive.build_kernel( logprior_fn, loglik_fn, delete_fn, mock_inner_kernel, mock_update_fn ) - + # Execute to test missing lines new_state, info = kernel(self.key, state) - + self.assertIsInstance(new_state, base.NSState) # Check that update function was called (adaptive logic) self.assertIn("updated", new_state.inner_kernel_params) if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main()