Skip to content

Conversation

@williamjameshandley
Copy link

Summary

Refactors the slice sampling implementation to separate generic slice sampling from nested-sampling-specific constraint handling. This improves modularity and makes the slice sampler reusable while preserving all nested sampling functionality.

Architectural Change

Before: Generic slice sampler had built-in constraint handling

  • constraint field in SliceState/SliceInfo
  • Parameters: constraint_fn, constraint array, strict flags passed everywhere
  • Complex conditional logic mixing slice sampling with constraint checking

After: Generic slice sampler accepts slice_fn: (t) -> (state, is_accepted)

  • Caller defines what "accepted" means via slice_fn
  • Nested sampling implements its constraint (loglikelihood > threshold) in custom slice_fn
  • Clean separation: ss.py has no knowledge of constraints

Changes

blackjax/mcmc/ss.py (-96 lines, net simplification)

Removed constraint machinery:

  • Removed constraint field from SliceState and SliceInfo
  • Removed constraint_fn, constraint, strict parameters

New API:

  • build_kernel(slice_fn, max_steps, max_shrinkage) - takes slice_fn as configuration
  • slice_fn: Callable[[float], tuple[SliceState, bool]] - maps parameter t to (state, is_accepted)
  • Kernel signature: (rng_key, state) instead of (rng_key, state, logdensity_fn, d, constraint_fn, constraint, strict)

Renamed for consistency:

  • slicerslice_fn throughout (follows blackjax *_fn convention)

Bug fix:

  • build_hrss_kernel now correctly accepts logdensity_fn as parameter (was referencing it in closure without having it in scope)

Removed:

  • default_stepper_fn (moved to nss.py where it's actually used)

Added:

  • Type annotations: slice_fn: Callable[[float], tuple[SliceState, bool]]
  • generate_slice_direction_fn: Callable[[PRNGKey], ArrayTree]
  • max_steps and max_shrinkage parameters to hrss_as_top_level_api

blackjax/ns/nss.py (+96 lines, explicit constraint handling)

Added PartitionedSliceState:

class PartitionedSliceState(NamedTuple):
    position: ArrayLikeTree
    logdensity: float        # log-prior
    loglikelihood: Array     # for constraint checking

Moved from ss.py:

  • default_stepper_fn - only used in nested sampling

Explicit constraint implementation:

def slice_fn(t):
    x, step_accepted = stepper_fn(state.position, d, t)
    new_state = PartitionedSliceState(
        position=x,
        logdensity=logprior_fn(x),
        loglikelihood=loglikelihood_fn(x),
    )
    in_contour = new_state.loglikelihood > loglikelihood_0  # Clear constraint!
    is_accepted = in_contour & step_accepted
    return new_state, is_accepted

Much clearer than the old jnp.where(strict, constraint > value, constraint >= value) approach.

Builds kernel per-step:

  • Calls build_slice_kernel(slice_fn, max_steps, max_shrinkage) inside inner_kernel
  • Fresh slice_fn captures current state

tests/mcmc/test_slice_sampling.py (complete rewrite)

Removed tests for unexposed functions:

  • test_vertical_slice
  • test_default_stepper_fn

Updated all tests for new API:

  • No more constraint_fn parameters
  • Use slice_fn interface

Added new tests:

  • test_build_kernel_with_custom_slice_fn - tests build_kernel directly
  • test_horizontal_slice_basic - tests horizontal_slice function
  • test_multimodal_sampling - tests on bimodal distribution

All 12 tests pass

Benefits

  1. Modularity: Generic slice sampler is now truly generic
  2. Clarity: Constraint logic is explicit in nested sampling, not buried in generic code
  3. Extensibility: Easy to implement different constraints via custom slice_fn
  4. Type safety: Explicit type annotations improve IDE support
  5. Bug fix: build_hrss_kernel now works correctly

Breaking Changes

None for end users - hrss_as_top_level_api maintains the same interface.

Internal API changes:

  • build_kernel now takes slice_fn as first parameter
  • Removed unexposed internal functions/parameters

Testing

  • ✅ All 12 slice sampling tests pass
  • ✅ All pre-commit hooks pass (black, flake8, isort, mypy)
  • ✅ Verified standalone HRSS functionality
  • ✅ Nested sampling integration preserved (API unchanged)

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

yallup and others added 15 commits July 21, 2025 16:45
…rithm

This commit refactors the slice sampling implementation to separate
generic slice sampling functionality from nested-sampling-specific
constraint handling.

Changes to blackjax/mcmc/ss.py:
- Removed constraint field from SliceState
- Simplified build_kernel() interface to accept a generic slicer function
  instead of constraint_fn, constraint values, and strict flags
- Removed stepper_fn parameter; stepping logic now part of slicer
- Removed default_stepper_fn (moved to nss.py where needed)
- Simplified build_hrss_kernel() and hrss_as_top_level_api() interfaces

Changes to blackjax/ns/nss.py:
- Added PartitionedSliceState to track position, logdensity, and loglikelihood
- Moved default_stepper_fn from ss.py to this module
- Implemented constraint checking (loglikelihood > loglikelihood_0) via
  custom slicer function passed to slice_kernel
- Constraint logic now localized to nested sampling context

This separation makes the slice sampler more modular and reusable while
preserving all nested sampling functionality.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Renames the 'slicer' parameter and local variables to 'slice_fn' throughout
the slice sampling implementation to follow blackjax's *_fn naming convention.

Changes:
- blackjax/mcmc/ss.py: Renamed slicer → slice_fn in build_kernel(),
  horizontal_slice(), and build_hrss_kernel()
- blackjax/ns/nss.py: Renamed local slicer → slice_fn in build_kernel()

The _slice_fn prefix for wrapped local functions follows standard Python
convention for internal/private variables.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
The kernel function inside build_hrss_kernel was referencing logdensity_fn
without it being in scope. Added logdensity_fn as a kernel parameter and
fixed the lambda in slice_fn to use state.position instead of undefined x.

This fixes the standalone Hit-and-Run Slice Sampling (HRSS) functionality.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Rewrote all tests to use the new slice_fn interface instead of the old
stepper_fn and constraint-based API. Changes include:

- Removed test_vertical_slice (function no longer exposed)
- Removed test_default_stepper_fn (function moved to nss.py)
- Updated test_slice_sampling_dimensions to use new build_hrss_kernel API
- Updated test_constrained_slice_sampling to use logdensity for constraints
- Added test_build_kernel_with_custom_slice_fn for direct kernel testing
- Added test_horizontal_slice_basic to test horizontal_slice directly
- Added test_multimodal_sampling for multimodal distributions
- Updated test_slice_info_structure for new SliceInfo fields
  (is_accepted, num_steps, num_shrink instead of constraint, l_steps, etc.)

All 12 tests now pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
…rnel

Changed slice_fn in build_hrss_kernel to use init(x, logdensity_fn) instead
of SliceState(x, logdensity_fn(x)) for consistency with codebase patterns.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Updated docstrings throughout ss.py to match the new slice_fn interface:

- build_kernel: Clarified that kernel takes slice_fn mapping t to
  (SliceState, is_accepted) instead of old constraint-based interface
- horizontal_slice: Removed references to direction `d` from old API,
  updated to describe one-dimensional parameterization by `t`
- build_hrss_kernel: Added missing max_steps and max_shrinkage parameter docs
- hrss_as_top_level_api: Fixed reference to non-existent
  `default_proposal_distribution`, changed to `sample_direction_from_covariance`.
  Added missing max_steps and max_shrinkage parameter docs

All docstrings now accurately describe the current implementation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Added more specific type signatures following blackjax conventions:

- slice_fn: Callable[[float], tuple[SliceState, bool]]
  Specifies that slice_fn takes a scalar parameter and returns
  (state, is_accepted) tuple

- generate_slice_direction_fn: Callable[[PRNGKey], ArrayTree]
  Specifies that direction generator takes a PRNG key and returns
  a direction vector with the same structure as position

These additions improve type safety and IDE support without changing
runtime behavior.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Moved slice_fn from kernel runtime parameter to build_kernel configuration
parameter. This simplifies the API and makes the separation clearer:

Changes to blackjax/mcmc/ss.py:
- build_kernel now takes slice_fn as first parameter (before max_steps)
- Returned kernel signature changed from (rng_key, state, slice_fn) to
  (rng_key, state)
- build_hrss_kernel now calls build_kernel inside kernel function instead
  of once at module scope

Changes to blackjax/ns/nss.py:
- Moved build_slice_kernel call from outer scope to inside inner_kernel
- Kernel now built per-step with fresh slice_fn capturing current state

Changes to tests/mcmc/test_slice_sampling.py:
- Updated test_build_kernel_with_custom_slice_fn to pass slice_fn to
  build_kernel and call kernel with (key, state) only

Conceptual improvement: slice_fn is now treated as configuration (part of
building a specific kernel) rather than runtime data. JIT compilation should
optimize away any per-step overhead from rebuilding the kernel.

All 12 tests pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Fixed several documentation issues:

blackjax/mcmc/ss.py:
- horizontal_slice: Reordered parameter documentation to match actual
  function signature (rng_key, state, slice_fn, m, max_shrinkage)

blackjax/ns/nss.py:
- Added docstring for PartitionedSliceState class
- Fixed stepper_fn documentation to show correct signature
  (x, direction, t) -> (x_new, is_accepted) instead of just -> x_new
- Fixed reference from `default_stepper` to `default_stepper_fn`

All docstrings now accurately reflect the current implementation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Applied automatic formatting from pre-commit hooks:

- isort: Reorganized imports in nss.py
- Removed unused import: ss_init from blackjax.mcmc.ss
- black: Applied code formatting to ss.py, nss.py, and test_slice_sampling.py
- flake8: Fixed indentation issues

All pre-commit hooks now pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
williamjameshandley and others added 14 commits October 17, 2025 14:32
…on sampling

Instead of taking a generic generate_slice_direction_fn parameter, build_hrss_kernel
now directly takes a covariance matrix and uses sample_direction_from_covariance.
This simplifies the API and removes unnecessary abstraction for the common case.
- Replace local PartitionedSliceState with existing PartitionedState from ns/base.py
- Create NSSInfo namedtuple for nested slice sampling diagnostic information
- Make state and info construction explicit instead of using new_state_and_info helper

This improves code reuse and follows the pattern of other samplers like HMC where
Info types contain both proposal state and diagnostic information.
- Fix slice_fn return type annotation: tuple[SliceStateWithLoglikelihood, bool] not SliceInfo
- Add comprehensive docstring to NSSInfo class
- Fix default_stepper_fn return type and improve its docstring
…alls

Replace explicit params["cov"] extraction with **params unpacking at the call site.
This makes the interface more generic and allows future direction generation functions
to accept additional parameters beyond just 'cov'.
- base.py: Use explicit dtype for pid indices (int32), logmeanexp division,
  loglikelihood_birth array creation, and update_ns_runtime_info calculations
- ss.py: Preserve state.logdensity.dtype in random uniform calls, use int32
  for indices, create log_accepted with explicit dtype
- utils.py: Preserve dead_info.loglikelihood.dtype in random uniform and
  array creation (zeros, full) to maintain consistency

These changes prevent implicit dtype conversions that could cause precision
loss when mixing float32/float64 or platform-dependent behavior.
- ss.py: Replace static_binomial_sampling with jax.tree.map for state selection
  Use simple tree map with jnp.where instead of complex probabilistic sampling
- base.py: Remove unnecessary jnp.ones and explicit dtype wrapping
  JAX's at[].set() automatically broadcasts scalars with correct dtype
  Division automatically promotes -1 to match num_live dtype

These simplifications maintain dtype correctness while removing redundant code.
Ensure that the Cholesky factor L uses the same dtype as the position
to prevent unwanted dtype promotion during the matrix multiply L @ u.
Add 2*sqrt(d+2) scaling factor to slice direction generation to optimize
stepping-out performance in hit-and-run slice sampling.

## Mathematical Justification

For points uniformly distributed in a d-dimensional ball of radius R, the
empirical covariance matrix has the form:

    Σ_empirical = (R²/(d+2)) I

This can be derived by computing E[x_i²] for the uniform distribution:
- By spherical symmetry: Σ = σ² I where σ² = E[||x||²]/d = E[r²]/d
- The radial density is p(r) = d r^(d-1) / R^d for 0 ≤ r ≤ R
- Computing E[r²] = ∫₀^R r² · d r^(d-1)/R^d dr = d R²/(d+2)
- Therefore σ² = R²/(d+2)

## The Scaling Factor: 2*sqrt(d+2)

The optimal scaling consists of two components:

1. **Correcting the (d+2) bias** [sqrt(d+2) factor]:
   The empirical covariance underestimates the spatial extent of the region
   by a factor of (d+2). To recover the true scale R from the empirical
   variance R²/(d+2), we need to scale by (d+2).

2. **Diameter vs radius scaling** [factor of 2]:
   For slice sampling, the initial stepping-out interval should be able to
   span the full diameter (2R) of the constrained region, not just the
   radius (R). Since we want directions scaled to 2R, and the Cholesky
   factor L satisfies L @ L.T = cov, we need:

   L_optimal = 2R I = 2 * sqrt(d+2) * L_empirical

Combined scaling: 2 * sqrt(d+2) on the Cholesky factor L, which is
equivalent to scaling the covariance by 4(d+2).

## Implementation

- Applied scaling in `sample_direction_from_covariance` after Cholesky
  decomposition but before applying to unit direction vector
- Kept empirical covariance computation pure in NSS
- Updated docstrings to explain the scaling and its mathematical basis
- Simplified `compute_covariance_from_particles` docstring for consistency

## Validation

See `validate_covariance.py` for numerical validation of the d-ball
covariance formula with relative errors < 0.03% across dimensions 2-20.

Addresses review comment from PR #47 about optimal covariance scaling.
Adding to Slice refactor by including some additional standalone slicing compatibility, also merging changes from #27
williamjameshandley and others added 23 commits October 25, 2025 08:58
refactor to have an update strategy and from_mcmc initialisation
Removes the pid field from NSState as it was primarily used for diagnostic purposes and is not needed for the core nested sampling algorithm.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Rename NSParticle to StateWithLogLikelihood (clearer semantics)
- Bundle (position, logdensity, loglikelihood) into StateWithLogLikelihood
- Keep loglikelihood_birth in NSState (NS-specific bookkeeping)
- Update NSInfo to include both particles and extracted arrays for evidence calculation
- Fix covariance computation to use state.particles.position
- Update finalise() to work with new structure

This simplifies the inner kernel interface by having it operate on
clean StateWithLogLikelihood objects while keeping NS-specific tracking
separate.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
The loglikelihood is already available via particles.loglikelihood,
so storing it separately in NSInfo was redundant. All utility functions
now access it through the particles field.

This eliminates duplication and enforces a single source of truth.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
All particle data is now bundled in StateWithLogLikelihood:
- position, logdensity, loglikelihood, loglikelihood_birth

NSState is now simpler:
- particles (StateWithLogLikelihood)
- logX, logZ, logZ_live (evidence tracking)
- inner_kernel_params

NSInfo is also simpler:
- particles (contains all dead particle data)
- update_info

The birth likelihood is set correctly by inner kernels receiving
loglikelihood_0 as a parameter. During initialization, the scalar
birth likelihood is broadcast to match the batch dimension for
vmap compatibility.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Use len(positions) instead of tree_flatten to get batch size.
Much cleaner and more readable.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Major refactoring to flatten NS state structure and manage adaptive parameters externally:

1. Removed adaptive.py - adaptive parameter management now done in user loop
2. Removed inner_kernel_params from NSState - now passed as kernel argument
3. Eliminated StateWithLogLikelihood wrapper - NSState is now flat with direct fields
4. Updated kernel signature: kernel(rng_key, state, inner_kernel_params)
5. Added init_inner_kernel_params() and update_inner_kernel_params() in nss.py
6. Updated all modules to use flattened state (state.X instead of state.particles.X)

Final NSState structure:
- position: ArrayLikeTree
- logdensity: Array
- loglikelihood: Array
- loglikelihood_birth: Array

This follows BlackJAX patterns where:
- Cumulative statistics (NSIntegrator) managed externally like SMC log-likelihood
- Adaptive parameters passed separately like SMC temperature parameter
- State is minimal and contains only particle data

All tests pass with identical evidence calculations.
- Remove logX, logZ, logZ_live from NSState (now only particles + inner_kernel_params)
- Create NSIntegrator class in blackjax/ns/integrator.py to track evidence integration
- NSIntegrator accumulates cumulative statistics (evidence, prior volume) from dead particles
- Update nested_sampling.py example to manage NSIntegrator externally
- Clean separation: NSState = live points state, NSIntegrator = cumulative statistics

This follows the pattern where derived/cumulative quantities are computed externally
to the kernel, similar to how SMC returns normalizing_constant in SMCInfo.
Ensure loglikelihood_birth has the same dtype and shape as
loglikelihood_values using ones_like. This is important for
type consistency within StateWithLogLikelihood.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Reinstate internal adaptive and tuning
Flatten NSState and move to external parameter management
Extract evidence integration to NSIntegrator
Bundle particle data into StateWithLogLikelihood
def init(
positions: ArrayLikeTree,
init_state_fn: Callable,
loglikelihood_birth: Array = jnp.nan,
Copy link

@AdamOrmondroyd AdamOrmondroyd Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sus default given type annotation

num_live = jnp.arange(num_particles, num_particles - num_deleted, -1)
delta_logX = -1 / num_live
logX = integrator.logX + jnp.cumsum(delta_logX)
log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a utils log1mexp no?


def _logmeanexp(x: Array) -> Array:
"""Compute log(mean(exp(x))) in a numerically stable way."""
n = jnp.array(x.shape[0])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicit axis here making me nervous

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants