Skip to content
Open
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
03e4d98
update to more stable cov
yallup Jul 21, 2025
9faaffb
normalize correctly
yallup Jul 21, 2025
e148bf7
Implement mathematically simplified calculation
zwei-beiner Jul 22, 2025
898616f
Ensure calculation of lower cholesky decomposition
zwei-beiner Jul 22, 2025
800497e
remove vector constraints and add init factory
yallup Oct 2, 2025
e70259c
Refactor slice sampler to decouple constraint handling from core algo…
williamjameshandley Oct 16, 2025
b3b39fe
Rename slicer to slice_fn for consistency with blackjax conventions
williamjameshandley Oct 16, 2025
29d20c9
Fix build_hrss_kernel to accept logdensity_fn parameter
williamjameshandley Oct 16, 2025
337ead1
Update slice sampling tests to match refactored API
williamjameshandley Oct 16, 2025
4f971ee
Use init() instead of direct SliceState construction in build_hrss_ke…
williamjameshandley Oct 16, 2025
b609905
Update docstrings to reflect refactored slice_fn API
williamjameshandley Oct 16, 2025
f1d34f4
Add explicit type annotations for slice_fn and direction function
williamjameshandley Oct 16, 2025
6acee5b
Refactor build_kernel to accept slice_fn as configuration parameter
williamjameshandley Oct 16, 2025
502df29
Fix docstring inconsistencies in ss.py and nss.py
williamjameshandley Oct 16, 2025
c512f8c
Apply pre-commit formatting fixes
williamjameshandley Oct 16, 2025
865e460
Simplify build_hrss_kernel API by hardcoding covariance-based directi…
williamjameshandley Oct 17, 2025
256223d
Use PartitionedState from base and create explicit NSSInfo
williamjameshandley Oct 17, 2025
50734eb
clean up nss for state based slicing
yallup Oct 20, 2025
53548b9
Fix docstring inconsistencies in ss.py and nss.py
williamjameshandley Oct 20, 2025
0d571fc
Use **params unpacking for extensible direction generation function c…
williamjameshandley Oct 20, 2025
88fc2dd
use more stable cholesky
yallup Oct 20, 2025
48ab92c
Fix dtype consistency issues across NS and slice sampling code
williamjameshandley Oct 20, 2025
2644e75
Simplify dtype handling by leveraging JAX's automatic broadcasting
williamjameshandley Oct 20, 2025
2d0e0ec
Merge branch 'slice_state_dict' of github.com:handley-lab/blackjax in…
williamjameshandley Oct 20, 2025
e7f5d81
Fix dtype consistency in Cholesky-based direction sampling
williamjameshandley Oct 20, 2025
5bac447
Merge changes from log weights branch
yallup Oct 20, 2025
667fc63
Apply optimal covariance scaling for slice sampling direction proposals
williamjameshandley Oct 20, 2025
746ebdf
Merge pull request #47 from handley-lab/slice_state_dict
yallup Oct 20, 2025
87eed90
Merge branch 'refactor-slice-sampling-api' of github.com:handley-lab/…
williamjameshandley Oct 20, 2025
d945b13
Pass position template through params dict for consistency
williamjameshandley Oct 20, 2025
703ae27
Update docstring to reflect position now in kwargs
williamjameshandley Oct 20, 2025
7d9ca74
refactor factory init
yallup Oct 21, 2025
89e4858
fix slice tests
yallup Oct 21, 2025
3d220ea
refactor partitioned state to be consistent throughout
yallup Oct 22, 2025
8cff7b7
Idea for saving intermediate transition states
yallup Oct 22, 2025
6f72e63
add transition states to info
yallup Oct 22, 2025
80e2144
new format with strategy
yallup Oct 23, 2025
4a7d37f
fix naming convention in base
yallup Oct 23, 2025
d8d0914
remove partitioned naming
yallup Oct 23, 2025
95fdce6
remove logprior and logl from lower level functions
yallup Oct 23, 2025
97b0547
remove unnecesary casting axes
yallup Oct 23, 2025
a62ed51
make init strategy consistent throughout
yallup Oct 23, 2025
349063e
add placeholder generic from_mcmc kernel
yallup Oct 23, 2025
ddc07cf
make the info not accumulate states
yallup Oct 23, 2025
e8fa485
really remove states as compilation is slow
yallup Oct 24, 2025
7b2f01b
Fix docstrings in NS modules
williamjameshandley Oct 25, 2025
7ed83cf
Merge branch 'factory' of github.com:handley-lab/blackjax into factory
williamjameshandley Oct 25, 2025
ce2bd3f
Merge pull request #48 from handley-lab/factory
yallup Oct 25, 2025
aec8d1d
Remove particle ID tracking from NSState
williamjameshandley Oct 25, 2025
7e93499
Refactor NSState to use StateWithLogLikelihood
williamjameshandley Oct 25, 2025
779c431
Remove redundant loglikelihood field from NSInfo
williamjameshandley Oct 25, 2025
87a08c0
Move loglikelihood_birth into StateWithLogLikelihood
williamjameshandley Oct 25, 2025
536ca5a
Simplify loglikelihood_birth broadcasting using len()
williamjameshandley Oct 25, 2025
182bbbb
Eliminate state nesting and move to external parameter management
williamjameshandley Oct 25, 2025
2ea0155
Refactor NS to separate live state from evidence integration
williamjameshandley Oct 25, 2025
fc5b55c
Restore ones_like for dtype and shape consistency
williamjameshandley Oct 25, 2025
c099774
reinstate particles internally
yallup Oct 27, 2025
8807de0
include attempt to clean up some things in finalise
yallup Oct 28, 2025
2ff586b
fix tests and prune some docstrings
yallup Nov 3, 2025
1649426
better naming in utils
yallup Nov 11, 2025
19fff99
revert to inv cov due to vmap bug with cholesky
yallup Nov 12, 2025
10a3778
remove inner kernel params from base
yallup Nov 14, 2025
fa04761
Update blackjax/ns/adaptive.py
williamjameshandley Nov 24, 2025
6233068
Update tests/ns/test_nested_sampling.py
williamjameshandley Nov 24, 2025
85091fa
post review changes
yallup Nov 24, 2025
2fedde9
Merge pull request #55 from handley-lab/external_params_refactor
yallup Nov 24, 2025
6a42204
Merge pull request #54 from handley-lab/external-params
yallup Nov 24, 2025
73b6163
Merge pull request #53 from handley-lab/integrator-refactor
yallup Nov 24, 2025
f419265
Merge pull request #52 from handley-lab/particle-bundling
yallup Nov 24, 2025
d454247
use log1mexp in update_integrator
AdamOrmondroyd Dec 9, 2025
18bd5e0
fix from_mcmc generic kernel
yallup Dec 15, 2025
ab2a117
fix return type
yallup Dec 15, 2025
4b07423
Merge branch 'refactor-slice-sampling-api' of github.com:handley-lab/…
yallup Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 79 additions & 103 deletions blackjax/mcmc/ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import jax.numpy as jnp

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
Expand All @@ -59,7 +58,6 @@ class SliceState(NamedTuple):

position: ArrayLikeTree
logdensity: float
constraint: Array


class SliceInfo(NamedTuple):
Expand All @@ -72,8 +70,6 @@ class SliceInfo(NamedTuple):
----------
is_accepted
A boolean indicating whether the proposed sample was accepted.
constraint
The constraint values at the final accepted position.
num_steps
The number of steps taken to expand the interval during the "stepping-out" phase.
num_shrink
Expand All @@ -86,9 +82,7 @@ class SliceInfo(NamedTuple):
num_shrink: int


def init(
position: ArrayTree, logdensity_fn: Callable, constraint_fn: Callable
) -> SliceState:
def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState:
"""Initialize the Slice Sampler state.

Parameters
Expand All @@ -103,11 +97,11 @@ def init(
SliceState
The initial state of the Slice Sampler.
"""
return SliceState(position, logdensity_fn(position), constraint_fn(position))
return SliceState(position, logdensity_fn(position))


def build_kernel(
stepper_fn: Callable,
slice_fn: Callable[[float], tuple[SliceState, bool]],
max_steps: int = 10,
max_shrinkage: int = 100,
) -> Callable:
Expand All @@ -119,17 +113,21 @@ def build_kernel(

Parameters
----------
stepper_fn
A function that computes a new position given an initial position,
direction `d` and a slice parameter `t`.
`(x0, d, t) -> x_new` where e.g. `x_new = x0 + t * d`.
slice_fn
A function that takes a scalar parameter `t` and returns a tuple
(SliceState, is_accepted) indicating the state at that parameter value
and whether it satisfies acceptance criteria.
max_steps
The maximum number of steps to take when expanding the interval in
each direction during the stepping-out phase.
max_shrinkage
The maximum number of shrinking steps to perform to avoid infinite loops.

Returns
-------
Callable
A kernel function that takes a PRNG key, the current `SliceState`,
the log-density function, direction `d`, constraint function, constraint
values, and strict flags, and returns a new `SliceState` and `SliceInfo`.
A kernel function that takes a PRNG key and the current `SliceState`,
and returns a new `SliceState` and `SliceInfo`.

References
----------
Expand All @@ -139,32 +137,19 @@ def build_kernel(
def kernel(
rng_key: PRNGKey,
state: SliceState,
logdensity_fn: Callable,
d: ArrayTree,
constraint_fn: Callable,
constraint: Array,
strict: Array,
) -> tuple[SliceState, SliceInfo]:
vs_key, hs_key = jax.random.split(rng_key)
logslice = state.logdensity + jnp.log(jax.random.uniform(vs_key))
u = jax.random.uniform(vs_key)
logslice = state.logdensity + jnp.log(u)
vertical_is_accepted = logslice < state.logdensity

def slicer(t) -> tuple[SliceState, SliceInfo]:
x, step_accepted = stepper_fn(state.position, d, t)
new_state = init(x, logdensity_fn, constraint_fn)
constraints_ok = jnp.all(
jnp.where(
strict,
new_state.constraint > constraint,
new_state.constraint >= constraint,
)
)
def _slice_fn(t):
new_state, is_accepted = slice_fn(t)
in_slice = new_state.logdensity >= logslice
is_accepted = in_slice & constraints_ok & step_accepted
return new_state, is_accepted
return new_state, is_accepted & in_slice

new_state, info = horizontal_slice(
hs_key, slicer, state, max_steps, max_shrinkage
hs_key, state, _slice_fn, max_steps, max_shrinkage
)
info = info._replace(is_accepted=info.is_accepted & vertical_is_accepted)
return new_state, info
Expand All @@ -174,29 +159,29 @@ def slicer(t) -> tuple[SliceState, SliceInfo]:

def horizontal_slice(
rng_key: PRNGKey,
slicer: Callable,
state: SliceState,
slice_fn: Callable[[float], tuple[SliceState, bool]],
m: int,
max_shrinkage: int,
) -> tuple[SliceState, SliceInfo]:
"""Propose a new sample using the stepping-out and shrinking procedures.

This function implements the core of the Hit-and-Run Slice Sampling algorithm.
It first expands an interval (`[l, r]`) along the slice starting
from `x0` and proceeding along direction `d` until both ends are outside
the slice defined by `logslice` (stepping-out). Then, it samples
points uniformly from this interval and shrinks the interval until a point
is found that lies within the slice (shrinking).
It first expands an interval (`[l, r]`) along a one-dimensional parameterization
until both ends are outside the slice defined by `logslice` (stepping-out).
Then, it samples points uniformly from this interval and shrinks the interval
until a point is found that lies within the slice (shrinking).

Parameters
----------
rng_key
A JAX PRNG key.
slicer
A function that takes a scalar `t` and returns a state and info on the
slice.
state
The current slice sampling state.
slice_fn
A function that takes a scalar parameter `t` and returns a tuple
(SliceState, is_accepted) indicating the state at that parameter value
and whether it satisfies acceptance criteria.
m
The maximum number of steps to take when expanding the interval in
each direction during the stepping-out phase.
Expand All @@ -213,14 +198,14 @@ def horizontal_slice(
# Initial bounds
rng_key, subkey = jax.random.split(rng_key)
u, v = jax.random.uniform(subkey, 2)
j = jnp.floor(m * v).astype(int)
j = jnp.floor(m * v).astype(jnp.int32)
k = (m - 1) - j

# Expand
def step_body_fun(carry):
i, s, t, _ = carry
t += s
_, is_accepted = slicer(t)
_, is_accepted = slice_fn(t)
i -= 1
return i, s, t, is_accepted

Expand All @@ -240,7 +225,7 @@ def shrink_body_fun(carry):
rng_key, subkey = jax.random.split(rng_key)
u = jax.random.uniform(subkey, minval=l, maxval=r)

new_state, is_accepted = slicer(u)
new_state, is_accepted = slice_fn(u)
n += 1

l = jnp.where(u < 0, u, l)
Expand All @@ -255,17 +240,18 @@ def shrink_cond_fun(carry):
carry = 0, rng_key, l, r, state, False
carry = jax.lax.while_loop(shrink_cond_fun, shrink_body_fun, carry)
n, _, _, _, new_state, is_accepted = carry
new_state, (is_accepted, _, _) = static_binomial_sampling(
rng_key, jnp.log(is_accepted), state, new_state
new_state = jax.tree.map(
lambda new, old: jnp.where(is_accepted, new, old), new_state, state
)
slice_info = SliceInfo(is_accepted, m + 1 - j - k, n)
return new_state, slice_info


def build_hrss_kernel(
generate_slice_direction_fn: Callable,
stepper_fn: Callable,
cov: Array,
init_fn: Callable = init,
max_steps: int = 10,
max_shrinkage: int = 100,
) -> Callable:
"""Build a Hit-and-Run Slice Sampling kernel.

Expand All @@ -276,91 +262,75 @@ def build_hrss_kernel(

Parameters
----------
generate_slice_direction_fn
A function that, given a PRNG key, generates a direction vector (PyTree
with the same structure as the position) for the "hit-and-run" part of
the algorithm. This direction is typically normalized.

stepper_fn
A function that computes a new position given an initial position, a
direction, and a step size `t`. It should implement something analogous
to `x_new = x_initial + t * direction`.
cov
The covariance matrix used by the direction proposal function
init_fn
A function initializing a SliceState
max_steps
The maximum number of steps to take when expanding the interval in
each direction during the stepping-out phase.
max_shrinkage
The maximum number of shrinking steps to perform to avoid infinite loops.

Returns
-------
Callable
A kernel function that takes a PRNG key, the current `SliceState`, and
the log-density function, and returns a new `SliceState` and `SliceInfo`.
"""
slice_kernel = build_kernel(stepper_fn, max_steps)

def kernel(
rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable
) -> tuple[SliceState, SliceInfo]:
rng_key, prop_key = jax.random.split(rng_key, 2)
d = generate_slice_direction_fn(prop_key)
constraint_fn = lambda x: jnp.array([])
constraint = jnp.array([])
strict = jnp.array([], dtype=bool)
return slice_kernel(
rng_key, state, logdensity_fn, d, constraint_fn, constraint, strict
)

return kernel

d = sample_direction_from_covariance(prop_key, state.position, cov)

def default_stepper_fn(x: ArrayTree, d: ArrayTree, t: float) -> ArrayTree:
"""A simple stepper function that moves from `x` along direction `d` by `t` units.

Implements the operation: `x_new = x + t * d`.
def slice_fn(t):
x = jax.tree.map(lambda x, d: x + t * d, state.position, d)
is_accepted = True
new_state = init_fn(x, logdensity_fn)
return new_state, is_accepted

Parameters
----------
x
The starting position (PyTree).
d
The direction of movement (PyTree, same structure as `x`).
t
The scalar step size or distance along the direction.
slice_kernel = build_kernel(slice_fn, max_steps, max_shrinkage)
return slice_kernel(rng_key, state)

Returns
-------
position, is_accepted
"""
return jax.tree.map(lambda x, d: x + t * d, x, d), True
return kernel


def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array:
def sample_direction_from_covariance(
rng_key: PRNGKey, position: ArrayLikeTree, cov: Array
) -> Array:
"""Generates a random direction vector, normalized, from a multivariate Gaussian.

This function samples a direction `d` from a zero-mean multivariate Gaussian
distribution with covariance matrix `cov`, and then normalizes `d` to be a
unit vector with respect to the Mahalanobis norm defined by `inv(cov)`.
That is, `d_normalized^T @ inv(cov) @ d_normalized = 1`.

Parameters
----------
rng_key
A JAX PRNG key.
position
The current position of the chain (used for extracting shape).
cov
The covariance matrix for the multivariate Gaussian distribution from which
the initial direction is sampled. Assumed to be a 2D array.

The covariance matrix.
Returns
-------
Array
A normalized direction vector (1D array).
"""
d = jax.random.multivariate_normal(rng_key, mean=jnp.zeros(cov.shape[0]), cov=cov)
p, unravel_fn = jax.flatten_util.ravel_pytree(position)
d = jax.random.normal(rng_key, shape=p.shape, dtype=p.dtype)
invcov = jnp.linalg.inv(cov)
norm = jnp.sqrt(jnp.einsum("...i,...ij,...j", d, invcov, d))
d = d / norm[..., None]
return d
d *= 2
return unravel_fn(d)


def hrss_as_top_level_api(
logdensity_fn: Callable,
cov: Array,
init_fn: Callable = init,
max_steps: int = 10,
max_shrinkage: int = 100,
) -> SamplingAlgorithm:
"""Creates a Hit-and-Run Slice Sampling algorithm.

Expand All @@ -373,18 +343,24 @@ def hrss_as_top_level_api(
logdensity_fn
The log-density function of the target distribution to sample from.
cov
The covariance matrix used by the default direction proposal function
(`default_proposal_distribution`). This matrix shapes the random
The covariance matrix used by the direction proposal function
(`sample_direction_from_covariance`). This matrix shapes the random
directions proposed for the slice sampling steps.
init_fn
A function initializing a SliceState
max_steps
The maximum number of steps to take when expanding the interval in
each direction during the stepping-out phase.
max_shrinkage
The maximum number of shrinking steps to perform to avoid infinite loops.

Returns
-------
SamplingAlgorithm
A `SamplingAlgorithm` tuple containing `init` and `step` functions for
the configured Hit-and-Run Slice Sampler.
"""
generate_slice_direction_fn = partial(sample_direction_from_covariance, cov=cov)
kernel = build_hrss_kernel(generate_slice_direction_fn, default_stepper_fn)
init_fn = partial(init, logdensity_fn=logdensity_fn)
kernel = build_hrss_kernel(cov, init_fn, max_steps, max_shrinkage)
init_fn = partial(init_fn, logdensity_fn=logdensity_fn)
step_fn = partial(kernel, logdensity_fn=logdensity_fn)
return SamplingAlgorithm(init_fn, step_fn)
18 changes: 10 additions & 8 deletions blackjax/ns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.
"""Nested Sampling Algorithms in BlackJAX.

This subpackage provides implementations of Nested Sampling algorithms,
including a base version, an adaptive version, and Nested Slice Sampling (NSS).
This subpackage provides implementations of Nested Sampling algorithms.

Nested Sampling is a Monte Carlo method for Bayesian computation, primarily
used for evidence (marginal likelihood) calculation and posterior sampling.
Expand All @@ -23,20 +22,23 @@

Available modules:
------------------
- `adaptive`: Implements an adaptive Nested Sampling algorithm where inner
kernel parameters are tuned at each iteration.
- `base`: Provides core components and a non-adaptive Nested Sampling kernel.
- `base`: Provides core components for Nested Sampling.
- `adaptive`: Implements Adaptive Nested Sampling, combining SMC tempering
- `nss`: Implements Nested Slice Sampling, using Hit-and-Run Slice Sampling as
the inner kernel with adaptive tuning of its proposal mechanism.
- `integrator`: Provides NSIntegrator for tracking evidence integration.
- `utils`: Contains utility functions for processing and analyzing Nested
Sampling results.
- `from_mcmc`: Utilities to build Nested Sampling algorithms from MCMC kernels.

"""
from . import adaptive, base, nss, utils
from . import adaptive, base, from_mcmc, integrator, nss, utils

__all__ = [
"adaptive",
"base",
"utils",
"adaptive",
"integrator",
"nss",
"utils",
"from_mcmc",
]
Loading
Loading