Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions blackjax/mcmc/ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""

from functools import partial
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -87,7 +87,9 @@ class SliceInfo(NamedTuple):


def init(
position: ArrayTree, logdensity_fn: Callable, constraint_fn: Callable
position: ArrayTree,
logdensity_fn: Callable,
constraint_fn: Optional[Callable] = None,
) -> SliceState:
"""Initialize the Slice Sampler state.

Expand All @@ -103,7 +105,9 @@ def init(
SliceState
The initial state of the Slice Sampler.
"""
return SliceState(position, logdensity_fn(position), constraint_fn(position))
logp = logdensity_fn(position)
constraint_val = (constraint_fn or (lambda _: jnp.array([])))(position)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the blackjaxy way to do it, or should this default go at line 92?

return SliceState(position, logp, constraint_val)


def build_kernel(
Expand Down Expand Up @@ -266,6 +270,7 @@ def build_hrss_kernel(
generate_slice_direction_fn: Callable,
stepper_fn: Callable,
max_steps: int = 10,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults are important -- is max_steps 10 well motivated (and not e.g. 16, or 1024)?

max_shrinkage: int = 100,
) -> Callable:
"""Build a Hit-and-Run Slice Sampling kernel.

Expand All @@ -292,7 +297,7 @@ def build_hrss_kernel(
A kernel function that takes a PRNG key, the current `SliceState`, and
the log-density function, and returns a new `SliceState` and `SliceInfo`.
"""
slice_kernel = build_kernel(stepper_fn, max_steps)
slice_kernel = build_kernel(stepper_fn, max_steps, max_shrinkage)

def kernel(
rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable
Expand Down Expand Up @@ -359,8 +364,7 @@ def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array:


def hrss_as_top_level_api(
logdensity_fn: Callable,
cov: Array,
logdensity_fn: Callable, cov: Array, max_steps: int = 10, max_shrinkage: int = 100
) -> SamplingAlgorithm:
"""Creates a Hit-and-Run Slice Sampling algorithm.

Expand All @@ -384,7 +388,9 @@ def hrss_as_top_level_api(
the configured Hit-and-Run Slice Sampler.
"""
generate_slice_direction_fn = partial(sample_direction_from_covariance, cov=cov)
kernel = build_hrss_kernel(generate_slice_direction_fn, default_stepper_fn)
kernel = build_hrss_kernel(
generate_slice_direction_fn, default_stepper_fn, max_steps, max_shrinkage
)
init_fn = partial(init, logdensity_fn=logdensity_fn)
step_fn = partial(kernel, logdensity_fn=logdensity_fn)
return SamplingAlgorithm(init_fn, step_fn)
Loading