Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions blackjax/mcmc/ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,31 +386,37 @@ def default_stepper_fn(x: ArrayTree, d: ArrayTree, t: float) -> ArrayTree:
return jax.tree.map(lambda x, d: x + t * d, x, d)


def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array:
def sample_direction_from_covariance(
rng_key: PRNGKey, cov: Array, chol: Array
) -> Array:
"""Generates a random direction vector, normalized, from a multivariate Gaussian.

This function samples a direction `d` from a zero-mean multivariate Gaussian
distribution with covariance matrix `cov`, and then normalizes `d` to be a
unit vector with respect to the Mahalanobis norm defined by `inv(cov)`.
That is, `d_normalized^T @ inv(cov) @ d_normalized = 1`.
This function generates a direction vector uniformly distributed on a hypersphere
by using the mathematical simplification:
1. Sample from standard multivariate normal N(0, I)
2. Normalize to unit vector (uniform on hypersphere)
3. Transform by S^(1/2) where S is the covariance matrix

This is equivalent to sampling from N(0, S) and normalizing by Mahalanobis norm
but is more numerically stable and efficient.

Parameters
----------
rng_key
A JAX PRNG key.
cov
The covariance matrix for the multivariate Gaussian distribution from which
the initial direction is sampled. Assumed to be a 2D array.
The covariance matrix (unused in simplified version, kept for compatibility).
chol
The square root of the covariance matrix (Cholesky factor).

Returns
-------
Array
A normalized direction vector (1D array).
"""
d = jax.random.multivariate_normal(rng_key, mean=jnp.zeros(cov.shape[0]), cov=cov)
invcov = jnp.linalg.inv(cov)
norm = jnp.sqrt(jnp.einsum("...i,...ij,...j", d, invcov, d))
d = d / norm[..., None]
z = jax.random.normal(rng_key, shape=(cov.shape[0],))
u = z / jnp.linalg.norm(z)
d = chol.T @ u
return d


Expand Down
54 changes: 27 additions & 27 deletions blackjax/ns/nss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,14 @@ def sample_direction_from_covariance(
) -> ArrayTree:
"""Default function to generate a normalized slice direction for NSS.

This function is designed to work with covariance parameters adapted by
`default_adapt_direction_params_fn`. It expects `params` to contain
'cov', a PyTree structured identically to a single particle. Each leaf
of this 'cov' PyTree contains rows of the full covariance matrix that
correspond to that leaf's elements in the flattened particle vector.
(Specifically, if the full DxD covariance matrix of flattened particles is
`M_flat`, and `unravel_fn` un-flattens a D-vector to the particle PyTree,
then the input `cov` is effectively `jax.vmap(unravel_fn)(M_flat)`).

The function reassembles the full (D,D) covariance matrix from this
PyTree structure. It then samples a flat direction vector `d_flat` from
a multivariate Gaussian $\\mathcal{N}(0, M_{reassembled})$, normalizes
`d_flat` using the Mahalanobis norm defined by $M_{reassembled}^{-1}$,
and finally un-flattens this normalized direction back into the
particle's PyTree structure using an `unravel_fn` derived from the
particle structure.
This function uses a mathematically simplified approach to generate direction
vectors uniformly distributed on a hypersphere:
1. Sample from standard multivariate normal N(0, I)
2. Normalize to unit vector (uniform on hypersphere)
3. Transform by S^(1/2) where S is the covariance matrix

This is equivalent to the traditional approach of sampling from N(0, S) and
normalizing by Mahalanobis norm, but is more numerically stable and efficient.

Parameters
----------
Expand All @@ -85,20 +77,23 @@ def sample_direction_from_covariance(
params
Keyword arguments, must contain:
- `cov`: A PyTree (structured like a particle) whose leaves are rows
of the covariance matrix, typically output by
`compute_covariance_from_particles`.
of the covariance matrix.
- `chol`: A PyTree with the square root (Cholesky factor) of the
covariance matrix.

Returns
-------
ArrayTree
A Mahalanobis-normalized direction vector (PyTree, matching the
structure of a single particle), to be used by the slice sampler.
A direction vector uniformly distributed on a hypersphere (PyTree,
matching the structure of a single particle), to be used by the slice sampler.
"""
cov = params["cov"]
chol = params["chol"]
row = get_first_row(cov)
_, unravel_fn = ravel_pytree(row)
cov = particles_as_rows(cov)
d = ss_sample_direction_from_covariance(rng_key, cov)
chol = particles_as_rows(chol)
d = ss_sample_direction_from_covariance(rng_key, cov, chol)
return unravel_fn(d)


Expand Down Expand Up @@ -126,18 +121,23 @@ def compute_covariance_from_particles(
Returns
-------
Dict[str, ArrayTree]
A dictionary `{'cov': cov_pytree}`. `cov_pytree` is a PyTree with the
same structure as a single particle. If the full DxD covariance matrix
of the flattened particles is `M_flat`, and `unravel_fn` is the function
to un-flatten a D-vector to the particle's PyTree structure, then
`cov_pytree` is equivalent to `jax.vmap(unravel_fn)(M_flat)`.
A dictionary `{'cov': cov_pytree, 'chol': chol_pytree}`.
`cov_pytree` is a PyTree with the same structure as a single particle
containing the covariance matrix. `chol_pytree` contains the Cholesky
decomposition (square root) of the covariance matrix. If the full DxD
covariance matrix of the flattened particles is `M_flat`, and `unravel_fn`
is the function to un-flatten a D-vector to the particle's PyTree structure,
then `cov_pytree` is equivalent to `jax.vmap(unravel_fn)(M_flat)`.
This means each leaf of `cov_pytree` will have a shape `(D, *leaf_original_dims)`.
"""
cov_matrix = jnp.atleast_2d(particles_covariance_matrix(state.particles))
cov_matrix *= cov_matrix.shape[0] + 2
chol_matrix = jnp.linalg.cholesky(cov_matrix)
single_particle = get_first_row(state.particles)
_, unravel_fn = ravel_pytree(single_particle)
cov_pytree = jax.vmap(unravel_fn)(cov_matrix)
return {"cov": cov_pytree}
chol_pytree = jax.vmap(unravel_fn)(chol_matrix)
return {"cov": cov_pytree, "chol": chol_pytree}


def build_kernel(
Expand Down
52 changes: 36 additions & 16 deletions blackjax/ns/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,26 @@ def logX(rng_key: PRNGKey, dead_info: NSInfo, shape: int = 100) -> tuple[Array,
`dX_i` is approximately `X_i - X_{i+1}`.
"""
rng_key, subkey = jax.random.split(rng_key)
min_val = jnp.finfo(dead_info.loglikelihood.dtype).tiny
r = jnp.log(
jax.random.uniform(
subkey, shape=(dead_info.loglikelihood.shape[0], shape)
).clip(min_val, 1 - min_val)
)
r = -jax.random.exponential(subkey, shape=(dead_info.loglikelihood.shape[0], shape))

num_live = compute_num_live(dead_info)
t = r / num_live[:, jnp.newaxis]

# Standard cumulative sum
logX = jnp.cumsum(t, axis=0)

logXp = jnp.concatenate([jnp.zeros((1, logX.shape[1])), logX[:-1]], axis=0)
logXm = jnp.concatenate([logX[1:], jnp.full((1, logX.shape[1]), -jnp.inf)], axis=0)
log_diff = logXm - logXp
logdX = log1mexp(log_diff) + logXp - jnp.log(2)

# When log_diff = 0 due to precision limits, the volume element approaches 0
# Instead of getting -inf from log1mexp(0), use a large negative value
# that represents the precision limit
precision_floor = jnp.log(jnp.finfo(logX.dtype).eps)
safe_log_diff = jnp.where(log_diff == 0.0, precision_floor, log_diff)

logdX = log1mexp(safe_log_diff) + logXp - jnp.log(2)

return logX, logdX


Expand Down Expand Up @@ -229,22 +234,37 @@ def finalise(live: NSState, dead: list[NSInfo]) -> NSInfo:
The `update_info` from the last element of `dead` is used
for the final live points' `update_info` (as a placeholder).
"""

all_pytrees_to_combine = dead + [
NSInfo(
if not dead:
return NSInfo(
live.particles,
live.loglikelihood,
live.loglikelihood_birth,
live.logprior,
dead[-1].inner_kernel_info,
{}, # type: ignore
)

all_particles = [d.particles for d in dead] + [live.particles]
combined_particles = jax.tree.map(
lambda *args: jnp.concatenate(args, axis=0), *all_particles # type: ignore
)

all_loglikelihood = [d.loglikelihood for d in dead] + [live.loglikelihood]
all_loglikelihood_birth = [d.loglikelihood_birth for d in dead] + [
live.loglikelihood_birth
]
combined_dead_info = jax.tree.map(
lambda *args: jnp.concatenate(args),
all_pytrees_to_combine[0],
*all_pytrees_to_combine[1:],
all_logprior = [d.logprior for d in dead] + [live.logprior]

combined_loglikelihood = jnp.concatenate(all_loglikelihood, axis=0)
combined_loglikelihood_birth = jnp.concatenate(all_loglikelihood_birth, axis=0)
combined_logprior = jnp.concatenate(all_logprior, axis=0)

return NSInfo(
combined_particles,
combined_loglikelihood,
combined_loglikelihood_birth,
combined_logprior,
dead[-1].inner_kernel_info,
)
return combined_dead_info


def ess(rng_key: PRNGKey, dead_info_map: NSInfo) -> Array:
Expand Down