From b422efda3600d01525555e0b904e4e18d798fe4c Mon Sep 17 00:00:00 2001 From: Will Handley Date: Tue, 30 Sep 2025 07:41:48 +0100 Subject: [PATCH] Add error tracking and information metrics to nested sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add logZ_error, H (information), i_eff, and n_eff fields to NSState - Update runtime info calculation to compute these new metrics - Add converged() utility function for checking termination criteria - Update example to use new convergence function 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/ns/base.py | 56 +++++++++++++++++++++++++++----- blackjax/ns/utils.py | 6 ++++ docs/examples/nested_sampling.py | 4 +-- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/blackjax/ns/base.py b/blackjax/ns/base.py index 978f242d4..29922f5dc 100644 --- a/blackjax/ns/base.py +++ b/blackjax/ns/base.py @@ -66,6 +66,14 @@ class NSState(NamedTuple): The accumulated log evidence estimate from the "dead" points . logZ_live The current estimate of the log evidence contribution from the live points. + logZ_error + The current estimate of the error on logZ. + H + The current estimate of the information (negative entropy) in nats. + i_eff + The effective number of iterations + n_eff + The effective sample size of the current set of live particles. inner_kernel_params A dictionary of parameters for the inner kernel. """ @@ -78,6 +86,10 @@ class NSState(NamedTuple): logX: Array # The current log-volume estimate logZ: Array # The accumulated evidence estimate logZ_live: Array # The current evidence estimate + logZ_error: Array # The current error estimate on logZ + H: Array # The current information estimate + i_eff: Array # The effective number of iterations + n_eff: Array # The effective sample size of the current particles inner_kernel_params: Dict # Parameters for the inner kernel @@ -210,6 +222,10 @@ def init( loglikelihood_birth: Array = -jnp.nan, logX: Optional[Array] = 0.0, logZ: Optional[Array] = -jnp.inf, + logZ_error: Optional[Array] = 0.0, + H: Optional[Array] = 0.0, + i_eff: Optional[Array] = 0.0, + n_eff: Optional[Array] = 0.0, ) -> NSState: """Initializes the Nested Sampler state. @@ -244,6 +260,11 @@ def init( logX = jnp.array(logX, dtype=dtype) logZ = jnp.array(logZ, dtype=dtype) logZ_live = logmeanexp(loglikelihood) + logX + logZ_error = jnp.array(logZ_error, dtype=dtype) + H = jnp.array(H, dtype=dtype) + i_eff = jnp.array(i_eff, dtype=dtype) + n_eff = jnp.array(n_eff, dtype=dtype) + n_eff = jnp.where(n_eff == 0.0, len(loglikelihood), n_eff) inner_kernel_params: Dict = {} return NSState( particles, @@ -254,6 +275,10 @@ def init( logX, logZ, logZ_live, + logZ_error, + H, + i_eff, + n_eff, inner_kernel_params, ) @@ -350,8 +375,8 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: pid = state.pid.at[target_update_idx].set(state.pid[start_idx]) # Update the run-time information - logX, logZ, logZ_live = update_ns_runtime_info( - state.logX, state.logZ, loglikelihood, dead_loglikelihood + logX, logZ, logZ_live, logZ_error, H, i_eff, n_eff = update_ns_runtime_info( + state, loglikelihood, dead_loglikelihood ) # Return updated state and info @@ -364,6 +389,10 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: logX, logZ, logZ_live, + logZ_error, + H, + i_eff, + n_eff, state.inner_kernel_params, ) info = NSInfo( @@ -428,20 +457,29 @@ def delete_fn( def update_ns_runtime_info( - logX: Array, logZ: Array, loglikelihood: Array, dead_loglikelihood: Array -) -> tuple[Array, Array, Array]: + state: NSState, loglikelihood: Array, dead_loglikelihood: Array +) -> tuple[Array, Array, Array, Array, Array, Array, Array]: num_particles = len(loglikelihood) num_deleted = len(dead_loglikelihood) num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) delta_logX = -1 / num_live - logX = logX + jnp.cumsum(delta_logX) - log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX)) + logX = state.logX + jnp.cumsum(delta_logX) + log_delta_X = logX + jnp.log1p(-jnp.exp(delta_logX)) log_delta_Z = dead_loglikelihood + log_delta_X - delta_logZ = logsumexp(log_delta_Z) - logZ = jnp.logaddexp(logZ, delta_logZ) + logZ = jnp.logaddexp(state.logZ, delta_logZ) + A = state.i_eff / state.n_eff + jnp.sum(1 / num_live) + B = state.i_eff / state.n_eff**2 + jnp.sum(1 / num_live**2) + i_eff = A**2 / B + n_eff = A / B + H = jnp.nan_to_num(jnp.exp(state.logZ - logZ) * (state.H + state.logZ), 0.0) + H += jnp.sum(jnp.exp(log_delta_Z - logZ) * dead_loglikelihood) - logZ + logZ_error = jnp.sqrt(H / n_eff) logZ_live = logmeanexp(loglikelihood) + logX[-1] - return logX[-1], logZ, logZ_live + return logX[-1], logZ, logZ_live, logZ_error, H, i_eff, n_eff + + +# H ~ Σ 1/n ± sqrt(Σ 1/n^2) def logmeanexp(x: Array) -> Array: diff --git a/blackjax/ns/utils.py b/blackjax/ns/utils.py index 15f3ef35b..9fced73d4 100644 --- a/blackjax/ns/utils.py +++ b/blackjax/ns/utils.py @@ -396,3 +396,9 @@ def prior_sample(rng_key): particles = jax.vmap(prior_sample)(init_keys) return particles, logprior_fn + + +def converged(live: NSState, precision_criterion: float = jnp.exp(-3)) -> bool: + all_same = jnp.max(live.loglikelihood) == jnp.min(live.loglikelihood) + live_evidence = live.logZ_live - live.logZ < jnp.log(precision_criterion) + return live_evidence | all_same diff --git a/docs/examples/nested_sampling.py b/docs/examples/nested_sampling.py index bb8f05e82..065064cea 100644 --- a/docs/examples/nested_sampling.py +++ b/docs/examples/nested_sampling.py @@ -4,7 +4,7 @@ from jax.scipy.linalg import inv, solve import blackjax -from blackjax.ns.utils import finalise, log_weights +from blackjax.ns.utils import converged, finalise, log_weights # jax.config.update("jax_enable_x64", True) @@ -98,7 +98,7 @@ def compute_logZ(mu_L, Sigma_L, logLmax=0, mu_pi=None, Sigma_pi=None): for _ in tqdm.trange(1000): # We track the estimate of the evidence in the live points as logZ_live, and the accumulated sum across all steps in logZ # this gives a handy termination that allows us to stop early - if live.logZ_live - live.logZ < -3: # type: ignore[attr-defined] + if converged(live): # type: ignore[attr-defined] break rng_key, subkey = jax.random.split(rng_key, 2) live, dead_info = step_fn(subkey, live)