-
Notifications
You must be signed in to change notification settings - Fork 2
Refactor slice sampler to decouple constraint handling from core algorithm #45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: nested_sampling
Are you sure you want to change the base?
Conversation
…rithm This commit refactors the slice sampling implementation to separate generic slice sampling functionality from nested-sampling-specific constraint handling. Changes to blackjax/mcmc/ss.py: - Removed constraint field from SliceState - Simplified build_kernel() interface to accept a generic slicer function instead of constraint_fn, constraint values, and strict flags - Removed stepper_fn parameter; stepping logic now part of slicer - Removed default_stepper_fn (moved to nss.py where needed) - Simplified build_hrss_kernel() and hrss_as_top_level_api() interfaces Changes to blackjax/ns/nss.py: - Added PartitionedSliceState to track position, logdensity, and loglikelihood - Moved default_stepper_fn from ss.py to this module - Implemented constraint checking (loglikelihood > loglikelihood_0) via custom slicer function passed to slice_kernel - Constraint logic now localized to nested sampling context This separation makes the slice sampler more modular and reusable while preserving all nested sampling functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Renames the 'slicer' parameter and local variables to 'slice_fn' throughout the slice sampling implementation to follow blackjax's *_fn naming convention. Changes: - blackjax/mcmc/ss.py: Renamed slicer → slice_fn in build_kernel(), horizontal_slice(), and build_hrss_kernel() - blackjax/ns/nss.py: Renamed local slicer → slice_fn in build_kernel() The _slice_fn prefix for wrapped local functions follows standard Python convention for internal/private variables. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
The kernel function inside build_hrss_kernel was referencing logdensity_fn without it being in scope. Added logdensity_fn as a kernel parameter and fixed the lambda in slice_fn to use state.position instead of undefined x. This fixes the standalone Hit-and-Run Slice Sampling (HRSS) functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Rewrote all tests to use the new slice_fn interface instead of the old stepper_fn and constraint-based API. Changes include: - Removed test_vertical_slice (function no longer exposed) - Removed test_default_stepper_fn (function moved to nss.py) - Updated test_slice_sampling_dimensions to use new build_hrss_kernel API - Updated test_constrained_slice_sampling to use logdensity for constraints - Added test_build_kernel_with_custom_slice_fn for direct kernel testing - Added test_horizontal_slice_basic to test horizontal_slice directly - Added test_multimodal_sampling for multimodal distributions - Updated test_slice_info_structure for new SliceInfo fields (is_accepted, num_steps, num_shrink instead of constraint, l_steps, etc.) All 12 tests now pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…rnel Changed slice_fn in build_hrss_kernel to use init(x, logdensity_fn) instead of SliceState(x, logdensity_fn(x)) for consistency with codebase patterns. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Updated docstrings throughout ss.py to match the new slice_fn interface: - build_kernel: Clarified that kernel takes slice_fn mapping t to (SliceState, is_accepted) instead of old constraint-based interface - horizontal_slice: Removed references to direction `d` from old API, updated to describe one-dimensional parameterization by `t` - build_hrss_kernel: Added missing max_steps and max_shrinkage parameter docs - hrss_as_top_level_api: Fixed reference to non-existent `default_proposal_distribution`, changed to `sample_direction_from_covariance`. Added missing max_steps and max_shrinkage parameter docs All docstrings now accurately describe the current implementation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Added more specific type signatures following blackjax conventions: - slice_fn: Callable[[float], tuple[SliceState, bool]] Specifies that slice_fn takes a scalar parameter and returns (state, is_accepted) tuple - generate_slice_direction_fn: Callable[[PRNGKey], ArrayTree] Specifies that direction generator takes a PRNG key and returns a direction vector with the same structure as position These additions improve type safety and IDE support without changing runtime behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Moved slice_fn from kernel runtime parameter to build_kernel configuration parameter. This simplifies the API and makes the separation clearer: Changes to blackjax/mcmc/ss.py: - build_kernel now takes slice_fn as first parameter (before max_steps) - Returned kernel signature changed from (rng_key, state, slice_fn) to (rng_key, state) - build_hrss_kernel now calls build_kernel inside kernel function instead of once at module scope Changes to blackjax/ns/nss.py: - Moved build_slice_kernel call from outer scope to inside inner_kernel - Kernel now built per-step with fresh slice_fn capturing current state Changes to tests/mcmc/test_slice_sampling.py: - Updated test_build_kernel_with_custom_slice_fn to pass slice_fn to build_kernel and call kernel with (key, state) only Conceptual improvement: slice_fn is now treated as configuration (part of building a specific kernel) rather than runtime data. JIT compilation should optimize away any per-step overhead from rebuilding the kernel. All 12 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Fixed several documentation issues: blackjax/mcmc/ss.py: - horizontal_slice: Reordered parameter documentation to match actual function signature (rng_key, state, slice_fn, m, max_shrinkage) blackjax/ns/nss.py: - Added docstring for PartitionedSliceState class - Fixed stepper_fn documentation to show correct signature (x, direction, t) -> (x_new, is_accepted) instead of just -> x_new - Fixed reference from `default_stepper` to `default_stepper_fn` All docstrings now accurately reflect the current implementation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Applied automatic formatting from pre-commit hooks: - isort: Reorganized imports in nss.py - Removed unused import: ss_init from blackjax.mcmc.ss - black: Applied code formatting to ss.py, nss.py, and test_slice_sampling.py - flake8: Fixed indentation issues All pre-commit hooks now pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…on sampling Instead of taking a generic generate_slice_direction_fn parameter, build_hrss_kernel now directly takes a covariance matrix and uses sample_direction_from_covariance. This simplifies the API and removes unnecessary abstraction for the common case.
- Replace local PartitionedSliceState with existing PartitionedState from ns/base.py - Create NSSInfo namedtuple for nested slice sampling diagnostic information - Make state and info construction explicit instead of using new_state_and_info helper This improves code reuse and follows the pattern of other samplers like HMC where Info types contain both proposal state and diagnostic information.
- Fix slice_fn return type annotation: tuple[SliceStateWithLoglikelihood, bool] not SliceInfo - Add comprehensive docstring to NSSInfo class - Fix default_stepper_fn return type and improve its docstring
…alls Replace explicit params["cov"] extraction with **params unpacking at the call site. This makes the interface more generic and allows future direction generation functions to accept additional parameters beyond just 'cov'.
- base.py: Use explicit dtype for pid indices (int32), logmeanexp division, loglikelihood_birth array creation, and update_ns_runtime_info calculations - ss.py: Preserve state.logdensity.dtype in random uniform calls, use int32 for indices, create log_accepted with explicit dtype - utils.py: Preserve dead_info.loglikelihood.dtype in random uniform and array creation (zeros, full) to maintain consistency These changes prevent implicit dtype conversions that could cause precision loss when mixing float32/float64 or platform-dependent behavior.
- ss.py: Replace static_binomial_sampling with jax.tree.map for state selection Use simple tree map with jnp.where instead of complex probabilistic sampling - base.py: Remove unnecessary jnp.ones and explicit dtype wrapping JAX's at[].set() automatically broadcasts scalars with correct dtype Division automatically promotes -1 to match num_live dtype These simplifications maintain dtype correctness while removing redundant code.
…to slice_state_dict
Ensure that the Cholesky factor L uses the same dtype as the position to prevent unwanted dtype promotion during the matrix multiply L @ u.
Add 2*sqrt(d+2) scaling factor to slice direction generation to optimize
stepping-out performance in hit-and-run slice sampling.
## Mathematical Justification
For points uniformly distributed in a d-dimensional ball of radius R, the
empirical covariance matrix has the form:
Σ_empirical = (R²/(d+2)) I
This can be derived by computing E[x_i²] for the uniform distribution:
- By spherical symmetry: Σ = σ² I where σ² = E[||x||²]/d = E[r²]/d
- The radial density is p(r) = d r^(d-1) / R^d for 0 ≤ r ≤ R
- Computing E[r²] = ∫₀^R r² · d r^(d-1)/R^d dr = d R²/(d+2)
- Therefore σ² = R²/(d+2)
## The Scaling Factor: 2*sqrt(d+2)
The optimal scaling consists of two components:
1. **Correcting the (d+2) bias** [sqrt(d+2) factor]:
The empirical covariance underestimates the spatial extent of the region
by a factor of (d+2). To recover the true scale R from the empirical
variance R²/(d+2), we need to scale by (d+2).
2. **Diameter vs radius scaling** [factor of 2]:
For slice sampling, the initial stepping-out interval should be able to
span the full diameter (2R) of the constrained region, not just the
radius (R). Since we want directions scaled to 2R, and the Cholesky
factor L satisfies L @ L.T = cov, we need:
L_optimal = 2R I = 2 * sqrt(d+2) * L_empirical
Combined scaling: 2 * sqrt(d+2) on the Cholesky factor L, which is
equivalent to scaling the covariance by 4(d+2).
## Implementation
- Applied scaling in `sample_direction_from_covariance` after Cholesky
decomposition but before applying to unit direction vector
- Kept empirical covariance computation pure in NSS
- Updated docstrings to explain the scaling and its mathematical basis
- Simplified `compute_covariance_from_particles` docstring for consistency
## Validation
See `validate_covariance.py` for numerical validation of the d-ball
covariance formula with relative errors < 0.03% across dimensions 2-20.
Addresses review comment from PR #47 about optimal covariance scaling.
Adding to Slice refactor by including some additional standalone slicing compatibility, also merging changes from #27
…blackjax into refactor-slice-sampling-api
refactor to have an update strategy and from_mcmc initialisation
Removes the pid field from NSState as it was primarily used for diagnostic purposes and is not needed for the core nested sampling algorithm. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Rename NSParticle to StateWithLogLikelihood (clearer semantics) - Bundle (position, logdensity, loglikelihood) into StateWithLogLikelihood - Keep loglikelihood_birth in NSState (NS-specific bookkeeping) - Update NSInfo to include both particles and extracted arrays for evidence calculation - Fix covariance computation to use state.particles.position - Update finalise() to work with new structure This simplifies the inner kernel interface by having it operate on clean StateWithLogLikelihood objects while keeping NS-specific tracking separate. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
The loglikelihood is already available via particles.loglikelihood, so storing it separately in NSInfo was redundant. All utility functions now access it through the particles field. This eliminates duplication and enforces a single source of truth. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
All particle data is now bundled in StateWithLogLikelihood: - position, logdensity, loglikelihood, loglikelihood_birth NSState is now simpler: - particles (StateWithLogLikelihood) - logX, logZ, logZ_live (evidence tracking) - inner_kernel_params NSInfo is also simpler: - particles (contains all dead particle data) - update_info The birth likelihood is set correctly by inner kernels receiving loglikelihood_0 as a parameter. During initialization, the scalar birth likelihood is broadcast to match the batch dimension for vmap compatibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Use len(positions) instead of tree_flatten to get batch size. Much cleaner and more readable. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Major refactoring to flatten NS state structure and manage adaptive parameters externally: 1. Removed adaptive.py - adaptive parameter management now done in user loop 2. Removed inner_kernel_params from NSState - now passed as kernel argument 3. Eliminated StateWithLogLikelihood wrapper - NSState is now flat with direct fields 4. Updated kernel signature: kernel(rng_key, state, inner_kernel_params) 5. Added init_inner_kernel_params() and update_inner_kernel_params() in nss.py 6. Updated all modules to use flattened state (state.X instead of state.particles.X) Final NSState structure: - position: ArrayLikeTree - logdensity: Array - loglikelihood: Array - loglikelihood_birth: Array This follows BlackJAX patterns where: - Cumulative statistics (NSIntegrator) managed externally like SMC log-likelihood - Adaptive parameters passed separately like SMC temperature parameter - State is minimal and contains only particle data All tests pass with identical evidence calculations.
- Remove logX, logZ, logZ_live from NSState (now only particles + inner_kernel_params) - Create NSIntegrator class in blackjax/ns/integrator.py to track evidence integration - NSIntegrator accumulates cumulative statistics (evidence, prior volume) from dead particles - Update nested_sampling.py example to manage NSIntegrator externally - Clean separation: NSState = live points state, NSIntegrator = cumulative statistics This follows the pattern where derived/cumulative quantities are computed externally to the kernel, similar to how SMC returns normalizing_constant in SMCInfo.
Ensure loglikelihood_birth has the same dtype and shape as loglikelihood_values using ones_like. This is important for type consistency within StateWithLogLikelihood. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Reinstate internal adaptive and tuning
Flatten NSState and move to external parameter management
Extract evidence integration to NSIntegrator
Bundle particle data into StateWithLogLikelihood
| def init( | ||
| positions: ArrayLikeTree, | ||
| init_state_fn: Callable, | ||
| loglikelihood_birth: Array = jnp.nan, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sus default given type annotation
blackjax/ns/integrator.py
Outdated
| num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) | ||
| delta_logX = -1 / num_live | ||
| logX = integrator.logX + jnp.cumsum(delta_logX) | ||
| log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a utils log1mexp no?
|
|
||
| def _logmeanexp(x: Array) -> Array: | ||
| """Compute log(mean(exp(x))) in a numerically stable way.""" | ||
| n = jnp.array(x.shape[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explicit axis here making me nervous
…blackjax into refactor-slice-sampling-api
Summary
Refactors the slice sampling implementation to separate generic slice sampling from nested-sampling-specific constraint handling. This improves modularity and makes the slice sampler reusable while preserving all nested sampling functionality.
Architectural Change
Before: Generic slice sampler had built-in constraint handling
constraintfield inSliceState/SliceInfoconstraint_fn,constraintarray,strictflags passed everywhereAfter: Generic slice sampler accepts
slice_fn: (t) -> (state, is_accepted)slice_fnloglikelihood > threshold) in customslice_fnss.pyhas no knowledge of constraintsChanges
blackjax/mcmc/ss.py(-96 lines, net simplification)Removed constraint machinery:
constraintfield fromSliceStateandSliceInfoconstraint_fn,constraint,strictparametersNew API:
build_kernel(slice_fn, max_steps, max_shrinkage)- takesslice_fnas configurationslice_fn: Callable[[float], tuple[SliceState, bool]]- maps parametertto (state, is_accepted)(rng_key, state)instead of(rng_key, state, logdensity_fn, d, constraint_fn, constraint, strict)Renamed for consistency:
slicer→slice_fnthroughout (follows blackjax*_fnconvention)Bug fix:
build_hrss_kernelnow correctly acceptslogdensity_fnas parameter (was referencing it in closure without having it in scope)Removed:
default_stepper_fn(moved to nss.py where it's actually used)Added:
slice_fn: Callable[[float], tuple[SliceState, bool]]generate_slice_direction_fn: Callable[[PRNGKey], ArrayTree]max_stepsandmax_shrinkageparameters tohrss_as_top_level_apiblackjax/ns/nss.py(+96 lines, explicit constraint handling)Added
PartitionedSliceState:Moved from ss.py:
default_stepper_fn- only used in nested samplingExplicit constraint implementation:
Much clearer than the old
jnp.where(strict, constraint > value, constraint >= value)approach.Builds kernel per-step:
build_slice_kernel(slice_fn, max_steps, max_shrinkage)insideinner_kernelslice_fncaptures current statetests/mcmc/test_slice_sampling.py(complete rewrite)Removed tests for unexposed functions:
test_vertical_slicetest_default_stepper_fnUpdated all tests for new API:
constraint_fnparametersslice_fninterfaceAdded new tests:
test_build_kernel_with_custom_slice_fn- testsbuild_kerneldirectlytest_horizontal_slice_basic- testshorizontal_slicefunctiontest_multimodal_sampling- tests on bimodal distributionAll 12 tests pass
Benefits
slice_fnbuild_hrss_kernelnow works correctlyBreaking Changes
None for end users -
hrss_as_top_level_apimaintains the same interface.Internal API changes:
build_kernelnow takesslice_fnas first parameterTesting
🤖 Generated with Claude Code
Co-Authored-By: Claude [email protected]