Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions blackjax/ns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ class NSState(NamedTuple):
The accumulated log evidence estimate from the "dead" points .
logZ_live
The current estimate of the log evidence contribution from the live points.
logZ_error
The current estimate of the error on logZ.
H
The current estimate of the information (negative entropy) in nats.
i_eff
The effective number of iterations
n_eff
The effective sample size of the current set of live particles.
inner_kernel_params
A dictionary of parameters for the inner kernel.
"""
Expand All @@ -78,6 +86,10 @@ class NSState(NamedTuple):
logX: Array # The current log-volume estimate
logZ: Array # The accumulated evidence estimate
logZ_live: Array # The current evidence estimate
logZ_error: Array # The current error estimate on logZ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can any of this be calculated “offline”, i.e. in the utils rather than in the online function?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could all be calculated offline, but there is a virtue in accumulation here that we appreciated in the first place for logZ (this is not true for logZ_live, which we do not accumulate, and could be calculated as part of converged). My main reason for doing this is to be able to construct more advanced convergence criteria that e.g. incorporated an evidence error, or a KL-based stopping criterion (which John Skilling prefers).

H: Array # The current information estimate
i_eff: Array # The effective number of iterations
n_eff: Array # The effective sample size of the current particles
inner_kernel_params: Dict # Parameters for the inner kernel


Expand Down Expand Up @@ -210,6 +222,10 @@ def init(
loglikelihood_birth: Array = -jnp.nan,
logX: Optional[Array] = 0.0,
logZ: Optional[Array] = -jnp.inf,
logZ_error: Optional[Array] = 0.0,
H: Optional[Array] = 0.0,
i_eff: Optional[Array] = 0.0,
n_eff: Optional[Array] = 0.0,
) -> NSState:
"""Initializes the Nested Sampler state.

Expand Down Expand Up @@ -244,6 +260,11 @@ def init(
logX = jnp.array(logX, dtype=dtype)
logZ = jnp.array(logZ, dtype=dtype)
logZ_live = logmeanexp(loglikelihood) + logX
logZ_error = jnp.array(logZ_error, dtype=dtype)
H = jnp.array(H, dtype=dtype)
i_eff = jnp.array(i_eff, dtype=dtype)
n_eff = jnp.array(n_eff, dtype=dtype)
n_eff = jnp.where(n_eff == 0.0, len(loglikelihood), n_eff)
inner_kernel_params: Dict = {}
return NSState(
particles,
Expand All @@ -254,6 +275,10 @@ def init(
logX,
logZ,
logZ_live,
logZ_error,
H,
i_eff,
n_eff,
inner_kernel_params,
)

Expand Down Expand Up @@ -350,8 +375,8 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]:
pid = state.pid.at[target_update_idx].set(state.pid[start_idx])

# Update the run-time information
logX, logZ, logZ_live = update_ns_runtime_info(
state.logX, state.logZ, loglikelihood, dead_loglikelihood
logX, logZ, logZ_live, logZ_error, H, i_eff, n_eff = update_ns_runtime_info(
state, loglikelihood, dead_loglikelihood
)

# Return updated state and info
Expand All @@ -364,6 +389,10 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]:
logX,
logZ,
logZ_live,
logZ_error,
H,
i_eff,
n_eff,
state.inner_kernel_params,
)
info = NSInfo(
Expand Down Expand Up @@ -428,20 +457,29 @@ def delete_fn(


def update_ns_runtime_info(
logX: Array, logZ: Array, loglikelihood: Array, dead_loglikelihood: Array
) -> tuple[Array, Array, Array]:
state: NSState, loglikelihood: Array, dead_loglikelihood: Array
) -> tuple[Array, Array, Array, Array, Array, Array, Array]:
num_particles = len(loglikelihood)
num_deleted = len(dead_loglikelihood)
num_live = jnp.arange(num_particles, num_particles - num_deleted, -1)
delta_logX = -1 / num_live
logX = logX + jnp.cumsum(delta_logX)
log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX))
logX = state.logX + jnp.cumsum(delta_logX)
log_delta_X = logX + jnp.log1p(-jnp.exp(delta_logX))
log_delta_Z = dead_loglikelihood + log_delta_X

delta_logZ = logsumexp(log_delta_Z)
logZ = jnp.logaddexp(logZ, delta_logZ)
logZ = jnp.logaddexp(state.logZ, delta_logZ)
A = state.i_eff / state.n_eff + jnp.sum(1 / num_live)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this has to be all calculated online, would it be useful to put any of the calculation logic into the utils, in case they are useful for various post hoc analysis bits

B = state.i_eff / state.n_eff**2 + jnp.sum(1 / num_live**2)
i_eff = A**2 / B
n_eff = A / B
H = jnp.nan_to_num(jnp.exp(state.logZ - logZ) * (state.H + state.logZ), 0.0)
H += jnp.sum(jnp.exp(log_delta_Z - logZ) * dead_loglikelihood) - logZ
logZ_error = jnp.sqrt(H / n_eff)
logZ_live = logmeanexp(loglikelihood) + logX[-1]
return logX[-1], logZ, logZ_live
return logX[-1], logZ, logZ_live, logZ_error, H, i_eff, n_eff


# H ~ Σ 1/n ± sqrt(Σ 1/n^2)


def logmeanexp(x: Array) -> Array:
Expand Down
6 changes: 6 additions & 0 deletions blackjax/ns/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,9 @@ def prior_sample(rng_key):
particles = jax.vmap(prior_sample)(init_keys)

return particles, logprior_fn


def converged(live: NSState, precision_criterion: float = jnp.exp(-3)) -> bool:
all_same = jnp.max(live.loglikelihood) == jnp.min(live.loglikelihood)
live_evidence = live.logZ_live - live.logZ < jnp.log(precision_criterion)
return live_evidence | all_same
4 changes: 2 additions & 2 deletions docs/examples/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.scipy.linalg import inv, solve

import blackjax
from blackjax.ns.utils import finalise, log_weights
from blackjax.ns.utils import converged, finalise, log_weights

# jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -98,7 +98,7 @@ def compute_logZ(mu_L, Sigma_L, logLmax=0, mu_pi=None, Sigma_pi=None):
for _ in tqdm.trange(1000):
# We track the estimate of the evidence in the live points as logZ_live, and the accumulated sum across all steps in logZ
# this gives a handy termination that allows us to stop early
if live.logZ_live - live.logZ < -3: # type: ignore[attr-defined]
if converged(live): # type: ignore[attr-defined]
break
rng_key, subkey = jax.random.split(rng_key, 2)
live, dead_info = step_fn(subkey, live)
Expand Down
Loading