diff --git a/blackjax/mcmc/ss.py b/blackjax/mcmc/ss.py index 50764ad4a..46855fbef 100644 --- a/blackjax/mcmc/ss.py +++ b/blackjax/mcmc/ss.py @@ -33,7 +33,6 @@ import jax.numpy as jnp from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ @@ -59,7 +58,6 @@ class SliceState(NamedTuple): position: ArrayLikeTree logdensity: float - constraint: Array class SliceInfo(NamedTuple): @@ -72,8 +70,6 @@ class SliceInfo(NamedTuple): ---------- is_accepted A boolean indicating whether the proposed sample was accepted. - constraint - The constraint values at the final accepted position. num_steps The number of steps taken to expand the interval during the "stepping-out" phase. num_shrink @@ -86,9 +82,7 @@ class SliceInfo(NamedTuple): num_shrink: int -def init( - position: ArrayTree, logdensity_fn: Callable, constraint_fn: Callable -) -> SliceState: +def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState: """Initialize the Slice Sampler state. Parameters @@ -103,11 +97,11 @@ def init( SliceState The initial state of the Slice Sampler. """ - return SliceState(position, logdensity_fn(position), constraint_fn(position)) + return SliceState(position, logdensity_fn(position)) def build_kernel( - stepper_fn: Callable, + slice_fn: Callable[[float], tuple[SliceState, bool]], max_steps: int = 10, max_shrinkage: int = 100, ) -> Callable: @@ -119,17 +113,21 @@ def build_kernel( Parameters ---------- - stepper_fn - A function that computes a new position given an initial position, - direction `d` and a slice parameter `t`. - `(x0, d, t) -> x_new` where e.g. `x_new = x0 + t * d`. + slice_fn + A function that takes a scalar parameter `t` and returns a tuple + (SliceState, is_accepted) indicating the state at that parameter value + and whether it satisfies acceptance criteria. + max_steps + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. Returns ------- Callable - A kernel function that takes a PRNG key, the current `SliceState`, - the log-density function, direction `d`, constraint function, constraint - values, and strict flags, and returns a new `SliceState` and `SliceInfo`. + A kernel function that takes a PRNG key and the current `SliceState`, + and returns a new `SliceState` and `SliceInfo`. References ---------- @@ -139,32 +137,19 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: SliceState, - logdensity_fn: Callable, - d: ArrayTree, - constraint_fn: Callable, - constraint: Array, - strict: Array, ) -> tuple[SliceState, SliceInfo]: vs_key, hs_key = jax.random.split(rng_key) - logslice = state.logdensity + jnp.log(jax.random.uniform(vs_key)) + u = jax.random.uniform(vs_key) + logslice = state.logdensity + jnp.log(u) vertical_is_accepted = logslice < state.logdensity - def slicer(t) -> tuple[SliceState, SliceInfo]: - x, step_accepted = stepper_fn(state.position, d, t) - new_state = init(x, logdensity_fn, constraint_fn) - constraints_ok = jnp.all( - jnp.where( - strict, - new_state.constraint > constraint, - new_state.constraint >= constraint, - ) - ) + def _slice_fn(t): + new_state, is_accepted = slice_fn(t) in_slice = new_state.logdensity >= logslice - is_accepted = in_slice & constraints_ok & step_accepted - return new_state, is_accepted + return new_state, is_accepted & in_slice new_state, info = horizontal_slice( - hs_key, slicer, state, max_steps, max_shrinkage + hs_key, state, _slice_fn, max_steps, max_shrinkage ) info = info._replace(is_accepted=info.is_accepted & vertical_is_accepted) return new_state, info @@ -174,29 +159,29 @@ def slicer(t) -> tuple[SliceState, SliceInfo]: def horizontal_slice( rng_key: PRNGKey, - slicer: Callable, state: SliceState, + slice_fn: Callable[[float], tuple[SliceState, bool]], m: int, max_shrinkage: int, ) -> tuple[SliceState, SliceInfo]: """Propose a new sample using the stepping-out and shrinking procedures. This function implements the core of the Hit-and-Run Slice Sampling algorithm. - It first expands an interval (`[l, r]`) along the slice starting - from `x0` and proceeding along direction `d` until both ends are outside - the slice defined by `logslice` (stepping-out). Then, it samples - points uniformly from this interval and shrinks the interval until a point - is found that lies within the slice (shrinking). + It first expands an interval (`[l, r]`) along a one-dimensional parameterization + until both ends are outside the slice defined by `logslice` (stepping-out). + Then, it samples points uniformly from this interval and shrinks the interval + until a point is found that lies within the slice (shrinking). Parameters ---------- rng_key A JAX PRNG key. - slicer - A function that takes a scalar `t` and returns a state and info on the - slice. state The current slice sampling state. + slice_fn + A function that takes a scalar parameter `t` and returns a tuple + (SliceState, is_accepted) indicating the state at that parameter value + and whether it satisfies acceptance criteria. m The maximum number of steps to take when expanding the interval in each direction during the stepping-out phase. @@ -213,14 +198,14 @@ def horizontal_slice( # Initial bounds rng_key, subkey = jax.random.split(rng_key) u, v = jax.random.uniform(subkey, 2) - j = jnp.floor(m * v).astype(int) + j = jnp.floor(m * v).astype(jnp.int32) k = (m - 1) - j # Expand def step_body_fun(carry): i, s, t, _ = carry t += s - _, is_accepted = slicer(t) + _, is_accepted = slice_fn(t) i -= 1 return i, s, t, is_accepted @@ -240,7 +225,7 @@ def shrink_body_fun(carry): rng_key, subkey = jax.random.split(rng_key) u = jax.random.uniform(subkey, minval=l, maxval=r) - new_state, is_accepted = slicer(u) + new_state, is_accepted = slice_fn(u) n += 1 l = jnp.where(u < 0, u, l) @@ -255,17 +240,18 @@ def shrink_cond_fun(carry): carry = 0, rng_key, l, r, state, False carry = jax.lax.while_loop(shrink_cond_fun, shrink_body_fun, carry) n, _, _, _, new_state, is_accepted = carry - new_state, (is_accepted, _, _) = static_binomial_sampling( - rng_key, jnp.log(is_accepted), state, new_state + new_state = jax.tree.map( + lambda new, old: jnp.where(is_accepted, new, old), new_state, state ) slice_info = SliceInfo(is_accepted, m + 1 - j - k, n) return new_state, slice_info def build_hrss_kernel( - generate_slice_direction_fn: Callable, - stepper_fn: Callable, + cov: Array, + init_fn: Callable = init, max_steps: int = 10, + max_shrinkage: int = 100, ) -> Callable: """Build a Hit-and-Run Slice Sampling kernel. @@ -276,15 +262,15 @@ def build_hrss_kernel( Parameters ---------- - generate_slice_direction_fn - A function that, given a PRNG key, generates a direction vector (PyTree - with the same structure as the position) for the "hit-and-run" part of - the algorithm. This direction is typically normalized. - - stepper_fn - A function that computes a new position given an initial position, a - direction, and a step size `t`. It should implement something analogous - to `x_new = x_initial + t * direction`. + cov + The covariance matrix used by the direction proposal function + init_fn + A function initializing a SliceState + max_steps + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. Returns ------- @@ -292,75 +278,59 @@ def build_hrss_kernel( A kernel function that takes a PRNG key, the current `SliceState`, and the log-density function, and returns a new `SliceState` and `SliceInfo`. """ - slice_kernel = build_kernel(stepper_fn, max_steps) def kernel( rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable ) -> tuple[SliceState, SliceInfo]: rng_key, prop_key = jax.random.split(rng_key, 2) - d = generate_slice_direction_fn(prop_key) - constraint_fn = lambda x: jnp.array([]) - constraint = jnp.array([]) - strict = jnp.array([], dtype=bool) - return slice_kernel( - rng_key, state, logdensity_fn, d, constraint_fn, constraint, strict - ) - - return kernel - + d = sample_direction_from_covariance(prop_key, state.position, cov) -def default_stepper_fn(x: ArrayTree, d: ArrayTree, t: float) -> ArrayTree: - """A simple stepper function that moves from `x` along direction `d` by `t` units. - - Implements the operation: `x_new = x + t * d`. + def slice_fn(t): + x = jax.tree.map(lambda x, d: x + t * d, state.position, d) + is_accepted = True + new_state = init_fn(x, logdensity_fn) + return new_state, is_accepted - Parameters - ---------- - x - The starting position (PyTree). - d - The direction of movement (PyTree, same structure as `x`). - t - The scalar step size or distance along the direction. + slice_kernel = build_kernel(slice_fn, max_steps, max_shrinkage) + return slice_kernel(rng_key, state) - Returns - ------- - position, is_accepted - """ - return jax.tree.map(lambda x, d: x + t * d, x, d), True + return kernel -def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array: +def sample_direction_from_covariance( + rng_key: PRNGKey, position: ArrayLikeTree, cov: Array +) -> Array: """Generates a random direction vector, normalized, from a multivariate Gaussian. - This function samples a direction `d` from a zero-mean multivariate Gaussian - distribution with covariance matrix `cov`, and then normalizes `d` to be a - unit vector with respect to the Mahalanobis norm defined by `inv(cov)`. - That is, `d_normalized^T @ inv(cov) @ d_normalized = 1`. Parameters ---------- rng_key A JAX PRNG key. + position + The current position of the chain (used for extracting shape). cov - The covariance matrix for the multivariate Gaussian distribution from which - the initial direction is sampled. Assumed to be a 2D array. - + The covariance matrix. Returns ------- Array A normalized direction vector (1D array). """ - d = jax.random.multivariate_normal(rng_key, mean=jnp.zeros(cov.shape[0]), cov=cov) + p, unravel_fn = jax.flatten_util.ravel_pytree(position) + d = jax.random.normal(rng_key, shape=p.shape, dtype=p.dtype) invcov = jnp.linalg.inv(cov) norm = jnp.sqrt(jnp.einsum("...i,...ij,...j", d, invcov, d)) d = d / norm[..., None] - return d + d *= 2 + return unravel_fn(d) def hrss_as_top_level_api( logdensity_fn: Callable, cov: Array, + init_fn: Callable = init, + max_steps: int = 10, + max_shrinkage: int = 100, ) -> SamplingAlgorithm: """Creates a Hit-and-Run Slice Sampling algorithm. @@ -373,9 +343,16 @@ def hrss_as_top_level_api( logdensity_fn The log-density function of the target distribution to sample from. cov - The covariance matrix used by the default direction proposal function - (`default_proposal_distribution`). This matrix shapes the random + The covariance matrix used by the direction proposal function + (`sample_direction_from_covariance`). This matrix shapes the random directions proposed for the slice sampling steps. + init_fn + A function initializing a SliceState + max_steps + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. Returns ------- @@ -383,8 +360,7 @@ def hrss_as_top_level_api( A `SamplingAlgorithm` tuple containing `init` and `step` functions for the configured Hit-and-Run Slice Sampler. """ - generate_slice_direction_fn = partial(sample_direction_from_covariance, cov=cov) - kernel = build_hrss_kernel(generate_slice_direction_fn, default_stepper_fn) - init_fn = partial(init, logdensity_fn=logdensity_fn) + kernel = build_hrss_kernel(cov, init_fn, max_steps, max_shrinkage) + init_fn = partial(init_fn, logdensity_fn=logdensity_fn) step_fn = partial(kernel, logdensity_fn=logdensity_fn) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/__init__.py b/blackjax/ns/__init__.py index d9c37553d..bc2a66ff1 100644 --- a/blackjax/ns/__init__.py +++ b/blackjax/ns/__init__.py @@ -13,8 +13,7 @@ # limitations under the License. """Nested Sampling Algorithms in BlackJAX. -This subpackage provides implementations of Nested Sampling algorithms, -including a base version, an adaptive version, and Nested Slice Sampling (NSS). +This subpackage provides implementations of Nested Sampling algorithms. Nested Sampling is a Monte Carlo method for Bayesian computation, primarily used for evidence (marginal likelihood) calculation and posterior sampling. @@ -23,20 +22,23 @@ Available modules: ------------------ -- `adaptive`: Implements an adaptive Nested Sampling algorithm where inner - kernel parameters are tuned at each iteration. -- `base`: Provides core components and a non-adaptive Nested Sampling kernel. +- `base`: Provides core components for Nested Sampling. +- `adaptive`: Implements Adaptive Nested Sampling, combining SMC tempering - `nss`: Implements Nested Slice Sampling, using Hit-and-Run Slice Sampling as the inner kernel with adaptive tuning of its proposal mechanism. +- `integrator`: Provides NSIntegrator for tracking evidence integration. - `utils`: Contains utility functions for processing and analyzing Nested Sampling results. +- `from_mcmc`: Utilities to build Nested Sampling algorithms from MCMC kernels. """ -from . import adaptive, base, nss, utils +from . import adaptive, base, from_mcmc, integrator, nss, utils __all__ = [ - "adaptive", "base", - "utils", + "adaptive", + "integrator", "nss", + "utils", + "from_mcmc", ] diff --git a/blackjax/ns/adaptive.py b/blackjax/ns/adaptive.py index 1bddb6296..1f0aece93 100644 --- a/blackjax/ns/adaptive.py +++ b/blackjax/ns/adaptive.py @@ -13,144 +13,108 @@ # limitations under the License. """Adaptive Nested Sampling for BlackJAX. -This module provides an adaptive version of the Nested Sampling algorithm. -In this variant, the parameters of the inner kernel, which is used to -sample new live points, are updated (tuned) at each iteration of the -Nested Sampling loop. This adaptation is based on the information from the -current set of live particles or the history of the sampling process, -allowing the kernel to adjust to the changing characteristics of the -constrained prior distribution as the likelihood threshold increases. +This combines the SMC equivalent of Adaptive Tempering and inner kernel tuning in one file. """ -from typing import Callable, Dict, Optional +from functools import partial +from typing import Callable, Dict, NamedTuple, Optional +import jax import jax.numpy as jnp -from blackjax.ns.base import NSInfo, NSState +from blackjax.ns.base import NSInfo, NSState, StateWithLogLikelihood from blackjax.ns.base import build_kernel as base_build_kernel from blackjax.ns.base import init as base_init +from blackjax.ns.integrator import NSIntegrator, init_integrator, update_integrator from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["init", "build_kernel"] -def init( - particles: ArrayLikeTree, - logprior_fn: Callable, - loglikelihood_fn: Callable, - loglikelihood_birth: Array = -jnp.nan, - update_inner_kernel_params_fn: Optional[Callable] = None, -) -> NSState: - """Initializes the Nested Sampler state. +class AdaptiveNSState(NamedTuple): + """An extension of the base NSState to include inner kernel parameters. + + This state class extends the base Nested Sampling state by adding a + dictionary of parameters for the inner kernel and an integrator to track + relevant values for the evidence computation. - Parameters + Attributes ---------- particles - An initial set of particles (PyTree of arrays) drawn from the prior - distribution. The leading dimension of each leaf array must be equal to - the number of particles. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. - logprior_fn - A function that computes the log-prior of a single particle. - loglikelihood_birth - The initial log-likelihood birth threshold. Defaults to -NaN, which - implies no initial likelihood constraint beyond the prior. - update_inner_kernel_params_fn - A function that takes the `NSState`, `NSInfo` from the completed NS - step, and the current inner kernel parameters dictionary, and returns - a dictionary of parameters to be used for the kernel in the *next* NS step. - - Returns - ------- - NSState - The initial state of the Nested Sampler. + The StateWithLogLikelihood of the current live particles. + integrator + The NSIntegrator instance that tracks evidence-related statistics. + inner_kernel_params + A dictionary of parameters for the inner kernel used to generate new + particles during the Nested Sampling process. """ - state = base_init(particles, logprior_fn, loglikelihood_fn, loglikelihood_birth) + + particles: StateWithLogLikelihood + integrator: NSIntegrator + inner_kernel_params: Dict[str, ArrayTree] + + +def init( + positions: ArrayLikeTree, + init_state_fn: Callable, + loglikelihood_birth: Array = jnp.nan, + update_inner_kernel_params_fn: Optional[Callable] = None, + rng_key: Optional[jax.random.PRNGKey] = None, +) -> AdaptiveNSState: + base_state = base_init( + positions, init_state_fn, loglikelihood_birth=loglikelihood_birth + ) + integrator = init_integrator(base_state.particles) + inner_kernel_params = {} if update_inner_kernel_params_fn is not None: - inner_kernel_params = update_inner_kernel_params_fn(state, None, {}) - state = state._replace(inner_kernel_params=inner_kernel_params) - return state + inner_kernel_params = update_inner_kernel_params_fn( + rng_key, base_state, None, {} + ) + return AdaptiveNSState( + particles=base_state.particles, + inner_kernel_params=inner_kernel_params, + integrator=integrator, + ) def build_kernel( - logprior_fn: Callable, - loglikelihood_fn: Callable, delete_fn: Callable, inner_kernel: Callable, update_inner_kernel_params_fn: Callable[ - [NSState, NSInfo, Dict[str, ArrayTree]], Dict[str, ArrayTree] + [PRNGKey, NSState, NSInfo, Dict[str, ArrayTree]], Dict[str, ArrayTree] ], ) -> Callable: """Build an adaptive Nested Sampling kernel. - This kernel extends the base Nested Sampling kernel by re-computing/tuning - the parameters for the inner kernel at each step. The `update_inner_kernel_params_fn` - is called after each NS step to determine the parameters for the *next* NS - step. - - Parameters - ---------- - logprior_fn - A function that computes the log-prior probability of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. - delete_fn - this particle deletion function has the signature - `(rng_key, current_state) -> (dead_idx, target_update_idx, start_idx)` - and identifies particles to be deleted, particles to be updated, and - selects live particles to be starting points for the inner kernel - for new particle generation. - inner_kernel - This kernel function has the signature - `(rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, inner_kernel_params) -> (new_inner_state, inner_info)`, - and is used to generate new particles. - update_inner_kernel_params_fn - A function that takes the `NSState`, `NSInfo` from the completed NS - step, and the current inner kernel parameters dictionary, and returns - a dictionary of parameters to be used for the kernel in the *next* NS step. - - Returns - ------- - Callable - A kernel function for adaptive Nested Sampling. It takes an `rng_key` and the - current `NSState` and returns a tuple containing the new `NSState` and - the `NSInfo` for the step. + This function constructs a Nested Sampling kernel that incorporates + adaptive tuning of the inner kernel parameters based on the current state + of the sampler and the information from the previous update step. """ - base_kernel = base_build_kernel( - logprior_fn, - loglikelihood_fn, - delete_fn, - inner_kernel, - ) - - def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: - """Performs one step of adaptive Nested Sampling. - - This involves running a step of the base Nested Sampling algorithm using - the current inner kernel parameters, and then updating these parameters - for the next step. - - Parameters - ---------- - rng_key - A JAX PRNG key. - state - The current `NSState`. - - Returns - ------- - tuple[NSState, NSInfo] - A tuple with the new `NSState` (including updated inner kernel - parameters) and the `NSInfo` for this step. - """ - new_state, info = base_kernel(rng_key, state) + def kernel( + rng_key: PRNGKey, state: AdaptiveNSState + ) -> tuple[AdaptiveNSState, NSInfo]: + """Performs one step of adaptive Nested Sampling.""" + adapted_kernel = base_build_kernel( + delete_fn, + partial(inner_kernel, **state.inner_kernel_params), + ) - inner_kernel_params = update_inner_kernel_params_fn( - new_state, info, new_state.inner_kernel_params + new_state, info = adapted_kernel(rng_key, state) + inner_kernel_update_key, rng_key = jax.random.split(rng_key) + new_inner_kernel_params = update_inner_kernel_params_fn( + inner_kernel_update_key, new_state, info, new_state.inner_kernel_params + ) + new_integrator_state = update_integrator( + state.integrator, new_state.particles, info.particles + ) + return ( + AdaptiveNSState( + particles=new_state.particles, + inner_kernel_params=new_inner_kernel_params, + integrator=new_integrator_state, + ), + info, ) - new_state = new_state._replace(inner_kernel_params=inner_kernel_params) - return new_state, info return kernel diff --git a/blackjax/ns/base.py b/blackjax/ns/base.py index 978f242d4..c6fab4dc7 100644 --- a/blackjax/ns/base.py +++ b/blackjax/ns/base.py @@ -11,285 +11,144 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Base components for Nested Sampling algorithms in BlackJAX. - -This module provides the fundamental data structures (`NSState`, `NSInfo`) and -a basic, non-adaptive kernel for Nested Sampling. Nested Sampling is a -Monte Carlo method primarily aimed at Bayesian evidence (marginal likelihood) -computation and posterior sampling, particularly effective for multi-modal -distributions. - -The core idea is to transform the multi-dimensional evidence integral into a -one-dimensional integral over the prior volume, ordered by likelihood. This is -achieved by iteratively replacing the point with the lowest likelihood among a -set of "live" points with a new point sampled from the prior, subject to the -constraint that its likelihood must be higher than the one just discarded. - -This base implementation uses a provided kernel to perform the constrained -sampling. -""" - -from typing import Callable, Dict, NamedTuple, Optional +"""""" +from typing import Callable, NamedTuple import jax import jax.numpy as jnp -from jax.scipy.special import logsumexp -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import Array, ArrayLikeTree, PRNGKey __all__ = ["init", "build_kernel", "NSState", "NSInfo", "delete_fn"] -class NSState(NamedTuple): - """State of the Nested Sampler. +class StateWithLogLikelihood(NamedTuple): + """State of a particle in NS. Mostly dressing a conventional + MCMC state with loglikelihood information. Positions are an ArrayTree + where each leaf represents a variable from the posterior. Attributes ---------- - particles - A PyTree of arrays, where each leaf array has a leading dimension - equal to the number of live particles. Stores the current positions of - the live particles. + position + The position of the particle (PyTree). + logdensity + The log-density of the particle under the prior (Array). loglikelihood - An array of log-likelihood values, one for each live particle. + The log-likelihood of the particle (Array). loglikelihood_birth - An array storing the log-likelihood threshold that each current live - particle was required to exceed when it was "born" (i.e., sampled). - This is used for reconstructing the nested sampling path. - logprior - An array of log-prior values, one for each live particle. - pid - Particle ID. An array of integers tracking the identity or lineage of - particles, primarily for diagnostic purposes. - logX - The log of the current prior volume estimate. - logZ - The accumulated log evidence estimate from the "dead" points . - logZ_live - The current estimate of the log evidence contribution from the live points. - inner_kernel_params - A dictionary of parameters for the inner kernel. + The log-likelihood birth threshold for the particle (Array). """ - particles: ArrayLikeTree - loglikelihood: Array # The log-likelihood of the particles - loglikelihood_birth: Array # The log-likelihood threshold at particle birth - logprior: Array # The log-prior density of the particles - pid: Array # particle IDs - logX: Array # The current log-volume estimate - logZ: Array # The accumulated evidence estimate - logZ_live: Array # The current evidence estimate - inner_kernel_params: Dict # Parameters for the inner kernel + position: ArrayLikeTree + logdensity: Array + loglikelihood: Array + loglikelihood_birth: Array -class NSInfo(NamedTuple): - """Additional information returned at each step of the Nested Sampling algorithm. +class NSState(NamedTuple): + """State of the Nested Sampler. - Attributes - ---------- - particles - The PyTree of particles that were marked as "dead" (replaced) in the - current step. - loglikelihood - The log-likelihood values of the dead particles. - loglikelihood_birth - The birth log-likelihood thresholds of the dead particles. - logprior - The log-prior values of the dead particles. - inner_kernel_info - A NamedTuple (or any PyTree) containing information from the update step - (inner kernel) used to generate new live particles. The content - depends on the specific inner kernel used. + At the most basic level, this is just a wrapper around a StateWithLogLikelihood + however it is extended in other NS implementations. """ - particles: ArrayTree - loglikelihood: Array # The log-likelihood of the particles - loglikelihood_birth: Array # The log-likelihood threshold at particle birth - logprior: Array # The log-prior density of the particles - inner_kernel_info: NamedTuple # Information from the inner kernel update step - - -class PartitionedState(NamedTuple): - """State container that partitions out the loglikelihood and logprior. - - This intermediate construction wraps around the usual State of an MCMC chain - so that the loglikelihood and logprior can be efficiently recorded, a - necessary step for the Parition function reconstruction that Nested - Sampling builds - + particles: StateWithLogLikelihood - Attributes - ---------- - position - A PyTree of arrays representing the current positions of the particles. - Each leaf array has a leading dimension corresponding to the number of particles. - logprior - An array of log-prior density values evaluated at the particle positions. - Shape: (n_particles,) - loglikelihood - An array of log-likelihood values evaluated at the particle positions. - Shape: (n_particles,) - """ - position: ArrayLikeTree # Current positions of particles in the inner kernel - logprior: Array # Log-prior values for particles in the inner kernel - loglikelihood: Array # Log-likelihood values for particles in the inner kernel - - -class PartitionedInfo(NamedTuple): - """Transition information that additionally records a partitioned loglikelihood - and logprior. - - See PartitionedState +class NSInfo(NamedTuple): + """Additional information returned at each step of the Nested Sampling algorithm. Attributes ---------- - position - A PyTree of arrays representing the final positions after the transition step. - Structure matches the input particle positions. - logprior - An array of log-prior density values at the final positions. - Kept separate to support posterior repartitioning schemes. - Shape: (n_particles,) - loglikelihood - An array of log-likelihood values at the final positions. - Kept separate to support posterior repartitioning schemes. - Shape: (n_particles,) - info - Additional transition-specific diagnostic information from the step. - The content and structure depend on the specific transition implementation - (e.g., acceptance rates, step sizes, number of evaluations, etc.). + particles + The StateWithLogLikelihood of particles that were marked as "dead" (replaced). + update_info + A NamedTuple (or any PyTree) containing information from the update step + (inner kernel) used to generate new live particles. """ - position: ArrayTree - logprior: ArrayTree - loglikelihood: ArrayTree - info: NamedTuple + particles: StateWithLogLikelihood + update_info: NamedTuple -def new_state_and_info(position, logprior, loglikelihood, info): - """Create new PartitionedState and PartitionedInfo from transition results. - - This utility function packages the results of a transition into the standard - partitioned state and info containers, maintaining the separation of logprior - and loglikelihood components. +def init_state_strategy( + position: ArrayLikeTree, + logprior_fn: Callable, + loglikelihood_fn: Callable, + loglikelihood_birth: Array = jnp.nan, +) -> StateWithLogLikelihood: + """The default initialisation strategy for each state. Parameters ---------- position - The particle positions after the transition step. + A PyTree of arrays representing the initial positions of the particles. + Each leaf array has a leading dimension corresponding to the number of particles. logprior - The log-prior densities at the new positions. + A function that computes the log-prior density for a single particle. loglikelihood - The log-likelihood values at the new positions. - info - Additional transition-specific information from the step. + A function that computes the log-likelihood for a single particle. + loglikelihood_birth + The log-likelihood threshold that the particle must exceed. Defaults to NaN. Returns ------- - tuple[PartitionedState, PartitionedInfo] - A tuple containing the new partitioned state and associated information. + StateWithLogLikelihood + The initialized state containing positions, log-prior, log-likelihood, and birth likelihood. """ - new_state = PartitionedState( - position=position, - logprior=logprior, - loglikelihood=loglikelihood, + logprior_values = logprior_fn(position) + loglikelihood_values = loglikelihood_fn(position) + loglikelihood_birth_values = loglikelihood_birth * jnp.ones_like( + loglikelihood_values ) - info = PartitionedInfo( - position=position, - logprior=logprior, - loglikelihood=loglikelihood, - info=info, + + return StateWithLogLikelihood( + position, logprior_values, loglikelihood_values, loglikelihood_birth_values ) - return new_state, info def init( - particles: ArrayLikeTree, - logprior_fn: Callable, - loglikelihood_fn: Callable, - loglikelihood_birth: Array = -jnp.nan, - logX: Optional[Array] = 0.0, - logZ: Optional[Array] = -jnp.inf, + positions: ArrayLikeTree, + init_state_fn: Callable, + loglikelihood_birth: Array = jnp.nan, ) -> NSState: """Initializes the Nested Sampler state. Parameters ---------- - particles - An initial set of particles (PyTree of arrays) drawn from the prior + positions + An initial set of positions (PyTree of arrays) drawn from the prior distribution. The leading dimension of each leaf array must be equal to - the number of particles. - logprior_fn - A function that computes the log-prior of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. + the number of positions. + init_state_fn + A function that initializes an NSState from positions. loglikelihood_birth - The initial log-likelihood birth threshold. Defaults to -NaN, which + The initial log-likelihood birth threshold. Defaults to NaN, which implies no initial likelihood constraint beyond the prior. - logX - The initial log prior volume estimate. Defaults to 0.0. - logZ - The initial log evidence estimate. Defaults to -inf. Returns ------- NSState The initial state of the Nested Sampler. """ - loglikelihood = loglikelihood_fn(particles) - loglikelihood_birth = loglikelihood_birth * jnp.ones_like(loglikelihood) - logprior = logprior_fn(particles) - pid = jnp.arange(len(loglikelihood)) - dtype = loglikelihood.dtype - logX = jnp.array(logX, dtype=dtype) - logZ = jnp.array(logZ, dtype=dtype) - logZ_live = logmeanexp(loglikelihood) + logX - inner_kernel_params: Dict = {} - return NSState( - particles, - loglikelihood, - loglikelihood_birth, - logprior, - pid, - logX, - logZ, - logZ_live, - inner_kernel_params, + state_init = init_state_fn(positions) + loglikelihood_birth_array = loglikelihood_birth * jnp.ones_like( + state_init.loglikelihood_birth ) + return NSState(state_init._replace(loglikelihood_birth=loglikelihood_birth_array)) def build_kernel( - logprior_fn: Callable, - loglikelihood_fn: Callable, delete_fn: Callable, inner_kernel: Callable, ) -> Callable: """Build a generic Nested Sampling kernel. - This kernel implements one step of the Nested Sampling algorithm. In each step: - 1. A set of particles with the lowest log-likelihoods are identified and - marked as "dead" using `delete_fn`. The log-likelihood of the "worst" - of these dead particles (i.e., max among the lowest ones) defines the new - likelihood constraint `loglikelihood_0`. - 2. Live particles are selected (typically with replacement from the remaining - live particles, determined by `delete_fn`) to act as starting points for - the updates. - 3. These selected live particles are evolved using an kernel - `inner_kernel`. The sampling is constrained to the region where - `loglikelihood(new_particle) > loglikelihood_0`. - 4. The newly generated particles replace particles marked for replacement, - (typically the ones that have just been deleted). - 5. The prior volume `logX` and evidence `logZ` are updated based on the - number of deleted particles and their likelihoods. - - This base version does not adapt the kernel parameters. + This function creates a kernel for the Nested Sampling algorithm by combining + a particle deletion function and an inner kernel for generating new particles. Parameters ---------- - logprior_fn - A function that computes the log-prior probability of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. delete_fn this particle deletion function has the signature `(rng_key, current_state) -> (dead_idx, target_update_idx, start_idx)` @@ -298,80 +157,45 @@ def build_kernel( for new particle generation. inner_kernel This kernel function has the signature - `(rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, params) -> (new_inner_state, inner_info)`, + `(rng_keys, inner_state, loglikelihood_0) -> (new_inner_state, inner_info)`, and is used to generate new particles. Returns ------- Callable A kernel function for Nested Sampling: - `(rng_key, state) -> (new_state, ns_info)`. + `(rng_key, state, inner_kernel_params) -> (new_state, ns_info)`. """ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: # Delete, and grab all the dead information rng_key, delete_fn_key = jax.random.split(rng_key) - dead_idx, target_update_idx, start_idx = delete_fn(delete_fn_key, state) + dead_idx, target_update_idx, start_idx = delete_fn( + delete_fn_key, state.particles + ) dead_particles = jax.tree.map(lambda x: x[dead_idx], state.particles) - dead_loglikelihood = state.loglikelihood[dead_idx] - dead_loglikelihood_birth = state.loglikelihood_birth[dead_idx] - dead_logprior = state.logprior[dead_idx] # Resample the live particles - loglikelihood_0 = dead_loglikelihood.max() - rng_key, sample_key = jax.random.split(rng_key) - sample_keys = jax.random.split(sample_key, len(start_idx)) - particles = jax.tree.map(lambda x: x[start_idx], state.particles) - logprior = state.logprior[start_idx] - loglikelihood = state.loglikelihood[start_idx] - inner_state = PartitionedState(particles, logprior, loglikelihood) - new_inner_state, inner_info = inner_kernel( - sample_keys, - inner_state, - logprior_fn, - loglikelihood_fn, - loglikelihood_0, - state.inner_kernel_params, + sample_keys = jax.random.split(rng_key, len(start_idx)) + inner_state = jax.tree.map(lambda x: x[start_idx], state.particles) + loglikelihood_0 = dead_particles.loglikelihood.max() + new_inner_state, inner_update_info = inner_kernel( + sample_keys, inner_state, loglikelihood_0 ) # Update the particles - particles = jax.tree_util.tree_map( - lambda p, n: p.at[target_update_idx].set(n), - state.particles, - new_inner_state.position, - ) - loglikelihood = state.loglikelihood.at[target_update_idx].set( - new_inner_state.loglikelihood - ) - loglikelihood_birth = state.loglikelihood_birth.at[target_update_idx].set( - loglikelihood_0 * jnp.ones(len(target_update_idx)) - ) - logprior = state.logprior.at[target_update_idx].set(new_inner_state.logprior) - pid = state.pid.at[target_update_idx].set(state.pid[start_idx]) - - # Update the run-time information - logX, logZ, logZ_live = update_ns_runtime_info( - state.logX, state.logZ, loglikelihood, dead_loglikelihood + state = state._replace( + particles=jax.tree_util.tree_map( + lambda p, n: p.at[target_update_idx].set(n), + state.particles, + new_inner_state, + ) ) # Return updated state and info - state = NSState( - particles, - loglikelihood, - loglikelihood_birth, - logprior, - pid, - logX, - logZ, - logZ_live, - state.inner_kernel_params, - ) info = NSInfo( dead_particles, - dead_loglikelihood, - dead_loglikelihood_birth, - dead_logprior, - inner_info, + inner_update_info, ) return state, info @@ -379,7 +203,7 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: def delete_fn( - rng_key: PRNGKey, state: NSState, num_delete: int + rng_key: PRNGKey, state: StateWithLogLikelihood, num_delete: int ) -> tuple[Array, Array, Array]: """Identifies particles to be deleted and selects live particles for resampling. @@ -414,7 +238,7 @@ def delete_fn( loglikelihood = state.loglikelihood neg_dead_loglikelihood, dead_idx = jax.lax.top_k(-loglikelihood, num_delete) constraint_loglikelihood = loglikelihood > -neg_dead_loglikelihood.min() - weights = jnp.array(constraint_loglikelihood, dtype=jnp.float32) + weights = jnp.array(constraint_loglikelihood) weights = jnp.where(weights.sum() > 0.0, weights, jnp.ones_like(weights)) start_idx = jax.random.choice( rng_key, @@ -425,24 +249,3 @@ def delete_fn( ) target_update_idx = dead_idx return dead_idx, target_update_idx, start_idx - - -def update_ns_runtime_info( - logX: Array, logZ: Array, loglikelihood: Array, dead_loglikelihood: Array -) -> tuple[Array, Array, Array]: - num_particles = len(loglikelihood) - num_deleted = len(dead_loglikelihood) - num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) - delta_logX = -1 / num_live - logX = logX + jnp.cumsum(delta_logX) - log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX)) - log_delta_Z = dead_loglikelihood + log_delta_X - - delta_logZ = logsumexp(log_delta_Z) - logZ = jnp.logaddexp(logZ, delta_logZ) - logZ_live = logmeanexp(loglikelihood) + logX[-1] - return logX[-1], logZ, logZ_live - - -def logmeanexp(x: Array) -> Array: - return logsumexp(x) - jnp.log(len(x)) diff --git a/blackjax/ns/from_mcmc.py b/blackjax/ns/from_mcmc.py new file mode 100644 index 000000000..3c6e1fb94 --- /dev/null +++ b/blackjax/ns/from_mcmc.py @@ -0,0 +1,138 @@ +from functools import partial +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.ns.adaptive import build_kernel as build_adaptive_kernel +from blackjax.ns.base import delete_fn as default_delete_fn + + +class MCMCUpdateInfo(NamedTuple): + """Thin layer to hold all the info pertaining to the update step.""" + + mcmc_states: NamedTuple + mcmc_infos: NamedTuple + + +class ConstrainedMCMCInfo(NamedTuple): + """Info for a constrained MCMC proposal.""" + + info: NamedTuple + is_accepted: jnp.ndarray + num_trials: jnp.ndarray + + +def update_with_mcmc_take_last( + constrained_mcmc_step_fn, + num_mcmc_steps, +): + """An update strategy for NS that uses MCMC to update the particles. + For now we will not keep the states as they will be too large to store. + Similar to the update_and_take_last from SMC. + + Parameters + ---------- + constrained_mcmc_step_fn + Wrapped MCMC step function that enforces the NS likelihood constraint. + num_mcmc_steps + Number of MCMC proposals per particle. + """ + + def update_function(rng_key, state, loglikelihood_0, **step_parameters): + shared_mcmc_step_fn = partial( + constrained_mcmc_step_fn, + loglikelihood_0=loglikelihood_0, + **step_parameters, + ) + + def mcmc_kernel(rng_key, state): + keys = jax.random.split(rng_key, num_mcmc_steps) + + def body_fn(state, rng_key): + new_state, info = shared_mcmc_step_fn(rng_key, state) + return new_state, info + + final_state, infos = jax.lax.scan(body_fn, state, keys) + return final_state, infos # MCMCUpdateInfo(final_state, infos) + + return jax.vmap(mcmc_kernel)(rng_key, state) + + return update_function + + +def build_kernel( + init_state_fn: Callable, + logdensity_fn: Callable, + mcmc_init_fn: Callable, + mcmc_step_fn: Callable, + num_inner_steps: int, + update_inner_kernel_params_fn: Callable, + num_delete: int = 1, + delete_fn: Callable = default_delete_fn, +) -> Callable: + """Builds a Nested Sampling kernel wrapping any MCMC algorithm.""" + + def constrained_mcmc_step_fn(rng_key, state, loglikelihood_0, **params): + def propose_once(rng_key, current_state): + rng_key, step_key = jax.random.split(rng_key) + mcmc_state = mcmc_init_fn(current_state.position, logdensity_fn) + new_mcmc_state, mcmc_info = mcmc_step_fn( + step_key, mcmc_state, logdensity_fn, **params + ) + proposed_state = init_state_fn( + new_mcmc_state.position, loglikelihood_birth=loglikelihood_0 + ) + within_contour = proposed_state.loglikelihood > loglikelihood_0 + proposal_accepted = getattr(mcmc_info, "is_accepted", True) + is_accepted = proposal_accepted & within_contour + new_state = jax.lax.cond( + is_accepted, + lambda _: proposed_state, + lambda _: current_state, + operand=None, + ) + return ( + rng_key, + new_state, + mcmc_info, + is_accepted, + jnp.array(1, dtype=jnp.int32), + ) + + rng_key, state, mcmc_info, is_accepted, trials = propose_once(rng_key, state) + + def cond_fn(carry): + _, _, _, accepted, _ = carry + return ~accepted + + def body_fn(carry): + rng_key, current_state, _, _, trials = carry + rng_key, new_state, new_info, is_accepted, new_trials = propose_once( + rng_key, current_state + ) + return ( + rng_key, + new_state, + new_info, + is_accepted, + trials + new_trials, + ) + + rng_key, state, mcmc_info, is_accepted, trials = jax.lax.while_loop( + cond_fn, body_fn, (rng_key, state, mcmc_info, is_accepted, trials) + ) + + mcmc_info = ConstrainedMCMCInfo(mcmc_info, is_accepted, trials) + return state, mcmc_info, trials + + inner_kernel = update_with_mcmc_take_last(constrained_mcmc_step_fn, num_inner_steps) + + delete_fn = partial(delete_fn, num_delete=num_delete) + + kernel = build_adaptive_kernel( + delete_fn, + inner_kernel, + update_inner_kernel_params_fn=update_inner_kernel_params_fn, + ) + return kernel diff --git a/blackjax/ns/integrator.py b/blackjax/ns/integrator.py new file mode 100644 index 000000000..ec4bfab61 --- /dev/null +++ b/blackjax/ns/integrator.py @@ -0,0 +1,118 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evidence integration for Nested Sampling. + +This module provides utilities for tracking the evidence integral during +a Nested Sampling run. The NSIntegrator accumulates statistics as the algorithm +compresses the prior volume, computing the marginal likelihood (evidence), +information gain (entropy), and related quantities. +""" + +from typing import NamedTuple + +import jax.numpy as jnp +from jax.scipy.special import logsumexp + +from blackjax.ns.base import StateWithLogLikelihood +from blackjax.ns.utils import log1mexp +from blackjax.types import Array + +__all__ = ["NSIntegrator", "init_integrator", "update_integrator"] + + +class NSIntegrator(NamedTuple): + """Integrator for computing the evidence integral in Nested Sampling. + + This accumulates statistics over the course of a Nested Sampling run, + computing the evidence (marginal likelihood) and related quantities + from the history of dead particles. These are derived quantities that + can be reconstructed from the dead particle history. + + Attributes + ---------- + logX + The log of the current prior volume estimate. + logZ + The accumulated log evidence estimate from the "dead" points. + logZ_live + The current estimate of the log evidence contribution from the live points. + """ + + logX: Array + logZ: Array + logZ_live: Array + + +def init_integrator(particle_state: StateWithLogLikelihood) -> NSIntegrator: + """Initialize the evidence integrator from the initial live points. + + Parameters + ---------- + particle_state + The initial NSState containing the live particles. + + Returns + ------- + NSIntegrator + The initial integrator with logX=0, logZ=-inf, and logZ_live computed + from the initial live points. + """ + logX = jnp.array(0.0) + logZ = jnp.array(-jnp.inf) + logZ_live = _logmeanexp(particle_state.loglikelihood) + logX + return NSIntegrator(logX, logZ, logZ_live) + + +def update_integrator( + integrator: NSIntegrator, + particle_state: StateWithLogLikelihood, + dead_particles: StateWithLogLikelihood, +) -> NSIntegrator: + """Update the evidence integrator after a Nested Sampling step. + + Parameters + ---------- + integrator + The current integrator state. + live_state + The updated live state after the NS step. + dead_info + Information about the particles that died in this step. + + Returns + ------- + NSIntegrator + The updated integrator with new logX, logZ, and logZ_live. + """ + loglikelihood = particle_state.loglikelihood + dead_loglikelihood = dead_particles.loglikelihood + + num_particles = len(loglikelihood) + num_deleted = len(dead_loglikelihood) + num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) + delta_logX = -1 / num_live + logX = integrator.logX + jnp.cumsum(delta_logX) + log_delta_X = logX + log1mexp(delta_logX) + log_delta_Z = dead_loglikelihood + log_delta_X + + delta_logZ = logsumexp(log_delta_Z) + logZ = jnp.logaddexp(integrator.logZ, delta_logZ) + logZ_live = _logmeanexp(loglikelihood) + logX[-1] + return NSIntegrator(logX[-1], logZ, logZ_live) + + +def _logmeanexp(x: Array) -> Array: + """Compute log(mean(exp(x))) in a numerically stable way.""" + n = jnp.array(x.shape[0]) + return logsumexp(x) - jnp.log(n) diff --git a/blackjax/ns/nss.py b/blackjax/ns/nss.py index 236138fa2..bb30807c9 100644 --- a/blackjax/ns/nss.py +++ b/blackjax/ns/nss.py @@ -12,15 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Nested Slice Sampling (NSS) algorithm. - -This module implements the Nested Slice Sampling algorithm, which combines the -Nested Sampling framework with an inner Hit-and-Run Slice Sampling (HRSS) kernel -for exploring the constrained prior distribution at each likelihood level. - -The key idea is to leverage the efficiency of slice sampling for constrained -sampling tasks. The parameters of the HRSS kernel, specifically the covariance -matrix for proposing slice directions, are adaptively tuned based on the current -set of live particles. +A specific implementation of Nested Sampling that uses +Hit-and-Run Slice Sampling (HRSS) as the inner MCMC kernel. """ from functools import partial @@ -28,216 +21,124 @@ import jax import jax.numpy as jnp -from jax.flatten_util import ravel_pytree from blackjax import SamplingAlgorithm -from blackjax.mcmc.ss import SliceState from blackjax.mcmc.ss import build_kernel as build_slice_kernel -from blackjax.mcmc.ss import default_stepper_fn -from blackjax.mcmc.ss import ( - sample_direction_from_covariance as ss_sample_direction_from_covariance, -) +from blackjax.mcmc.ss import sample_direction_from_covariance from blackjax.ns.adaptive import build_kernel as build_adaptive_kernel from blackjax.ns.adaptive import init from blackjax.ns.base import NSInfo, NSState from blackjax.ns.base import delete_fn as default_delete_fn -from blackjax.ns.base import new_state_and_info -from blackjax.ns.utils import get_first_row, repeat_kernel -from blackjax.smc.tuning.from_particles import ( - particles_as_rows, - particles_covariance_matrix, -) -from blackjax.types import ArrayTree, PRNGKey +from blackjax.ns.base import init_state_strategy +from blackjax.ns.from_mcmc import update_with_mcmc_take_last +from blackjax.smc.tuning.from_particles import particles_covariance_matrix +from blackjax.types import ArrayTree __all__ = [ - "init", "as_top_level_api", "build_kernel", + "init", + "update_inner_kernel_params", ] -def sample_direction_from_covariance( - rng_key: PRNGKey, params: Dict[str, ArrayTree] -) -> ArrayTree: - """Default function to generate a normalized slice direction for NSS. - - This function is designed to work with covariance parameters adapted by - `default_adapt_direction_params_fn`. It expects `params` to contain - 'cov', a PyTree structured identically to a single particle. Each leaf - of this 'cov' PyTree contains rows of the full covariance matrix that - correspond to that leaf's elements in the flattened particle vector. - (Specifically, if the full DxD covariance matrix of flattened particles is - `M_flat`, and `unravel_fn` un-flattens a D-vector to the particle PyTree, - then the input `cov` is effectively `jax.vmap(unravel_fn)(M_flat)`). - - The function reassembles the full (D,D) covariance matrix from this - PyTree structure. It then samples a flat direction vector `d_flat` from - a multivariate Gaussian $\\mathcal{N}(0, M_{reassembled})$, normalizes - `d_flat` using the Mahalanobis norm defined by $M_{reassembled}^{-1}$, - and finally un-flattens this normalized direction back into the - particle's PyTree structure using an `unravel_fn` derived from the - particle structure. +def default_stepper_fn(x: ArrayTree, d: ArrayTree, t: float) -> tuple[ArrayTree, bool]: + """A simple stepper function that moves from `x` along direction `d` by `t` units. + + Implements the operation: `x_new = x + t * d`. Parameters ---------- - rng_key - A JAX PRNG key. - params - Keyword arguments, must contain: - - `cov`: A PyTree (structured like a particle) whose leaves are rows - of the covariance matrix, typically output by - `compute_covariance_from_particles`. + x + The starting position (PyTree). + d + The direction of movement (PyTree, same structure as `x`). + t + The scalar step size or distance along the direction. Returns ------- - ArrayTree - A Mahalanobis-normalized direction vector (PyTree, matching the - structure of a single particle), to be used by the slice sampler. + tuple[ArrayTree, bool] + A tuple containing the new position and whether the step was accepted. """ - cov = params["cov"] - row = get_first_row(cov) - _, unravel_fn = ravel_pytree(row) - cov = particles_as_rows(cov) - d = ss_sample_direction_from_covariance(rng_key, cov) - return unravel_fn(d) + return jax.tree.map(lambda x, d: x + t * d, x, d), True -def compute_covariance_from_particles( +def update_inner_kernel_params( + rng_key: jax.random.PRNGKey, state: NSState, info: NSInfo, inner_kernel_params: Optional[Dict[str, ArrayTree]] = None, ) -> Dict[str, ArrayTree]: - """Default function to adapt/tune the slice direction proposal parameters. + """Update inner kernel parameters from current particles. - This function computes the empirical covariance matrix from the current set of - live particles in `state.particles`. This covariance matrix is then returned - and can be used by the slice direction generation function (e.g., - `default_generate_slice_direction_fn`) in the next Nested Sampling iteration. + Computes the empirical covariance matrix from the live particles + for use in slice direction proposals. Parameters ---------- state - The current `NSState` of the Nested Sampler, containing the live particles. + The current NSState containing live particles. info - The `NSInfo` from the last Nested Sampling step (currently unused by this function). + Information from the last NS step (unused but kept for interface consistency). inner_kernel_params - A dictionary of parameters for the inner kernel (currently unused by this function). + Previous inner kernel parameters (unused but kept for interface consistency). Returns ------- Dict[str, ArrayTree] - A dictionary `{'cov': cov_pytree}`. `cov_pytree` is a PyTree with the - same structure as a single particle. If the full DxD covariance matrix - of the flattened particles is `M_flat`, and `unravel_fn` is the function - to un-flatten a D-vector to the particle's PyTree structure, then - `cov_pytree` is equivalent to `jax.vmap(unravel_fn)(M_flat)`. - This means each leaf of `cov_pytree` will have a shape `(D, *leaf_original_dims)`. + Dictionary containing updated 'cov' (covariance matrix). """ - cov_matrix = jnp.atleast_2d(particles_covariance_matrix(state.particles)) - single_particle = get_first_row(state.particles) - _, unravel_fn = ravel_pytree(single_particle) - cov_pytree = jax.vmap(unravel_fn)(cov_matrix) - return {"cov": cov_pytree} + return { + "cov": jnp.atleast_2d(particles_covariance_matrix(state.particles.position)) + } def build_kernel( - logprior_fn: Callable, - loglikelihood_fn: Callable, + init_state_fn: Callable, num_inner_steps: int, num_delete: int = 1, stepper_fn: Callable = default_stepper_fn, - adapt_direction_params_fn: Callable = compute_covariance_from_particles, generate_slice_direction_fn: Callable = sample_direction_from_covariance, + update_inner_kernel_params_fn: Callable = update_inner_kernel_params, + delete_fn: Callable = default_delete_fn, max_steps: int = 10, max_shrinkage: int = 100, ) -> Callable: """Builds the Nested Slice Sampling kernel. - This function creates a Nested Slice Sampling kernel that uses - Hit-and-Run Slice Sampling (HRSS) as its inner kernel. The parameters - for the HRSS direction proposal (specifically, the covariance matrix) - are adaptively tuned at each step using `adapt_direction_params_fn`. - - Parameters - ---------- - logprior_fn - A function that computes the log-prior probability of a single particle. - loglikelihood_fn - A function that computes the log-likelihood of a single particle. - num_inner_steps - The number of HRSS steps to run for each new particle generation. - This should be a multiple of the dimension of the parameter space. - num_delete - The number of particles to delete and replace at each NS step. - Defaults to 1. - stepper_fn - The stepper function `(x, direction, t) -> x_new` for the HRSS kernel. - Defaults to `default_stepper_fn`. - adapt_direction_params_fn - A function `(ns_state, ns_info) -> dict_of_params` that computes/adapts - the parameters (e.g., covariance matrix) for the slice direction proposal, - based on the current NS state. Defaults to `compute_covariance_from_particles`. - generate_slice_direction_fn - A function `(rng_key, **params) -> direction_pytree` that generates a - normalized direction for HRSS, using parameters from `adapt_direction_params_fn`. - Defaults to `sample_direction_from_covariance`. - max_steps - The maximum number of steps to take when expanding the interval in - each direction during the stepping-out phase. Defaults to 10. - max_shrinkage - The maximum number of shrinking steps to perform to avoid infinite loops. - Defaults to 100. - - Returns - ------- - Callable - A kernel function for Nested Slice Sampling that takes an `rng_key` and - the current `NSState` and returns a tuple containing the new `NSState` and - the `NSInfo` for the step. + see `as_top_level_api` for parameter descriptions. """ - slice_kernel = build_slice_kernel(stepper_fn, max_steps, max_shrinkage) - - @repeat_kernel(num_inner_steps) - def inner_kernel( - rng_key, state, logprior_fn, loglikelihood_fn, loglikelihood_0, params - ): - # Do constrained slice sampling - slice_state = SliceState( - position=state.position, - logdensity=state.logprior, - constraint=jnp.array([state.loglikelihood]), - ) + def constrained_mcmc_slice_fn(rng_key, state, loglikelihood_0, **params): rng_key, prop_key = jax.random.split(rng_key, 2) - d = generate_slice_direction_fn(prop_key, params) - logdensity_fn = logprior_fn - constraint_fn = lambda x: jnp.array([loglikelihood_fn(x)]) - constraint = jnp.array([loglikelihood_0]) - strict = jnp.array([True]) - new_slice_state, slice_info = slice_kernel( - rng_key, slice_state, logdensity_fn, d, constraint_fn, constraint, strict + d = generate_slice_direction_fn(prop_key, state.position, **params) + + def slice_fn(t) -> tuple[NSState, bool]: + x, step_accepted = stepper_fn(state.position, d, t) + new_state = init_state_fn(x, loglikelihood_birth=loglikelihood_0) + in_contour = new_state.loglikelihood > loglikelihood_0 + is_accepted = in_contour & step_accepted + return new_state, is_accepted + + slice_kernel = build_slice_kernel( + slice_fn, + max_steps=max_steps, + max_shrinkage=max_shrinkage, ) + new_slice_state, slice_info = slice_kernel(rng_key, state) + return new_slice_state, slice_info - # Pass the relevant information back to PartitionedState and PartitionedInfo - return new_state_and_info( - position=new_slice_state.position, - logprior=new_slice_state.logdensity, - loglikelihood=new_slice_state.constraint[0], - info=slice_info, - ) - - delete_fn = partial(default_delete_fn, num_delete=num_delete) + inner_kernel = update_with_mcmc_take_last( + constrained_mcmc_slice_fn, num_inner_steps + ) - # Vectorize the inner kernel for parallel execution - in_axes = (0, 0, None, None, None, None) + delete_fn = partial(delete_fn, num_delete=num_delete) - update_inner_kernel_params_fn = adapt_direction_params_fn kernel = build_adaptive_kernel( - logprior_fn, - loglikelihood_fn, delete_fn, - jax.vmap(inner_kernel, in_axes=in_axes), - update_inner_kernel_params_fn, + inner_kernel, + update_inner_kernel_params_fn=update_inner_kernel_params_fn, ) return kernel @@ -248,17 +149,19 @@ def as_top_level_api( num_inner_steps: int, num_delete: int = 1, stepper_fn: Callable = default_stepper_fn, - adapt_direction_params_fn: Callable = compute_covariance_from_particles, generate_slice_direction_fn: Callable = sample_direction_from_covariance, + init_state_strategy_fn: Callable = init_state_strategy, + update_inner_kernel_params_fn: Callable = update_inner_kernel_params, + delete_fn: Callable = default_delete_fn, max_steps: int = 10, max_shrinkage: int = 100, ) -> SamplingAlgorithm: - """Creates an adaptive Nested Slice Sampling (NSS) algorithm. + """Creates a Nested Slice Sampling (NSS) algorithm. This function configures a Nested Sampling algorithm that uses Hit-and-Run Slice Sampling (HRSS) as its inner kernel. The parameters for the HRSS - direction proposal (specifically, the covariance matrix) are adaptively tuned - at each step using `adapt_direction_params_fn`. + direction proposal (specifically, the covariance matrix) are managed + externally using `init_inner_kernel_params` and `update_inner_kernel_params`. Parameters ---------- @@ -273,16 +176,15 @@ def as_top_level_api( The number of particles to delete and replace at each NS step. Defaults to 1. stepper_fn - The stepper function `(x, direction, t) -> x_new` for the HRSS kernel. - Defaults to `default_stepper`. - adapt_direction_params_fn - A function `(ns_state, ns_info) -> dict_of_params` that computes/adapts - the parameters (e.g., covariance matrix) for the slice direction proposal, - based on the current NS state. Defaults to `compute_covariance_from_particles`. + The stepper function `(x, direction, t) -> (x_new, is_accepted)` for the HRSS kernel. + Defaults to `default_stepper_fn`. generate_slice_direction_fn - A function `(rng_key, **params) -> direction_pytree` that generates a - normalized direction for HRSS, using parameters from `adapt_direction_params_fn`. - Defaults to `sample_direction_from_covariance`. + A function `(rng_key, position, **kwargs) -> direction_pytree` that generates a + normalized direction for HRSS. Keyword arguments are unpacked from the + inner_kernel_params dict. Defaults to `sample_direction_from_covariance`. + init_state_strategy_fn + A function to initialize NSState from positions. + Defaults to `init_state_strategy`. max_steps The maximum number of steps to take when expanding the interval in each direction during the stepping-out phase. Defaults to 10. @@ -294,31 +196,37 @@ def as_top_level_api( ------- SamplingAlgorithm A `SamplingAlgorithm` tuple containing `init` and `step` functions for - the configured Nested Slice Sampler. The state managed by this - algorithm is `NSState`. + the configured Nested Slice Sampler. The step function signature is + `step(rng_key, state, inner_kernel_params) -> (new_state, info)`. """ + init_state_fn = partial( + init_state_strategy_fn, + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + ) kernel = build_kernel( - logprior_fn, - loglikelihood_fn, + init_state_fn, num_inner_steps, num_delete, stepper_fn=stepper_fn, - adapt_direction_params_fn=adapt_direction_params_fn, generate_slice_direction_fn=generate_slice_direction_fn, + update_inner_kernel_params_fn=update_inner_kernel_params_fn, + delete_fn=delete_fn, max_steps=max_steps, max_shrinkage=max_shrinkage, ) def init_fn(position, rng_key=None): # Vectorize the functions for parallel evaluation over particles + # vmap maps over positional args, keyword args (like loglikelihood_birth) are broadcast return init( position, - logprior_fn=jax.vmap(logprior_fn), - loglikelihood_fn=jax.vmap(loglikelihood_fn), - update_inner_kernel_params_fn=adapt_direction_params_fn, + init_state_fn=jax.vmap(init_state_fn), + update_inner_kernel_params_fn=update_inner_kernel_params_fn, ) - step_fn = kernel + def step_fn(rng_key, state): + return kernel(rng_key, state) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/utils.py b/blackjax/ns/utils.py index 15f3ef35b..2bdebbcb0 100644 --- a/blackjax/ns/utils.py +++ b/blackjax/ns/utils.py @@ -11,14 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for Nested Sampling. - -This module provides helper functions for common tasks associated with Nested -Sampling, such as calculating log-volumes, log-weights, effective sample sizes, -and post-processing of results. +"""Utility functions for Nested Sampling post-processing. """ -import functools from typing import Callable, Dict, Tuple import jax @@ -29,28 +24,7 @@ def log1mexp(x: Array) -> Array: - """Computes log(1 - exp(x)) in a numerically stable way. - - This function implements the algorithm from Mächler (2012) [1]_ for computing - log(1 - exp(x)) while avoiding precision issues, especially when x is close to 0. - - Parameters - ---------- - x - Input array or scalar. Values in x should be less than or equal to 0; - the function returns `jnp.nan` for `x > 0`. - - Returns - ------- - Array - The value of log(1 - exp(x)). - - References - ---------- - .. [1] Mächler, M. (2012). Accurately computing log(1-exp(-|a|)). - CRAN R project, package Rmpfr, vignette log1mexp-note.pdf. - https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf - """ + """Computes log(1 - exp(x)) in a numerically stable way.""" return jnp.where( x > -0.6931472, # approx log(2) jnp.log(-jnp.expm1(x)), @@ -61,27 +35,9 @@ def log1mexp(x: Array) -> Array: def compute_num_live(info: NSInfo) -> Array: """Compute the effective number of live points at each death contour. - In Nested Sampling, especially with batch deletions (k > 1), the conceptual - number of live points changes with each individual particle considered "dead" - within that batch. - - The function works by: - 1. Creating "birth" events (particle added to live set, count +1) and "death" - events (particle removed, count -1). - 2. Sorting all events by their log-likelihood. In case of ties, birth events - can be processed before death events by sorting on the count type (1 before -1), - though the primary sort is logL. - 3. Computing the cumulative sum of these +1/-1 counts. This gives the number - of particles with log-likelihood greater than or equal to the current event's logL. - 4. For each death event, this cumulative sum (plus 1, because the dead particle itself - was live just before its "death") represents `m*_i`. - - Parameters - ---------- - info - An `NSInfo` object (or a PyTree with compatible `loglikelihood_birth` - and `loglikelihood` fields, typically from a concatenated history of NS steps) - containing the birth and death log-likelihoods of particles. + When doing batch deletions, the jump in energy level can be smoothed by + transforming 1 jump of size k into k jumps of size 1. This function computes + the effective population size associated with this transformation. Returns ------- @@ -90,15 +46,11 @@ def compute_num_live(info: NSInfo) -> Array: points `m*_i` when the j-th particle (in the sorted list of dead particles) was considered "dead". """ - birth_logL = info.loglikelihood_birth - death_logL = info.loglikelihood + birth_logL = info.particles.loglikelihood_birth + death_logL = info.particles.loglikelihood - birth_events = jnp.column_stack( - (birth_logL, jnp.ones_like(birth_logL, dtype=jnp.int32)) - ) - death_events = jnp.column_stack( - (death_logL, -jnp.ones_like(death_logL, dtype=jnp.int32)) - ) + birth_events = jnp.column_stack((birth_logL, jnp.ones_like(birth_logL, dtype=int))) + death_events = jnp.column_stack((death_logL, -jnp.ones_like(death_logL, dtype=int))) combined = jnp.concatenate([birth_events, death_events], axis=0) logL_col = combined[:, 0] n_col = combined[:, 1] @@ -111,18 +63,15 @@ def compute_num_live(info: NSInfo) -> Array: cumsum = jnp.maximum(cumsum, 0) death_mask_sorted = sorted_n_col == -1 num_live = cumsum[death_mask_sorted] + 1 - return num_live def logX(rng_key: PRNGKey, dead_info: NSInfo, shape: int = 100) -> tuple[Array, Array]: """Simulate the stochastic evolution of log prior volumes. - This function estimates the sequence of log prior volumes `logX_i` and the - log prior volume elements `log(dX_i)` associated with each dead particle. - For each dead particle `i`, the change in log volume is modeled as - `delta_logX_i = log(u_i) / m*_i`, where `u_i` is a standard uniform random - variable and `m*_i` is the effective number of live points when particle `i` died. + Wraps the effective population size in `compute_num_live`, along with stochastic + simulation of the log prior shrinkage associated with each deleted particle. + Parameters ---------- @@ -147,13 +96,11 @@ def logX(rng_key: PRNGKey, dead_info: NSInfo, shape: int = 100) -> tuple[Array, `dX_i` is approximately `X_i - X_{i+1}`. """ rng_key, subkey = jax.random.split(rng_key) - min_val = jnp.finfo(dead_info.loglikelihood.dtype).tiny - r = jnp.log( - jax.random.uniform( - subkey, shape=(dead_info.loglikelihood.shape[0], shape) - ).clip(min_val, 1 - min_val) + u = jax.random.uniform( + subkey, + shape=(dead_info.particles.loglikelihood.shape[0], shape), ) - + r = jax.lax.log1p(jax.lax.neg(u)) num_live = compute_num_live(dead_info) t = r / num_live[:, jnp.newaxis] logX = jnp.cumsum(t, axis=0) @@ -170,11 +117,6 @@ def log_weights( ) -> Array: """Calculate the log importance weights for Nested Sampling results. - The importance weight for each dead particle `i` is `w_i = dX_i * L_i^beta`, - where `dX_i` is the prior volume element associated with the particle and - `L_i` is its likelihood. This function computes `log(w_i)` using stochastically - simulated `log(dX_i)` values. - Parameters ---------- rng_key @@ -195,24 +137,18 @@ def log_weights( An array of log importance weights, shape `(num_dead_particles, *shape)`. The original order of particles in `dead_info` is preserved. """ - sort_indices = jnp.argsort(dead_info.loglikelihood) + sort_indices = jnp.argsort(dead_info.particles.loglikelihood) unsort_indices = jnp.empty_like(sort_indices) unsort_indices = unsort_indices.at[sort_indices].set(jnp.arange(len(sort_indices))) dead_info_sorted = jax.tree.map(lambda x: x[sort_indices], dead_info) _, log_dX = logX(rng_key, dead_info_sorted, shape) - log_w = log_dX + beta * dead_info_sorted.loglikelihood[..., jnp.newaxis] + log_w = log_dX + beta * dead_info_sorted.particles.loglikelihood[..., jnp.newaxis] return log_w[unsort_indices] -def finalise(live: NSState, dead: list[NSInfo]) -> NSInfo: +def finalise(live: NSState, dead: list[NSInfo], update_info: bool = True) -> NSInfo: """Combines the history of dead particle information with the final live points. - At the end of a Nested Sampling run, the remaining live points are treated - as if they were the next set of "dead" points to complete the evidence - integral and posterior sample set. This function concatenates the `NSInfo` - objects accumulated for dead particles throughout the run with a new `NSInfo` - object created from the final live particles in `live`. - Parameters ---------- live @@ -230,36 +166,29 @@ def finalise(live: NSState, dead: list[NSInfo]) -> NSInfo: for the final live points' `update_info` (as a placeholder). """ - all_pytrees_to_combine = dead + [ - NSInfo( - live.particles, - live.loglikelihood, - live.loglikelihood_birth, - live.logprior, - dead[-1].inner_kernel_info, + if update_info: + update_infos = [d.update_info for d in dead] + final_update_info = jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *update_infos ) - ] - combined_dead_info = jax.tree.map( - lambda *args: jnp.concatenate(args), - all_pytrees_to_combine[0], - *all_pytrees_to_combine[1:], + else: + final_update_info = None + + particles = [d.particles for d in dead] + [live.particles] + final_particles = jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *particles ) - return combined_dead_info + return NSInfo(final_particles, final_update_info) -def ess(rng_key: PRNGKey, dead_info_map: NSInfo) -> Array: +def ess(rng_key: PRNGKey, dead: NSInfo) -> Array: """Computes the Effective Sample Size (ESS) from log-weights. - The ESS is a measure of the quality of importance samples, indicating - how many independent samples the weighted set is equivalent to. - It's calculated as `(sum w_i)^2 / sum (w_i^2)`. This function computes - the mean ESS across multiple stochastic log-weight samples. - Parameters ---------- rng_key A JAX PRNG key, used by `log_weights`. - dead_info_map + dead An `NSInfo` object containing the full set of dead (and final live) particles, typically the output of `finalise`. @@ -268,7 +197,7 @@ def ess(rng_key: PRNGKey, dead_info_map: NSInfo) -> Array: Array The mean Effective Sample Size, a scalar float. """ - logw = log_weights(rng_key, dead_info_map).mean(axis=-1) + logw = log_weights(rng_key, dead).mean(axis=-1) logw -= logw.max() l_sum_w = jax.scipy.special.logsumexp(logw) l_sum_w_sq = jax.scipy.special.logsumexp(2 * logw) @@ -276,39 +205,23 @@ def ess(rng_key: PRNGKey, dead_info_map: NSInfo) -> Array: return ess -def sample(rng_key: PRNGKey, dead_info_map: NSInfo, shape: int = 1000) -> ArrayTree: +def sample(rng_key: PRNGKey, dead: NSInfo, shape: int = 1000) -> ArrayTree: """Resamples particles according to their importance weights. - This function takes the full set of dead (and final live) particles and - their computed importance weights, and draws `shape` particles with - replacement, where the probability of drawing each particle is proportional - to its weight. This produces an unweighted sample from the target posterior - distribution. - - Parameters - ---------- - rng_key - A JAX PRNG key, used for both `log_weights` and `jax.random.choice`. - dead_info_map - An `NSInfo` object containing the full set of dead (and final live) - particles, typically the output of `finalise`. - shape - The number of posterior samples to draw. Defaults to 1000. - Returns ------- ArrayTree A PyTree of resampled particles, where each leaf has `shape`. """ - logw = log_weights(rng_key, dead_info_map).mean(axis=-1) + logw = log_weights(rng_key, dead).mean(axis=-1) indices = jax.random.choice( rng_key, - dead_info_map.loglikelihood.shape[0], + dead.particles.loglikelihood.shape[0], p=jnp.exp(logw.squeeze() - jnp.max(logw)), shape=(shape,), replace=True, ) - return jax.tree.map(lambda leaf: leaf[indices], dead_info_map.particles) + return jax.tree.map(lambda leaf: leaf[indices], dead.particles) def get_first_row(x: ArrayTree) -> ArrayTree: @@ -332,23 +245,6 @@ def get_first_row(x: ArrayTree) -> ArrayTree: return jax.tree.map(lambda x: x[0], x) -def repeat_kernel(num_repeats: int): - """Decorator to repeat a kernel function multiple times.""" - - def decorator(kernel): - @functools.wraps(kernel) - def repeated_kernel(rng_key: PRNGKey, state, *args, **kwargs): - def body_fn(state, rng_key): - return kernel(rng_key, state, *args, **kwargs) - - keys = jax.random.split(rng_key, num_repeats) - return jax.lax.scan(body_fn, state, keys) - - return repeated_kernel - - return decorator - - def uniform_prior( rng_key: PRNGKey, num_live: int, bounds: Dict[str, Tuple[float, float]] ) -> Tuple[ArrayTree, Callable]: diff --git a/docs/examples/nested_sampling.py b/docs/examples/nested_sampling.py index bb8f05e82..dc69c89df 100644 --- a/docs/examples/nested_sampling.py +++ b/docs/examples/nested_sampling.py @@ -98,7 +98,7 @@ def compute_logZ(mu_L, Sigma_L, logLmax=0, mu_pi=None, Sigma_pi=None): for _ in tqdm.trange(1000): # We track the estimate of the evidence in the live points as logZ_live, and the accumulated sum across all steps in logZ # this gives a handy termination that allows us to stop early - if live.logZ_live - live.logZ < -3: # type: ignore[attr-defined] + if live.integrator.logZ_live - live.integrator.logZ < -3: break rng_key, subkey = jax.random.split(rng_key, 2) live, dead_info = step_fn(subkey, live) @@ -116,5 +116,5 @@ def compute_logZ(mu_L, Sigma_L, logLmax=0, mu_pi=None, Sigma_pi=None): logZs = jax.scipy.special.logsumexp(logw, axis=0) print(f"Analytic evidence: {log_analytic_evidence:.2f}") -print(f"Runtime evidence: {live.logZ:.2f}") # type: ignore[attr-defined] +print(f"Integrated evidence: {live.integrator.logZ:.2f}") print(f"Estimated evidence: {logZs.mean():.2f} +- {logZs.std():.2f}") diff --git a/tests/mcmc/test_slice_sampling.py b/tests/mcmc/test_slice_sampling.py index 54c2a721f..319ed9931 100644 --- a/tests/mcmc/test_slice_sampling.py +++ b/tests/mcmc/test_slice_sampling.py @@ -40,40 +40,16 @@ def test_slice_init(self): expected_logdensity = logdensity_fn(position) chex.assert_trees_all_close(state.logdensity, expected_logdensity) - def test_vertical_slice(self): - """Test vertical slice height sampling""" - key = jax.random.key(123) - position = jnp.array([0.0]) - state = ss.init(position, logdensity_fn) - - # Sample many slice heights - keys = jax.random.split(key, 1000) - new_state, info = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) - - # Heights should be below log density at position - logdens_at_pos = logdensity_fn(position) - self.assertTrue(jnp.all(new_state.logslice <= logdens_at_pos)) - - # Heights should be reasonably distributed - mean_height = jnp.mean(new_state.logslice) - expected_mean = logdens_at_pos - 1.0 # E[log(U)] = -1 for U~Uniform(0,1) - chex.assert_trees_all_close(mean_height, expected_mean, atol=0.1) - @parameterized.parameters([1, 2, 5]) def test_slice_sampling_dimensions(self, ndim): """Test slice sampling in different dimensions""" key = jax.random.key(456) position = jnp.zeros(ndim) - # Simple step function - def stepper_fn(x, d, t): - return x + t * d - - # Build kernel - def direction_fn(rng_key): - return jax.random.normal(rng_key, (ndim,)) + # Build kernel with identity covariance matrix + cov = jnp.eye(ndim) - kernel = ss.build_hrss_kernel(direction_fn, stepper_fn) + kernel = ss.build_hrss_kernel(cov) state = ss.init(position, logdensity_fn) # Take one step @@ -81,51 +57,65 @@ def direction_fn(rng_key): chex.assert_shape(new_state.position, (ndim,)) self.assertIsInstance(new_state.logdensity, (float, jax.Array)) + self.assertIsInstance(info.is_accepted, (bool, jax.Array)) def test_constrained_slice_sampling(self): - """Test slice sampling with constraints""" + """Test slice sampling with constraints via logdensity""" key = jax.random.key(789) position = jnp.array([1.0]) # Start in valid region - def stepper_fn(x, d, t): - return x + t * d + cov = jnp.eye(1) + algorithm = ss.hrss_as_top_level_api(constrained_logdensity, cov) + state = algorithm.init(position) + + # Take multiple steps + for _ in range(10): + key, subkey = jax.random.split(key) + state, info = algorithm.step(subkey, state) + # Should remain in valid region (x > 0) + self.assertTrue(jnp.all(state.position > 0)) - kernel = ss.build_kernel(stepper_fn) - state = ss.init(position, constrained_logdensity) + def test_build_kernel_with_custom_slice_fn(self): + """Test build_kernel with custom slice_fn""" + key = jax.random.key(111) + position = jnp.array([0.0]) + state = ss.init(position, logdensity_fn) - # Direction pointing outward + # Custom slice_fn that samples along a direction direction = jnp.array([1.0]) - # Constraint function - def constraint_fn(x): - return jnp.array([]) # No additional constraints for this test - - new_state, info = kernel( - key, - state, - constrained_logdensity, - direction, - constraint_fn, - jnp.array([]), - jnp.array([]), - ) + def slice_fn(t): + new_position = state.position + t * direction + new_state = ss.SliceState(new_position, logdensity_fn(new_position)) + is_accepted = True + return new_state, is_accepted + + # Build kernel with slice_fn + slice_kernel = ss.build_kernel(slice_fn, max_steps=10, max_shrinkage=100) - # Should remain in valid region - self.assertTrue(jnp.all(new_state.position > 0)) + # Take one step + new_state, info = slice_kernel(key, state) + + chex.assert_shape(new_state.position, (1,)) + self.assertIsInstance(info, ss.SliceInfo) def test_default_direction_generation(self): """Test default direction generation function""" key = jax.random.key(101112) + position = jnp.zeros(3) cov = jnp.eye(3) * 2.0 - direction = ss.sample_direction_from_covariance(key, cov) + direction = ss.sample_direction_from_covariance(key, position, cov) chex.assert_shape(direction, (3,)) - # Direction should be normalized in Mahalanobis sense + # Direction should be normalized in Mahalanobis sense with scaling factor + # The scaling factor is 2 * sqrt(dim + 2) + dim = 3 + expected_norm = 2 * jnp.sqrt(dim + 2) invcov = jnp.linalg.inv(cov) mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) - chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) + chex.assert_trees_all_close(mahal_norm, expected_norm, atol=1e-6) def test_hrss_top_level_api(self): """Test hit-and-run slice sampling top-level API""" @@ -143,6 +133,7 @@ def test_hrss_top_level_api(self): new_state, info = algorithm.step(key, state) chex.assert_shape(new_state.position, (2,)) + self.assertIsInstance(info.is_accepted, (bool, jax.Array)) def test_slice_sampling_statistical_correctness(self): """Test that slice sampling produces correct statistics""" @@ -187,49 +178,70 @@ def test_slice_sampling_statistical_correctness(self): self.assertGreater(sample_std, 0.1, "Standard deviation is too small") self.assertLess(sample_std, 5.0, "Standard deviation is too large") - def test_default_stepper_fn(self): - """Test default stepper function""" - x = jnp.array([1.0, 2.0]) - d = jnp.array([0.5, -0.5]) - t = 2.0 - - result = ss.default_stepper_fn(x, d, t) - expected = x + t * d - - chex.assert_trees_all_close(result, expected) - def test_slice_info_structure(self): """Test that SliceInfo contains expected fields""" key = jax.random.key(789) position = jnp.array([0.0]) - def stepper_fn(x, d, t): - return x + t * d + cov = jnp.eye(1) + algorithm = ss.hrss_as_top_level_api(logdensity_fn, cov) + state = algorithm.init(position) + + new_state, info = algorithm.step(key, state) + + # Check that info has expected structure + self.assertIsInstance(info, ss.SliceInfo) + self.assertTrue(hasattr(info, "is_accepted")) + self.assertTrue(hasattr(info, "num_steps")) + self.assertTrue(hasattr(info, "num_shrink")) + + # Check types + self.assertIsInstance(info.is_accepted, (bool, jax.Array)) + self.assertIsInstance(info.num_steps, (int, jax.Array)) + self.assertIsInstance(info.num_shrink, (int, jax.Array)) + + def test_multimodal_sampling(self): + """Test slice sampling on multimodal distribution""" + key = jax.random.key(999) + position = jnp.array([2.5]) # Start near first mode + + cov = jnp.eye(1) * 4.0 # Large covariance for mode hopping + algorithm = ss.hrss_as_top_level_api(multimodal_logdensity, cov) + state = algorithm.init(position) - kernel = ss.build_kernel(stepper_fn) + # Run a few steps + samples = [] + for _ in range(50): + key, subkey = jax.random.split(key) + state, info = algorithm.step(subkey, state) + samples.append(state.position[0]) + + samples = jnp.array(samples) + + # Just check that sampling works without errors + self.assertFalse(jnp.isnan(samples).any()) + self.assertFalse(jnp.isinf(samples).any()) + + def test_horizontal_slice_basic(self): + """Test horizontal_slice function directly""" + key = jax.random.key(321) + position = jnp.array([0.0]) state = ss.init(position, logdensity_fn) - direction = jnp.array([1.0]) - def constraint_fn(x): - return jnp.array([]) - - new_state, info = kernel( - key, - state, - logdensity_fn, - direction, - constraint_fn, - jnp.array([]), - jnp.array([]), + # Simple slice_fn that accepts positions in [-1, 1] + def slice_fn(t): + new_position = jnp.array([t]) + new_state = ss.SliceState(new_position, logdensity_fn(new_position)) + is_accepted = jnp.abs(t) <= 1.0 + return new_state, is_accepted + + new_state, info = ss.horizontal_slice( + key, state, slice_fn, m=10, max_shrinkage=100 ) - # Check that info has expected structure + # Should find a point within [-1, 1] + self.assertLessEqual(jnp.abs(new_state.position[0]), 1.0) self.assertIsInstance(info, ss.SliceInfo) - self.assertTrue(hasattr(info, "constraint")) - self.assertTrue(hasattr(info, "l_steps")) - self.assertTrue(hasattr(info, "r_steps")) - self.assertTrue(hasattr(info, "s_steps")) - self.assertTrue(hasattr(info, "evals")) if __name__ == "__main__": diff --git a/tests/ns/test_nested_sampling.py b/tests/ns/test_nested_sampling.py index 4280a9d8c..dc9d698c3 100644 --- a/tests/ns/test_nested_sampling.py +++ b/tests/ns/test_nested_sampling.py @@ -20,6 +20,26 @@ def gaussian_loglikelihood(x): return stats.norm.logpdf(x - 1.0).sum() +def make_init_state_fn(logprior_fn, loglikelihood_fn): + """Helper to create init_state_fn from logprior and loglikelihood functions.""" + return functools.partial( + base.init_state_strategy, + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + ) + + +def make_mock_nsinfo(positions, loglikelihood, loglikelihood_birth, logdensity): + """Helper to create NSInfo with correct structure.""" + particles = base.StateWithLogLikelihood( + position=positions, + logdensity=logdensity, + loglikelihood=loglikelihood, + loglikelihood_birth=loglikelihood_birth, + ) + return base.NSInfo(particles=particles, update_info={}) + + def uniform_logprior_2d(x): """Uniform prior on [-5, 5]^2""" return jnp.where(jnp.all(jnp.abs(x) <= 5.0), 0.0, -jnp.inf) @@ -43,23 +63,26 @@ def test_base_ns_init(self): num_live = 50 # Generate initial particles - particles = jax.random.normal(key, (num_live,)) + positions = jax.random.normal(key, (num_live,)) - # Initialize NS state - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + # Initialize NS state using the correct API + init_state_fn = jax.vmap( + make_init_state_fn(gaussian_logprior, gaussian_loglikelihood) + ) + state = base.init(positions, init_state_fn) - # Check state structure - chex.assert_shape(state.particles, (num_live,)) - chex.assert_shape(state.loglikelihood, (num_live,)) - chex.assert_shape(state.logprior, (num_live,)) - chex.assert_shape(state.pid, (num_live,)) + # Check state structure - particles is now a StateWithLogLikelihood + chex.assert_shape(state.particles.position, (num_live,)) + chex.assert_shape(state.particles.loglikelihood, (num_live,)) + chex.assert_shape(state.particles.logdensity, (num_live,)) + chex.assert_shape(state.particles.loglikelihood_birth, (num_live,)) # Check that loglikelihood and logprior are properly computed - expected_loglik = jax.vmap(gaussian_loglikelihood)(particles) - expected_logprior = jax.vmap(gaussian_logprior)(particles) + expected_loglik = jax.vmap(gaussian_loglikelihood)(positions) + expected_logprior = jax.vmap(gaussian_logprior)(positions) - chex.assert_trees_all_close(state.loglikelihood, expected_loglik) - chex.assert_trees_all_close(state.logprior, expected_logprior) + chex.assert_trees_all_close(state.particles.loglikelihood, expected_loglik) + chex.assert_trees_all_close(state.particles.logdensity, expected_logprior) def test_delete_fn(self): """Test particle deletion function""" @@ -67,10 +90,15 @@ def test_delete_fn(self): num_live = 20 num_delete = 3 - particles = jax.random.normal(key, (num_live,)) - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + positions = jax.random.normal(key, (num_live,)) + init_state_fn = jax.vmap( + make_init_state_fn(gaussian_logprior, gaussian_loglikelihood) + ) + state = base.init(positions, init_state_fn) - dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) + dead_idx, target_idx, start_idx = base.delete_fn( + key, state.particles, num_delete + ) # Check correct number of deletions chex.assert_shape(dead_idx, (num_delete,)) @@ -78,8 +106,8 @@ def test_delete_fn(self): chex.assert_shape(start_idx, (num_delete,)) # Check that worst particles are selected - worst_loglik = jnp.sort(state.loglikelihood)[:num_delete] - selected_loglik = state.loglikelihood[dead_idx] + worst_loglik = jnp.sort(state.particles.loglikelihood)[:num_delete] + selected_loglik = state.particles.loglikelihood[dead_idx] chex.assert_trees_all_close(jnp.sort(selected_loglik), worst_loglik) @parameterized.parameters([1, 2, 5]) @@ -88,44 +116,43 @@ def test_ns_step_consistency(self, num_delete): key = jax.random.key(789) num_live = 50 - particles = jax.random.normal(key, (num_live, 2)) - state = base.init( - particles, uniform_logprior_2d, gaussian_mixture_loglikelihood + positions = jax.random.normal(key, (num_live, 2)) + init_state_fn = jax.vmap( + make_init_state_fn(uniform_logprior_2d, gaussian_mixture_loglikelihood) ) + state = base.init(positions, init_state_fn) - # Mock inner kernel for testing - def mock_inner_kernel( - rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, params - ): + # Mock inner kernel for testing - matches new API signature + def mock_inner_kernel(rng_keys, inner_state, loglikelihood_0): + # inner_state is StateWithLogLikelihood # Simple random walk for testing - new_pos = ( - inner_state["position"] - + jax.random.normal(rng_key, inner_state["position"].shape) * 0.1 - ) - new_logprior = logprior_fn(new_pos) - new_loglik = loglikelihood_fn(new_pos) - - new_inner_state = { - "position": new_pos, - "logprior": new_logprior, - "loglikelihood": new_loglik, - } + def single_step(rng_key, state): + new_pos = ( + state.position + + jax.random.normal(rng_key, state.position.shape) * 0.1 + ) + new_state = base.init_state_strategy( + new_pos, + uniform_logprior_2d, + gaussian_mixture_loglikelihood, + loglikelihood_birth=loglikelihood_0, + ) + return new_state + + new_inner_state = jax.vmap(single_step)(rng_keys, inner_state) return new_inner_state, {} delete_fn = functools.partial(base.delete_fn, num_delete=num_delete) - kernel = base.build_kernel( - uniform_logprior_2d, - gaussian_mixture_loglikelihood, - delete_fn, - mock_inner_kernel, - ) + kernel = base.build_kernel(delete_fn, mock_inner_kernel) # Test that the kernel can be constructed with mock components # Full execution would require more complex mocking of inner kernel behavior self.assertTrue(callable(kernel)) # Test delete function works - dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) + dead_idx, target_idx, start_idx = base.delete_fn( + key, state.particles, num_delete + ) chex.assert_shape(dead_idx, (num_delete,)) chex.assert_shape(target_idx, (num_delete,)) chex.assert_shape(start_idx, (num_delete,)) @@ -139,14 +166,16 @@ def test_utils_functions(self): dead_loglik = jnp.sort(jax.random.uniform(key, (n_dead,))) * 10 - 5 dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 2)), + # Create StateWithLogLikelihood for particles + particles = base.StateWithLogLikelihood( + position=jnp.zeros((n_dead, 2)), + logdensity=jnp.zeros(n_dead), loglikelihood=dead_loglik, loglikelihood_birth=dead_loglik_birth, - logprior=jnp.zeros(n_dead), - inner_kernel_info={}, ) + mock_info = base.NSInfo(particles=particles, update_info={}) + # Test compute_num_live num_live = utils.compute_num_live(mock_info) chex.assert_shape(num_live, (n_dead,)) @@ -170,15 +199,17 @@ def test_adaptive_init(self): key = jax.random.key(123) num_live = 30 - particles = jax.random.normal(key, (num_live,)) + positions = jax.random.normal(key, (num_live,)) def mock_update_params_fn(state, info, current_params): return {"test_param": 1.0} + init_state_fn = jax.vmap( + make_init_state_fn(gaussian_logprior, gaussian_loglikelihood) + ) state = adaptive.init( - particles, - gaussian_logprior, - gaussian_loglikelihood, + positions, + init_state_fn, update_inner_kernel_params_fn=mock_update_params_fn, ) @@ -196,25 +227,29 @@ def test_nss_direction_functions(self): key = jax.random.key(456) # Test covariance computation - particles = jax.random.normal(key, (50, 3)) - state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + positions = jax.random.normal(key, (50, 3)) + + def logprior_fn(x): + return stats.norm.logpdf(x).sum() + + def loglikelihood_fn(x): + return stats.norm.logpdf(x).sum() + + init_state_fn = jax.vmap(make_init_state_fn(logprior_fn, loglikelihood_fn)) + state = base.init(positions, init_state_fn) - params = nss.compute_covariance_from_particles(state, None, {}) + # Use update_inner_kernel_params instead of removed init_inner_kernel_params + params = nss.update_inner_kernel_params(state, None, {}) # Check that covariance is computed self.assertIn("cov", params) cov_pytree = params["cov"] chex.assert_shape(cov_pytree, (3, 3)) - # Test direction sampling - direction = nss.sample_direction_from_covariance(key, params) - chex.assert_shape(direction, (3,)) - def test_nss_kernel_construction(self): """Test NSS kernel can be constructed""" - kernel = nss.build_kernel( - gaussian_logprior, gaussian_loglikelihood, num_inner_steps=10 - ) + init_state_fn = make_init_state_fn(gaussian_logprior, gaussian_loglikelihood) + kernel = nss.build_kernel(init_state_fn, num_inner_steps=10) # Test that kernel is callable self.assertTrue(callable(kernel)) @@ -272,12 +307,8 @@ def loglikelihood_fn(x): dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) # Create NSInfo object - mock_info = base.NSInfo( - particles=positions, - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=dead_logprior, - inner_kernel_info={}, + mock_info = make_mock_nsinfo( + positions, dead_loglik, dead_loglik_birth, dead_logprior ) # Generate many evidence estimates for statistical testing @@ -345,22 +376,18 @@ def loglikelihood_fn(x): key = jax.random.key(456) # Initialize particles uniformly in [0, 1] - particles = jax.random.uniform(key, (num_live,)) - state = base.init(particles, logprior_fn, loglikelihood_fn) + positions = jax.random.uniform(key, (num_live,)) + init_state_fn = jax.vmap(make_init_state_fn(logprior_fn, loglikelihood_fn)) + state = base.init(positions, init_state_fn) # Check that initialization worked correctly - self.assertTrue(jnp.all(state.particles >= 0.0)) - self.assertTrue(jnp.all(state.particles <= 1.0)) - self.assertFalse(jnp.any(jnp.isinf(state.logprior))) - self.assertFalse(jnp.any(jnp.isnan(state.loglikelihood))) - - # Test evidence contribution from live points - logZ_live_contribution = state.logZ_live - self.assertIsInstance(logZ_live_contribution, (float, jax.Array)) - self.assertFalse(jnp.isnan(logZ_live_contribution)) + self.assertTrue(jnp.all(state.particles.position >= 0.0)) + self.assertTrue(jnp.all(state.particles.position <= 1.0)) + self.assertFalse(jnp.any(jnp.isinf(state.particles.logdensity))) + self.assertFalse(jnp.any(jnp.isnan(state.particles.loglikelihood))) def test_evidence_monotonicity(self): - """Test that evidence estimates are monotonically increasing during NS run.""" + """Test that we can initialize state and track integrator.""" # Simple setup for testing monotonicity def logprior_fn(x): @@ -372,40 +399,23 @@ def loglikelihood_fn(x): num_live = 30 key = jax.random.key(789) - particles = jax.random.normal(key, (num_live,)) - initial_state = base.init(particles, logprior_fn, loglikelihood_fn) - - # Test that we can track evidence during run - logZ_sequence = [initial_state.logZ] - - # Simulate a few evidence updates manually - for i in range(5): - # Simulate removing worst particle and updating evidence - worst_idx = jnp.argmin(initial_state.loglikelihood) - dead_loglik = initial_state.loglikelihood[worst_idx] - - # Update evidence (simplified) - delta_logX = -1.0 / num_live # Approximate volume decrease - new_logZ = jnp.logaddexp(initial_state.logZ, dead_loglik + delta_logX) - logZ_sequence.append(new_logZ) - - # Update for next iteration (simplified) - new_loglik = jnp.concatenate( - [ - initial_state.loglikelihood[:worst_idx], - initial_state.loglikelihood[worst_idx + 1 :], - jnp.array([dead_loglik + 0.1]), # Mock new particle - ] - ) - initial_state = initial_state._replace(loglikelihood=new_loglik) + positions = jax.random.normal(key, (num_live,)) + init_state_fn = jax.vmap(make_init_state_fn(logprior_fn, loglikelihood_fn)) + initial_state = base.init(positions, init_state_fn) - # Check monotonicity - logZ_array = jnp.array(logZ_sequence) - differences = logZ_array[1:] - logZ_array[:-1] - self.assertTrue( - jnp.all(differences >= -1e-10), - "Evidence should be monotonically increasing", - ) + # Test that we can access particle likelihoods + self.assertIsNotNone(initial_state.particles.loglikelihood) + chex.assert_shape(initial_state.particles.loglikelihood, (num_live,)) + + # For integrator tests, use adaptive state instead + from blackjax.ns import adaptive as adaptive_module + + adaptive_state = adaptive_module.init(positions, init_state_fn) + + # Check integrator exists and has expected fields + self.assertIsNotNone(adaptive_state.integrator) + self.assertIsNotNone(adaptive_state.integrator.logZ) + self.assertIsNotNone(adaptive_state.integrator.logX) def test_nested_sampling_utils_statistical_properties(self): """Test statistical properties of nested sampling utility functions.""" @@ -432,12 +442,8 @@ def test_nested_sampling_utils_statistical_properties(self): # Ensure birth likelihoods don't exceed death likelihoods dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 2)), - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.zeros(n_dead), - inner_kernel_info={}, + mock_info = make_mock_nsinfo( + jnp.zeros((n_dead, 2)), dead_loglik, dead_loglik_birth, jnp.zeros(n_dead) ) # Test compute_num_live @@ -540,12 +546,8 @@ def loglikelihood_fn(x): ) dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) - mock_info = base.NSInfo( - particles=positions, - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=dead_logprior, - inner_kernel_info={}, + mock_info = make_mock_nsinfo( + positions, dead_loglik, dead_loglik_birth, dead_logprior ) # Generate evidence estimates for statistical testing @@ -595,14 +597,11 @@ def test_evidence_integration_simple_case(self): dead_loglik = jnp.full(n_dead, loglik_constant) dead_loglik_birth = jnp.full(n_dead, -jnp.inf) # All from prior - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 1)), - loglikelihood=dead_loglik, - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.full( - n_dead, -jnp.log(prior_width) - ), # Uniform prior log density - inner_kernel_info={}, + mock_info = make_mock_nsinfo( + jnp.zeros((n_dead, 1)), + dead_loglik, + dead_loglik_birth, + jnp.full(n_dead, -jnp.log(prior_width)), # Uniform prior log density ) # Generate many evidence estimates @@ -645,12 +644,11 @@ def test_effective_sample_size_calculation(self): dead_loglik = jax.random.uniform(key, (n_dead,)) * 5 - 10 # Range [-10, -5] dead_loglik_birth = jnp.full(n_dead, -jnp.inf) - mock_info = base.NSInfo( - particles=jnp.zeros((n_dead, 1)), - loglikelihood=jnp.sort(dead_loglik), # Ensure increasing - loglikelihood_birth=dead_loglik_birth, - logprior=jnp.zeros(n_dead), - inner_kernel_info={}, + mock_info = make_mock_nsinfo( + jnp.zeros((n_dead, 1)), + jnp.sort(dead_loglik), # Ensure increasing + dead_loglik_birth, + jnp.zeros(n_dead), ) # Calculate ESS