diff --git a/blackjax/mcmc/ss.py b/blackjax/mcmc/ss.py index 50764ad4a..3a267cde0 100644 --- a/blackjax/mcmc/ss.py +++ b/blackjax/mcmc/ss.py @@ -27,7 +27,7 @@ """ from functools import partial -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Optional import jax import jax.numpy as jnp @@ -87,7 +87,9 @@ class SliceInfo(NamedTuple): def init( - position: ArrayTree, logdensity_fn: Callable, constraint_fn: Callable + position: ArrayTree, + logdensity_fn: Callable, + constraint_fn: Optional[Callable] = None, ) -> SliceState: """Initialize the Slice Sampler state. @@ -103,7 +105,9 @@ def init( SliceState The initial state of the Slice Sampler. """ - return SliceState(position, logdensity_fn(position), constraint_fn(position)) + logp = logdensity_fn(position) + constraint_val = (constraint_fn or (lambda _: jnp.array([])))(position) + return SliceState(position, logp, constraint_val) def build_kernel( @@ -266,6 +270,7 @@ def build_hrss_kernel( generate_slice_direction_fn: Callable, stepper_fn: Callable, max_steps: int = 10, + max_shrinkage: int = 100, ) -> Callable: """Build a Hit-and-Run Slice Sampling kernel. @@ -292,7 +297,7 @@ def build_hrss_kernel( A kernel function that takes a PRNG key, the current `SliceState`, and the log-density function, and returns a new `SliceState` and `SliceInfo`. """ - slice_kernel = build_kernel(stepper_fn, max_steps) + slice_kernel = build_kernel(stepper_fn, max_steps, max_shrinkage) def kernel( rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable @@ -359,8 +364,7 @@ def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array: def hrss_as_top_level_api( - logdensity_fn: Callable, - cov: Array, + logdensity_fn: Callable, cov: Array, max_steps: int = 10, max_shrinkage: int = 100 ) -> SamplingAlgorithm: """Creates a Hit-and-Run Slice Sampling algorithm. @@ -384,7 +388,9 @@ def hrss_as_top_level_api( the configured Hit-and-Run Slice Sampler. """ generate_slice_direction_fn = partial(sample_direction_from_covariance, cov=cov) - kernel = build_hrss_kernel(generate_slice_direction_fn, default_stepper_fn) + kernel = build_hrss_kernel( + generate_slice_direction_fn, default_stepper_fn, max_steps, max_shrinkage + ) init_fn = partial(init, logdensity_fn=logdensity_fn) step_fn = partial(kernel, logdensity_fn=logdensity_fn) return SamplingAlgorithm(init_fn, step_fn)