diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index eaa484a13f..bae204769b 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,7 +72,7 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray -from pymc.backends.zarr import ZarrTrace +from pymc.backends.zarr import TraceAlreadyInitialized, ZarrTrace from pymc.blocking import PointType from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep @@ -132,15 +132,41 @@ def init_traces( ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" if isinstance(backend, ZarrTrace): - backend.init_trace( - chains=chains, - draws=expected_length - tune, - tune=tune, - step=step, - model=model, - vars=trace_vars, - test_point=initial_point, - ) + try: + backend.init_trace( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=trace_vars, + test_point=initial_point, + ) + except TraceAlreadyInitialized: + # Trace has already been initialized. We need to make sure that the + # tracked variable names and the number of chains match, and then resize + # the zarr groups to the desired number of draws and tune. + backend.assert_model_and_step_are_compatible( + step=step, + model=model, + vars=trace_vars, + ) + assert backend.posterior.chain.size == chains, ( + f"The requested number of chains {chains} does not match the number " + f"of chains stored in the trace ({backend.posterior.chain.size})." + ) + vars, var_names = backend.parse_varnames(model=model, vars=trace_vars) + backend.link_model_and_step( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=vars, + var_names=var_names, + test_point=initial_point, + ) + backend.resize(tune=tune, draws=expected_length - tune) return None, backend.straces if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 9b7664c504..f64336e11a 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools +import re + from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any +from typing import Any, cast import arviz as az import numpy as np @@ -35,7 +38,9 @@ from pymc.step_methods.compound import ( BlockedStep, CompoundStep, + CompoundStepState, StatsBijection, + StepMethodState, get_stats_dtypes_shapes_from_steps, ) from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name @@ -61,6 +66,9 @@ _zarr_available = False +class TraceAlreadyInitialized(RuntimeError): ... + + class ZarrChain(BaseTrace): """Interface object to interact with a single chain in a :class:`~.ZarrTrace`. @@ -101,30 +109,44 @@ def __init__( test_point: dict[str, np.ndarray] | None = None, draws_per_chunk: int = 1, fn: Callable | None = None, + include_transformed: bool = True, ): if not _zarr_available: raise RuntimeError("You must install zarr to be able to create ZarrChain instances") super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) + self.include_transformed = include_transformed self._step_method: BlockedStep | CompoundStep | None = None self.unconstrained_variables = { var.name for var in self.vars if is_transformed_name(var.name) - } + } | {inp.name for inp in self.fn.maker.inputs if is_transformed_name(inp.name)} self.draw_idx = 0 self._buffers: dict[str, dict[str, list]] = { "posterior": {}, "sample_stats": {}, + "warmup_posterior": {}, + "warmup_sample_stats": {}, } self._buffered_draws = 0 self.draws_per_chunk = int(draws_per_chunk) assert self.draws_per_chunk > 0 + self._warmup_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="warmup_posterior", mode="a" + ) self._posterior = zarr.open_group( store, synchronizer=synchronizer, path="posterior", mode="a" ) - if self.unconstrained_variables: + if self.unconstrained_variables and include_transformed: + self._warmup_unconstrained_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="warmup_unconstrained_posterior", mode="a" + ) self._unconstrained_posterior = zarr.open_group( store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" ) self._buffers["unconstrained_posterior"] = {} + self._buffers["warmup_unconstrained_posterior"] = {} + self._warmup_sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="warmup_sample_stats", mode="a" + ) self._sample_stats = zarr.open_group( store, synchronizer=synchronizer, path="sample_stats", mode="a" ) @@ -141,10 +163,12 @@ def link_stepper(self, step_method: BlockedStep | CompoundStep): """ self._step_method = step_method - def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None, tune: int = 0): # type: ignore[override] self.chain = chain - self.total_draws = draws - self.draws_until_flush = min([self.draws_per_chunk, draws - self.draw_idx]) + self.draws = draws + self.tune = tune + self.draw_idx = self._sampling_state.draw_idx[chain] or 0 + self.update_draws_until_flush() self.clear_buffers() def clear_buffers(self): @@ -153,7 +177,8 @@ def clear_buffers(self): self._buffered_draws = 0 def buffer(self, group, var_name, value): - buffer = self._buffers[group] + group_name = f"warmup_{group}" if self.in_warmup else group + buffer = self._buffers[group_name] if var_name not in buffer: buffer[var_name] = [] buffer[var_name].append(value) @@ -178,16 +203,22 @@ def record( :meth:`~ZarrChain.flush` """ unconstrained_variables = self.unconstrained_variables + include_transformed = self.include_transformed for var_name, var_value in zip(self.varnames, self.fn(**draw)): if var_name in unconstrained_variables: - self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) + if include_transformed: + self.buffer( + group="unconstrained_posterior", + var_name=var_name, + value=var_value, + ) else: self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: - self.flush() + self.flush(draw) return True return None @@ -212,12 +243,12 @@ def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None): self.store_sampling_state(step.sampling_state) self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) - def store_sampling_state(self, sampling_state): + def store_sampling_state(self, sampling_state: StepMethodState | CompoundStepState): self._sampling_state.sampling_state.set_coordinate_selection( self.chain, np.array([sampling_state], dtype="object") ) - def flush(self): + def flush(self, mcmc_point: Mapping[str, np.ndarray] | None = None): """Write the data stored in the internal buffer to the desired zarr store. After writing the draws and stats returned by each step of the step method, @@ -225,7 +256,8 @@ def flush(self): the number of steps until the next flush is determined. """ chain = self.chain - draw_slice = slice(self.draw_idx, self.draw_idx + self.draws_until_flush) + offset = 0 if self.in_warmup else self.tune + draw_slice = slice(self.draw_idx - offset, self.draw_idx + self.draws_until_flush - offset) for group_name, buffer in self._buffers.items(): group = getattr(self, f"_{group_name}") for var_name, var_value in buffer.items(): @@ -236,7 +268,67 @@ def flush(self): self.draw_idx += self.draws_until_flush self.record_sampling_state() self.clear_buffers() - self.draws_until_flush = min([self.draws_per_chunk, self.total_draws - self.draw_idx]) + self.update_draws_until_flush() + if mcmc_point is not None: + self.set_mcmc_point(mcmc_point) + + def update_draws_until_flush(self): + self.in_warmup = self.draw_idx < self.tune + self.draws_until_flush = min( + [ + self.draws_per_chunk, + self.tune - self.draw_idx + if self.in_warmup + else self.draws + self.tune - self.draw_idx, + ] + ) + + def completed_draws_and_divergences(self, chain_specific: bool = False) -> tuple[int, int]: + """Get number of completed draws and divergences in the traces. + + This is a helper function to start the ProgressBarManager when resuming sampling + from an existing trace. + + Parameters + ---------- + chain_specific : bool + If ``True``, only the completed draws and divergences on the current chain + are returned. If ``False``, the draws and divergences across all chains are + returned + + Returns + ------- + draws : int + Number of draws in the current chain or across all chains. + divergences : int + Number of divergences in the current chain or across all chains. + """ + # No need to iterate over ZarrChain instances because the zarr group is + # shared between them + idx: int | slice + if chain_specific: + idx = self.chain + else: + idx = slice(None) + diverging_stat_sums = [ + np.sum(array[idx]) + for stat_name, array in self._sample_stats.arrays() + if "diverging" in stat_name + ] + return int(np.sum(self._sampling_state.draw_idx[idx])), int(sum(diverging_stat_sums)) + + def set_mcmc_point(self, mcmc_point: Mapping[str, np.ndarray]): + for var_name, value in mcmc_point.items(): + self._sampling_state.mcmc_point[var_name].set_basic_selection( + self.chain, + value, + ) + + def get_mcmc_point(self) -> dict[str, np.ndarray]: + return { + str(var_name): np.asarray(array[self.chain]) + for var_name, array in self._sampling_state.mcmc_point.arrays() + } FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None @@ -286,10 +378,9 @@ class ZarrTrace: | |--> _sampling_state The root group is created when the ``ZarrTrace`` object is initialized. The rest of - the groups are created once :meth:`~ZarrChain.init_trace` is called with a few exceptions: - unconstrained_posterior is only created if ``include_transformed = True``, and the - groups prefixed with ``warmup_`` are created only after calling - :meth:`~ZarrTrace.split_warmup_groups`. + the groups are created once :meth:`~ZarrChain.init_trace` is called with the exception + that the unconstrained_posterior and warmup_unconstrained_posterior groups are + only created if ``include_transformed = True``. Since ``ZarrTrace`` objects are intended to be as close to :class:`arviz.InferenceData` objects as possible, the groups store the dimension @@ -369,7 +460,25 @@ def __init__( self.include_transformed = include_transformed - self._is_base_setup = False + @property + def is_root_populated(self) -> bool: + groups = set(self.root.group_keys()) + out = groups >= { + "posterior", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + "_sampling_state", + } + if self.include_transformed and any( + is_transformed_name(name) for name in getattr(self, "varnames", []) + ): + out &= groups >= {"unconstrained_posterior", "warmup_unconstrained_posterior"} + return out + + @property + def _is_base_setup(self) -> bool: + return self.is_root_populated and getattr(self, "straces", 0) > 0 def groups(self) -> list[str]: return [str(group_name) for group_name, _ in self.root.groups()] @@ -378,14 +487,26 @@ def groups(self) -> list[str]: def posterior(self) -> Group: return self.root.posterior + @property + def warmup_posterior(self) -> Group: + return self.root.warmup_posterior + @property def unconstrained_posterior(self) -> Group: return self.root.unconstrained_posterior + @property + def warmup_unconstrained_posterior(self) -> Group: + return self.root.warmup_unconstrained_posterior + @property def sample_stats(self) -> Group: return self.root.sample_stats + @property + def warmup_sample_stats(self) -> Group: + return self.root.warmup_sample_stats + @property def constant_data(self) -> Group: return self.root.constant_data @@ -398,6 +519,22 @@ def observed_data(self) -> Group: def _sampling_state(self) -> Group: return self.root._sampling_state + def parse_varnames( + self, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + ) -> tuple[list[TensorVariable], list[str]]: + if vars is None: + vars = modelcontext(model).unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" + var_names = get_default_varnames( + [var.name for var in vars], include_transformed=self.include_transformed + ) + vars = [var for var in vars if var.name in var_names] + return vars, var_names + def init_trace( self, chains: int, @@ -411,16 +548,17 @@ def init_trace( """Initialize the trace groups and arrays. This function creates and fills with default values the groups below the - ``ZarrTrace.root`` group. It creates the ``constant_data``, ``observed_data``, + ``ZarrTrace.root`` group. It creates the ``constant_data`` (only if the model + has ``Data`` containers), ``observed_data`` (only if the model has observed), ``posterior``, ``unconstrained_posterior`` (if ``include_transformed = True``), ``sample_stats``, and ``_sampling_state`` zarr groups, and all of the relevant arrays that must be stored there. Every array in the posterior and sample stats groups will have the - (chains, tune + draws) batch dimensions to the left of the core dimensions of + (chains, draws) batch dimensions to the left of the core dimensions of the model's random variable or the step method's stat shape. The warmup (tuning - draws) and the posterior samples are split at a later stage, once - :meth:`~ZarrTrace.split_warmup_groups` is called. + draws) posterior and sample stats will have (chains, tune) batch dimensions + instead. After the creation if the zarr hierarchies, it initializes the list of :class:`~pymc.backends.zarr.Zarrchain` instances (one for each chain) under the @@ -447,23 +585,14 @@ def init_trace( from :class:`~.BaseTrace`, which uses it to determine the shape and dtype of `vars`. """ - if self._is_base_setup: - raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover + if self.is_root_populated: + raise TraceAlreadyInitialized("The ZarrTrace has already been initialized") model = modelcontext(model) - self.model = model self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) - if vars is None: - vars = model.unobserved_value_vars + vars, varnames = self.parse_varnames(model, vars) - unnamed_vars = {var for var in vars if var.name is None} - assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" - self.varnames = get_default_varnames( - [var.name for var in vars], include_transformed=self.include_transformed - ) - self.vars = [var for var in vars if var.name in self.varnames] - - self.fn = model.compile_fn( - self.vars, + fn = model.compile_fn( + vars, inputs=model.value_vars, on_unused_input="ignore", point_fn=False, @@ -473,7 +602,7 @@ def init_trace( # information. if test_point is None: test_point = model.initial_point() - var_values = list(zip(self.varnames, self.fn(**test_point))) + var_values = list(zip(varnames, fn(**test_point))) self.var_dtype_shapes = { var: (value.dtype, value.shape) for var, value in var_values @@ -494,34 +623,48 @@ def init_trace( self.create_group( name="constant_data", - data_dict=find_constants(self.model), + data_dict=find_constants(model), ) self.create_group( name="observed_data", - data_dict=find_observations(self.model), + data_dict=find_observations(model), ) - # Create the posterior that includes warmup draws + # Create the posterior and warmup posterior groups self.init_group_with_empty( group=self.root.create_group(name="posterior", overwrite=True), var_dtype_and_shape=self.var_dtype_shapes, chains=chains, - draws=tune + draws, + draws=draws, + extra_var_attrs=extra_var_attrs, + ) + self.init_group_with_empty( + group=self.root.create_group(name="warmup_posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=tune, extra_var_attrs=extra_var_attrs, ) - # Create the unconstrained posterior group that includes warmup draws + # Create the unconstrained posterior and warmup groups if self.include_transformed and self.unc_var_dtype_shapes: self.init_group_with_empty( group=self.root.create_group(name="unconstrained_posterior", overwrite=True), var_dtype_and_shape=self.unc_var_dtype_shapes, chains=chains, - draws=tune + draws, + draws=draws, + extra_var_attrs=extra_unc_var_attrs, + ) + self.init_group_with_empty( + group=self.root.create_group(name="warmup_unconstrained_posterior", overwrite=True), + var_dtype_and_shape=self.unc_var_dtype_shapes, + chains=chains, + draws=tune, extra_var_attrs=extra_unc_var_attrs, ) - # Create the sample stats that include warmup draws + # Create the sample stats and warmup groups stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) @@ -529,11 +672,61 @@ def init_trace( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, chains=chains, - draws=tune + draws, + draws=draws, + ) + self.init_group_with_empty( + group=self.root.create_group(name="warmup_sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=tune, ) - self.init_sampling_state_group(tune=tune, chains=chains) + self.init_sampling_state_group( + tune=tune, + draws=draws, + chains=chains, + mcmc_point=test_point, + ) + self.link_model_and_step( + chains=chains, + draws=draws, + tune=tune, + step=step, + model=model, + vars=vars, + var_names=varnames, + test_point=test_point, + fn=fn, + ) + def link_model_and_step( + self, + chains: int, + draws: int, + tune: int, + step: BlockedStep | CompoundStep, + vars: Sequence[TensorVariable], + var_names: Sequence[str], + model: Model | None = None, + test_point: dict[str, np.ndarray] | None = None, + fn: Callable | None = None, + ): + model = modelcontext(model) + self.model = model + self.varnames = var_names + self.vars = vars + if fn is None: + self.fn = cast( + Callable, + model.compile_fn( + self.vars, + inputs=model.value_vars, + on_unused_input="ignore", + point_fn=False, + ), + ) + else: + self.fn = fn self.straces = [ ZarrChain( store=self.root.store, @@ -544,31 +737,12 @@ def init_trace( stats_bijection=StatsBijection(step.stats_dtypes), draws_per_chunk=self.draws_per_chunk, fn=self.fn, + include_transformed=self.include_transformed, ) for _ in range(chains) ] for chain, strace in enumerate(self.straces): - strace.setup(draws=tune + draws, chain=chain, sampler_vars=None) - - def split_warmup_groups(self): - """Split the warmup and standard groups. - - This method takes the entries in the arrays in the posterior, sample_stats - and unconstrained_posterior that happened in the tuning phase and moves them - into the warmup_ groups. If the ``warmup_posterior`` group already exists, then - nothing is done. - - See Also - -------- - :meth:`~ZarrTrace.split_warmup` - """ - if "warmup_posterior" not in self.groups(): - self.split_warmup("posterior", error_if_already_split=False) - self.split_warmup("sample_stats", error_if_already_split=False) - try: - self.split_warmup("unconstrained_posterior", error_if_already_split=False) - except KeyError: - pass + strace.setup(draws=draws, tune=tune, chain=chain, sampler_vars=None) @property def tuning_steps(self): @@ -579,6 +753,15 @@ def tuning_steps(self): "ZarrTrace has not been initialized and there is no tuning step information available" ) + @property + def draws(self): + try: + return int(self._sampling_state.draws.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no draw information available" + ) + @property def sampling_time(self): try: @@ -592,7 +775,9 @@ def sampling_time(self): def sampling_time(self, value): self._sampling_state.sampling_time.set_basic_selection((), float(value)) - def init_sampling_state_group(self, tune: int, chains: int): + def init_sampling_state_group( + self, tune: int, draws: int, chains: int, mcmc_point: dict[str, np.ndarray] + ): state = self.root.create_group(name="_sampling_state", overwrite=True) sampling_state = state.empty( name="sampling_state", @@ -623,6 +808,14 @@ def init_sampling_state_group(self, tune: int, chains: int): fill_value=0, compressor=self.compressor, ) + state.array( + name="draws", + data=draws, + overwrite=True, + dtype="int", + fill_value=0, + compressor=self.compressor, + ) state.array( name="sampling_time", data=0.0, @@ -653,6 +846,19 @@ def init_sampling_state_group(self, tune: int, chains: int): shape=(0,), ) + zarr_mcmc_point = state.create_group("mcmc_point", overwrite=True) + for var_name, test_value in mcmc_point.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(test_value.dtype) + zarr_mcmc_point.full( + name=var_name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, *test_value.shape), + chunks=(1, *test_value.shape), + compressor=self.compressor, + ) + def init_group_with_empty( self, group: Group, @@ -680,10 +886,15 @@ def init_group_with_empty( group_coords[dim] = self.coords[dim] except KeyError: dims = [] + if len(shape) > 0: + self.vars_to_dims[name] = [] for i, shape_i in enumerate(shape): dim = f"{name}_dim_{i}" + coord = np.arange(shape_i, dtype="int") dims.append(dim) - group_coords[dim] = np.arange(shape_i, dtype="int") + self.vars_to_dims[name].append(dim) + group_coords[dim] = coord + self.coords[dim] = coord dims = ("chain", "draw", *dims) attrs = extra_var_attrs[name] if extra_var_attrs is not None else {} attrs.update({"_ARRAY_DIMENSIONS": dims}) @@ -719,10 +930,15 @@ def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | N group_coords[dim] = self.coords[dim] except KeyError: dims = [] + if var_value.ndim > 0: + self.vars_to_dims[var_name] = [] for i in range(var_value.ndim): dim = f"{var_name}_dim_{i}" + coord = np.arange(var_value.shape[i], dtype="int") dims.append(dim) - group_coords[dim] = np.arange(var_value.shape[i], dtype="int") + self.vars_to_dims[var_name].append(dim) + group_coords[dim] = coord + self.coords[dim] = coord array.attrs.update({"_ARRAY_DIMENSIONS": dims}) for dim, coord in group_coords.items(): array = group.array( @@ -734,86 +950,103 @@ def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | N array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) return group - def split_warmup(self, group_name: str, error_if_already_split: bool = True): - """Split the arrays of a group into the warmup and regular groups. - - This function takes the first ``self.tuning_steps`` draws of supplied - ``group_name`` and moves them into a new zarr group called - ``f"warmup_{group_name}"``. - - Parameters - ---------- - group_name : str - The name of the group that should be split. - error_if_already_split : bool - If ``True`` and if the ``f"warmup_{group_name}"`` group already exists in - the root hierarchy, a ``ValueError`` is raised. If this flag is ``False`` - but the warmup group already exists, the contents of that group are - overwritten. - """ - if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { - group_name for group_name, _ in self.root.groups() - }: - raise RuntimeError(f"Warmup data for {group_name} has already been split") - posterior_group = self.root[group_name] - tune = self.tuning_steps - warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) - if tune == 0: - try: - self.root.pop(f"{WARMUP_TAG}{group_name}") - except KeyError: - pass - return - for name, array in posterior_group.arrays(): - array_attrs = array.attrs.asdict() - if name == "draw": - warmup_array = warmup_group.array( - name="draw", - data=np.arange(tune), - dtype="int", - compressor=self.compressor, + def resize( + self, + tune: int | None = None, + draws: int | None = None, + ) -> "ZarrTrace": + if not self.is_root_populated: + raise RuntimeError( + "The ZarrTrace has not been initialized yet. You must call resize on " + "an instance that has already been initialized." + ) + old_tuning = self.tuning_steps + old_draws = self.draws + desired_tune = tune or old_tuning + desired_draws = draws or old_draws + draws_in_chains = self._sampling_state.draw_idx[:] + + # For us to be able to resize, a few conditions must be met: + # 1. If we want to change the number of tuning steps, the draws_in_chains must + # not be bigger than the old tune, and it must not be bigger than the desired + # tune. If the first condition weren't true, the sampler would have already + # stopped tuning, and it would be wrong to relabel some samples to belong to + # the tuning phase. If the second condition weren't true, the sampler would + # have continued tuning instead after the desired number of tuning steps had + # been taken. + # 2. If we want to change the number of posterior draws, the draws_in_chains + # minus the old number of tuning steps must be less or equal to the desired + # number of draws. If this condition is not met, the sampler will have taken + # extra steps and we wont have stored the sampling state information at the + # end of the desired number of draws. + change_tune = False + change_draws = False + if old_tuning != desired_tune: + # Attempting to change the number of tuning steps + if any(draws_in_chains > old_tuning): + raise ValueError( + "Cannot change the number of tuning steps in the trace. " + "Some chains have finished their tuning phase and have " + "already performed steps in the posterior sampling regime." ) - posterior_array = posterior_group.array( - name=name, - data=np.arange(len(array) - tune), - dtype="int", - overwrite=True, - compressor=self.compressor, + elif any(draws_in_chains >= desired_tune): + raise ValueError( + "Cannot change the number of tuning steps in the trace. " + "Some chains have already taken more steps than the desired number " + "of tuning steps. Please increase the desired number of tuning " + f"steps to at least {max(draws_in_chains)}." ) - posterior_array.attrs.update(array_attrs) - else: - dims = array.attrs["_ARRAY_DIMENSIONS"] - warmup_idx: slice | tuple[slice, slice] - if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: - must_overwrite_posterior = True - warmup_idx = (slice(None), slice(None, tune, None)) - posterior_idx = (slice(None), slice(tune, None, None)) - else: - must_overwrite_posterior = False - warmup_idx = slice(None) - fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) - warmup_array = warmup_group.array( - name=name, - data=array[warmup_idx], - chunks=array.chunks, - dtype=dtype, - fill_value=fill_value, - object_codec=object_codec, - compressor=self.compressor, + change_tune = True + if old_draws != desired_draws: + # Attempting to change the number of draws + if any((draws_in_chains - old_tuning) > desired_draws): + raise ValueError( + "Cannot change the number of draws in the trace. " + "Some chains have already taken more steps than the desired number " + "of draws. Please increase the desired number of draws " + f"to at least {max(draws_in_chains) - old_tuning}." ) - if must_overwrite_posterior: - posterior_array = posterior_group.array( - name=name, - data=array[posterior_idx], - chunks=array.chunks, - dtype=dtype, - fill_value=fill_value, - object_codec=object_codec, - overwrite=True, - compressor=self.compressor, - ) - posterior_array.attrs.update(array_attrs) - warmup_array.attrs.update(array_attrs) + change_draws = True + if change_tune: + self._resize_tuning_steps(desired_tune) + if change_draws: + self._resize_draws(desired_draws) + return self + + def _resize_tuning_steps(self, desired_tune: int): + groups = ["warmup_posterior", "warmup_sample_stats"] + if "warmup_unconstrained_posterior" in dict(self.root.groups()): + groups.append("warmup_unconstrained_posterior") + for group in groups: + self._resize_arrays_in_group(group=group, axis=1, new_size=desired_tune) + zarr_draw = getattr(self.root, group).draw + zarr_draw.resize(desired_tune) + zarr_draw.set_basic_selection( + slice(None), np.arange(desired_tune, dtype=zarr_draw.dtype) + ) + self._sampling_state.tuning_steps.set_basic_selection((), desired_tune) + + def _resize_draws(self, desired_draws: int): + groups = ["posterior", "sample_stats"] + if "unconstrained_posterior" in dict(self.root.groups()): + groups.append("unconstrained_posterior") + for group in groups: + self._resize_arrays_in_group(group=group, axis=1, new_size=desired_draws) + zarr_draw = getattr(self.root, group).draw + zarr_draw.resize(desired_draws) + zarr_draw.set_basic_selection( + slice(None), np.arange(desired_draws, dtype=zarr_draw.dtype) + ) + self._sampling_state.draws.set_basic_selection((), desired_draws) + + def _resize_arrays_in_group(self, group: str, axis: int, new_size: int): + zarr_group: Group = getattr(self.root, group) + for _, array in zarr_group.arrays(): + dims = array.attrs.get("_ARRAY_DIMENSIONS", []) + if len(dims) >= 2 and dims[1] == "draw": + new_shape = list(array.shape) + new_shape[axis] = new_size + array.resize(new_shape) def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: """Convert ``ZarrTrace`` to :class:`~.arviz.InferenceData`. @@ -837,7 +1070,6 @@ def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: than the calling ``ZarrTrace``, so future changes to the ``ZarrTrace`` won't be automatically reflected in the returned ``InferenceData`` object. """ - self.split_warmup_groups() # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata consolidated_root = zarr.consolidate_metadata(self.root.store) # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view @@ -860,3 +1092,195 @@ def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: data.attrs = make_attrs(attrs=attrs, library=pymc) groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data return az.InferenceData(**groups) + + @classmethod + def from_store( + cls: type["ZarrTrace"], + store: BaseStore | MutableMapping, + synchronizer: Synchronizer | None = None, + ) -> "ZarrTrace": + if not _zarr_available: + raise RuntimeError("You must install zarr to be able to create ZarrTrace instances") + self: ZarrTrace = object.__new__(cls) + self.root = zarr.group( + store=store, + overwrite=False, + synchronizer=synchronizer, + ) + self.synchronizer = synchronizer + self.compressor = default_compressor + + groups = set(self.root.group_keys()) + assert groups >= { + "posterior", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + "constant_data", + "observed_data", + "_sampling_state", + } + + if "posterior" in groups: + for _, array in self.posterior.arrays(): + dims = array.attrs.get("_ARRAY_DIMENSIONS", []) + if len(dims) >= 2 and dims[1] == "draw": + draws_per_chunk = int(array.chunks[1]) + break + else: + draws_per_chunk = 1 + + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk >= 1 + + self.include_transformed = "unconstrained_posterior" in groups + arrays = itertools.chain( + self.posterior.arrays(), + self.constant_data.arrays(), + self.observed_data.arrays(), + ) + if self.include_transformed: + arrays = itertools.chain(arrays, self.unconstrained_posterior.arrays()) + varnames = [] + coords = {} + vars_to_dims = {} + for name, array in arrays: + dims = array.attrs["_ARRAY_DIMENSIONS"] + if dims[:2] == ["chain", "draw"]: + # Random Variable + vars_to_dims[name] = dims[2:] + varnames.append(name) + elif len(dims) == 1 and name == dims[0]: + # Coordinate + # We store all model coordinates, which means we have to exclude chain + # and draw + if name not in ["chain", "draw"]: + coords[name] = np.asarray(array) + else: + # Constant data or observation + vars_to_dims[name] = dims + self.varnames = varnames + self.coords = coords + self.vars_to_dims = vars_to_dims + return self + + def assert_model_and_step_are_compatible( + self, + step: BlockedStep | CompoundStep, + model: Model, + vars: list[TensorVariable] | None = None, + ): + zarr_groups = set(self.root.group_keys()) + arrays_ = itertools.chain( + self.posterior.arrays(), + self.constant_data.arrays() if "constant_data" in zarr_groups else [], + self.observed_data.arrays() if "observed_data" in zarr_groups else [], + ) + if self.include_transformed: + arrays_ = itertools.chain(arrays_, self.unconstrained_posterior.arrays()) + arrays = list(arrays_) + zarr_varnames = [] + zarr_coords = {} + zarr_vars_to_dims = {} + zarr_deterministics = [] + zarr_free_vars = [] + for name, array in arrays: + dims = array.attrs["_ARRAY_DIMENSIONS"] + if dims[:2] == ["chain", "draw"]: + # Random Variable + zarr_vars_to_dims[name] = dims[2:] + zarr_varnames.append(name) + if array.attrs["kind"] == "freeRV": + zarr_free_vars.append(name) + else: + zarr_deterministics.append(name) + elif len(dims) == 1 and name == dims[0]: + # Coordinate + if name not in ["chain", "draw"]: + zarr_coords[name] = np.asarray(array) + else: + # Constant data or observation + zarr_vars_to_dims[name] = dims + zarr_constant_data = ( + [name for name in self.constant_data.array_keys() if name not in zarr_coords] + if "constant_data" in zarr_groups + else [] + ) + zarr_observed_data = ( + [name for name in self.observed_data.array_keys() if name not in zarr_coords] + if "observed_data" in zarr_groups + else [] + ) + autogenerated_dims = {dim for dim in zarr_coords if re.search(r"_dim_\d+$", dim)} + + # Check deterministics, free RVs and transformed RVs + _, var_names = self.parse_varnames(model, vars) + assert set(var_names) == set(zarr_free_vars + zarr_deterministics), ( + "The model deterministics and random variables given the sampled var_names " + "do not match with the stored deterministics variables in the trace." + ) + for name, array in arrays: + if name not in zarr_free_vars or name not in zarr_deterministics: + continue + model_var = model[name] + assert np.dtype(model_var.dtype) == np.dtype(array.dtype), ( + "The model deterministics and random variables given the sampled " + "var_names do not match with the stored deterministics variables in " + "the trace." + ) + + # Check coordinates + assert (set(zarr_coords) - set(autogenerated_dims)) == set(model.coords) and all( + np.array_equal(np.asarray(zarr_coords[dim]), np.asarray(coord)) + for dim, coord in model.coords.items() + ), "Model coordinates don't match the coordinates stored in the trace" + vars_to_explicit_dims = {} + for name, dims in zarr_vars_to_dims.items(): + if len(dims) == 0 or all(dim in autogenerated_dims for dim in dims): + # These variables wont be included in the named_vars_to_dims + continue + vars_to_explicit_dims[name] = [ + dim if dim not in autogenerated_dims else None for dim in dims + ] + assert set(vars_to_explicit_dims) == set(model.named_vars_to_dims) and all( + vars_to_explicit_dims[name] == list(dims) + for name, dims in model.named_vars_to_dims.items() + ), "Some model variables have different dimensions than those stored in the trace." + + # Check constant data + model_constant_data = find_constants(model) + assert set(zarr_constant_data) == set(model_constant_data), ( + "The model constant data does not match with the stored constant data" + ) + for name, model_data in model_constant_data.items(): + assert np.array_equal(self.constant_data[name], model_data, equal_nan=True), ( + "The model constant data does not match with the stored constant data" + ) + + # Check observed data + model_observed_data = find_observations(model) + assert set(zarr_observed_data) == set(model_observed_data), ( + "The model observed data does not match with the stored observed data" + ) + for name, model_data in model_observed_data.items(): + assert np.array_equal(self.observed_data[name], model_data, equal_nan=True), ( + "The model observed data does not match with the stored observed data" + ) + + # Check sample stats given the step method + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + assert (set(stats_dtypes_shapes) | {"chain", "draw"}) == set( + self.sample_stats.array_keys() + ), "The step method sample stats do not match the ones stored in the trace." + for name, array in self.sample_stats.arrays(): + if name in ("chain", "draw"): + continue + assert np.dtype(stats_dtypes_shapes[name][0]) == np.dtype(array.dtype), ( + "The step method sample stats do not match the ones stored in the trace." + ) + + assert step.sampling_state.is_compatible(self._sampling_state.sampling_state[0]), ( + "The state method sampling state class is incompatible with what's stored in the trace." + ) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..9457273944 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -993,11 +993,8 @@ def _sample_return( Final step of `pm.sampler`. """ if isinstance(traces, ZarrTrace): - # Split warmup from posterior samples - traces.split_warmup_groups() - # Set sampling time - traces.sampling_time = t_sampling + traces.sampling_time = traces.sampling_time + t_sampling # Compute number of actual draws per chain total_draws_per_chain = traces._sampling_state.draw_idx[:] @@ -1160,13 +1157,24 @@ def _sample_many( with progress_manager: for i in range(chains): + trace = traces[i] + if isinstance(trace, ZarrChain): + progress_manager.set_initial_state(*trace.completed_draws_and_divergences()) + progress_manager._progress.update( + progress_manager.tasks[i], + draws=progress_manager.completed_draws + if progress_manager.combined_progress + else progress_manager.draws, + divergences=progress_manager.divergences, + refresh=True, + ) step.sampling_state = initial_step_state _sample( draws=draws, chain=i, start=start[i], step=step, - trace=traces[i], + trace=trace, rng=rngs[i], callback=callback, progress_manager=progress_manager, @@ -1226,7 +1234,7 @@ def _sample( callback=callback, ) try: - for it, stats in enumerate(sampling_gen): + for it, stats in sampling_gen: progress_manager.update( chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune ) @@ -1251,7 +1259,7 @@ def _iter_sample( rng: np.random.Generator, model: Model | None = None, callback: SamplingIteratorCallback | None = None, -) -> Iterator[list[dict[str, Any]]]: +) -> Iterator[tuple[int, list[dict[str, Any]]]]: """Sample one chain with a generator (singleprocess). Parameters @@ -1285,14 +1293,33 @@ def _iter_sample( step.set_rng(rng) point = start + initial_draw_idx = 0 + step.tune = bool(tune) + if hasattr(step, "reset_tuning"): + step.reset_tuning() if isinstance(trace, ZarrChain): trace.link_stepper(step) + stored_draw_idx = trace._sampling_state.draw_idx[chain] + stored_sampling_state = trace._sampling_state.sampling_state[chain] + if stored_draw_idx > 0: + if stored_sampling_state is not None: + step.sampling_state = stored_sampling_state + else: + raise RuntimeError( + "Cannot use the supplied ZarrTrace to restart sampling because " + "it has no sampling_state information stored. You will have to " + "resample from scratch." + ) + initial_draw_idx = stored_draw_idx + point = trace.get_mcmc_point() + else: + # Store initial point in trace + trace.set_mcmc_point(point) try: - step.tune = bool(tune) - if hasattr(step, "reset_tuning"): - step.reset_tuning() - for i in range(draws): + for i in range(initial_draw_idx, draws): + diverging = False + if i == 0 and hasattr(step, "iter_count"): step.iter_count = 0 if i == tune: @@ -1308,7 +1335,7 @@ def _iter_sample( draw=Draw(chain, i == draws, i, i < tune, stats, point), ) - yield stats + yield i, stats except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6f..d0a1f31287 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -194,6 +194,24 @@ def _start_loop(self): draw = 0 tuning = True + if self._zarr_recording: + trace = self._zarr_chain + stored_draw_idx = trace._sampling_state.draw_idx[self.chain] + stored_sampling_state = trace._sampling_state.sampling_state[self.chain] + if stored_draw_idx > 0: + if stored_sampling_state is not None: + self._step_method.sampling_state = stored_sampling_state + else: + raise RuntimeError( + "Cannot use the supplied ZarrTrace to restart sampling because " + "it has no sampling_state information stored. You will have to " + "resample from scratch." + ) + draw = stored_draw_idx + self._write_point(trace.get_mcmc_point()) + else: + # Store starting point in trace's mcmc_point + trace.set_mcmc_point(self._point) msg = self._recv_msg() if msg[0] == "abort": @@ -491,6 +509,10 @@ def __init__( progressbar=progressbar, progressbar_theme=progressbar_theme, ) + if self.zarr_recording: + self._progress.set_initial_state( + *cast(ZarrChain, zarr_chains)[0].completed_draws_and_divergences() + ) def _make_active(self): while self._inactive and len(self._active) < self._max_active: diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 92de63d0c2..f74cbadb78 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence from copy import copy -from typing import TypeAlias +from typing import TypeAlias, cast import cloudpickle import numpy as np @@ -37,7 +37,7 @@ PopulationArrayStepShared, StatsType, ) -from pymc.step_methods.compound import StepMethodState +from pymc.step_methods.compound import CompoundStepState, StepMethodState from pymc.step_methods.metropolis import DEMetropolis from pymc.util import CustomProgress @@ -54,7 +54,7 @@ def _sample_population( *, initial_points: Sequence[PointType], draws: int, - start: Sequence[PointType], + start: list[PointType], rngs: Sequence[np.random.Generator], step: BlockedStep | CompoundStep, tune: int, @@ -110,6 +110,9 @@ def _sample_population( with CustomProgress(disable=not progressbar) as progress: task = progress.add_task("[red]Sampling...", total=draws) + if isinstance(traces[0], ZarrChain): + completed_draws, _ = traces[0].completed_draws_and_divergences() + progress.update(task, completed=completed_draws) for _ in sampling: progress.update(task) @@ -151,7 +154,9 @@ def warn_population_size( class PopulationStepper: """Wraps population of step methods to step them in parallel with single or multiprocessing.""" - def __init__(self, steppers, parallelize: bool, progressbar: bool = True): + def __init__( + self, steppers, parallelize: bool, progressbar: bool = True, first_draw_idx: int = 0 + ): """Use multiprocessing to parallelize chains. Falls back to sequential evaluation if multiprocessing fails. @@ -195,6 +200,7 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) # ): task = self._progress.add_task(description=f"Chain {c}") + self._progress.update(task, completed=first_draw_idx) secondary_end, primary_end = multiprocessing.Pipe() stepper_dumps = cloudpickle.dumps(stepper, protocol=4) process = multiprocessing.Process( @@ -330,7 +336,7 @@ def _prepare_iter_population( *, draws: int, step, - start: Sequence[PointType], + start: list[PointType], parallelize: bool, traces: Sequence[BaseTrace], tune: int, @@ -376,10 +382,35 @@ def _prepare_iter_population( raise ValueError("Argument `draws` should be above 0.") # The initialization of traces, samplers and points must happen in the right order: + # 0. previous sampling state is loaded if possible # 1. population of points is created # 2. steppers are initialized and linked to the points object # 3. a PopulationStepper is configured for parallelized stepping + # 0. load sampling state and start point from traces if possible + first_draw_idx = 0 + stored_sampling_states: Sequence[StepMethodState | CompoundStepState] | None = None + can_resume_sampling = False + if isinstance(traces[0], ZarrChain): + # All traces share the same store. This lets us load the past sampling states and draw + # indices for all chain + stored_draw_idxs = traces[0]._sampling_state.draw_idx[:] + stored_sampling_states = cast( + Sequence[StepMethodState | CompoundStepState], + traces[0]._sampling_state.sampling_state[:], + ) + can_resume_sampling = ( + all(stored_draw_idxs > 0) + and all(stored_draw_idxs == stored_draw_idxs[0]) + and all(sampling_state is not None for sampling_state in stored_sampling_states) + ) + for chain, trace in enumerate(traces): + trace = cast(ZarrChain, trace) + if can_resume_sampling: + start[chain] = trace.get_mcmc_point() + else: + trace.set_mcmc_point(start[chain]) + # 1. create a population (points) that tracks each chain # it is updated as the chains are advanced population = [start[c] for c in range(nchains)] @@ -401,15 +432,25 @@ def _prepare_iter_population( for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: if isinstance(sm, PopulationArrayStepShared): sm.link_population(population, c) + if can_resume_sampling: + chainstep.sampling_state = cast(Sequence[CompoundStepState], stored_sampling_states)[c] steppers.append(chainstep) # 3. configure the PopulationStepper (expensive call) - popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) + popstep = PopulationStepper( + steppers, parallelize, progressbar=progressbar, first_draw_idx=first_draw_idx + ) # Because the preparations above are expensive, the actual iterator is # in another method. This way the progbar will not be disturbed. return _iter_population( - draws=draws, tune=tune, popstep=popstep, steppers=steppers, traces=traces, points=population + draws=draws, + tune=tune, + popstep=popstep, + steppers=steppers, + traces=traces, + points=population, + first_draw_idx=first_draw_idx, ) @@ -421,6 +462,7 @@ def _iter_population( steppers, traces: Sequence[BaseTrace], points, + first_draw_idx=0, ) -> Iterator[int]: """Iterate a ``PopulationStepper``. @@ -450,7 +492,7 @@ def _iter_population( try: with popstep: # iterate draws of all chains - for i in range(draws): + for i in range(first_draw_idx, draws): # this call steps all chains and returns a list of (point, stats) # the `popstep` may interact with subprocesses internally updates = popstep.step(i == tune, points) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..3c92360106 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -40,7 +40,7 @@ ) from pymc.util import RandomGenerator, get_random_generator -__all__ = ("Competence", "CompoundStep") +__all__ = ("Competence", "CompoundStep", "StepMethodState") @unique @@ -101,7 +101,8 @@ class StepMethodState(DataClassState): rng: RandomGeneratorState -class BlockedStep(ABC, WithSamplingState): +class BlockedStep(ABC, WithSamplingState[StepMethodState]): + _state_class = StepMethodState stats_dtypes: list[dict[str, type]] = [] """A list containing <=1 dictionary that maps stat names to dtypes. @@ -254,7 +255,7 @@ def __init__(self, methods: list[StepMethodState]): self.methods = methods -class CompoundStep(WithSamplingState): +class CompoundStep(WithSamplingState[CompoundStepState]): """Step method composed of a list of several other step methods applied in sequence.""" _state_class = CompoundStepState @@ -291,7 +292,7 @@ def reset_tuning(self): method.reset_tuning() @property - def sampling_state(self) -> DataClassState: + def sampling_state(self) -> CompoundStepState: return CompoundStepState(methods=[method.sampling_state for method in self.methods]) @sampling_state.setter diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index dd7ad6922b..86fb0f1d99 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -196,9 +196,9 @@ class QuadPotentialDiagAdaptState(PotentialState): _n: int = field(metadata={"frozen": True}) _discard_window: int = field(metadata={"frozen": True}) _early_update: int = field(metadata={"frozen": True}) - _initial_mean: np.ndarray = field(metadata={"frozen": True}) - _initial_diag: np.ndarray = field(metadata={"frozen": True}) - _initial_weight: np.ndarray = field(metadata={"frozen": True}) + _initial_mean: np.ndarray + _initial_diag: np.ndarray + _initial_weight: np.ndarray adaptation_window_multiplier: float = field(metadata={"frozen": True}) _store_mass_matrix_trace: bool = field(metadata={"frozen": True}) @@ -734,9 +734,9 @@ class QuadPotentialFullAdaptState(PotentialState): dtype: Any = field(metadata={"frozen": True}) _n: int = field(metadata={"frozen": True}) _update_window: int = field(metadata={"frozen": True}) - _initial_mean: np.ndarray = field(metadata={"frozen": True}) - _initial_cov: np.ndarray = field(metadata={"frozen": True}) - _initial_weight: np.ndarray = field(metadata={"frozen": True}) + _initial_mean: np.ndarray + _initial_cov: np.ndarray + _initial_weight: np.ndarray adaptation_window_multiplier: float = field(metadata={"frozen": True}) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..bdc0df58c5 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -875,7 +875,8 @@ def competence(var): class DEMetropolisState(StepMethodState): scaling: np.ndarray lamb: float - tune: str | None + tune: bool + tune_target: str | None tune_interval: int steps_until_tune: int accepted: int @@ -977,7 +978,8 @@ def __init__( self.lamb = float(lamb) if tune not in {None, "scaling", "lambda"}: raise ValueError('The parameter "tune" must be one of {None, scaling, lambda}') - self.tune = tune + self.tune = True + self.tune_target = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval self.accepted = 0 @@ -993,9 +995,9 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: q0d = q0.data if not self.steps_until_tune and self.tune: - if self.tune == "scaling": + if self.tune_target == "scaling": self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval)) - elif self.tune == "lambda": + elif self.tune_target == "lambda": self.lamb = tune(self.lamb, self.accepted / float(self.tune_interval)) # Reset counter self.steps_until_tune = self.tune_interval diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index 98e177aa03..071cbbf595 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -13,7 +13,8 @@ # limitations under the License. from copy import deepcopy from dataclasses import MISSING, Field, dataclass, fields -from typing import Any, ClassVar +from numbers import Number +from typing import Any, ClassVar, Generic, TypeVar import numpy as np @@ -26,9 +27,37 @@ class DataClassState: __dataclass_fields__: ClassVar[dict[str, Field[Any]]] = {} + def is_compatible(self, other: Any) -> bool: + return compatible_dataclass_values(self, other) + + +def compatible_dataclass_values(v1: Any, v2: Any) -> bool: + if v1.__class__ != v2.__class__ and not (isinstance(v1, Number) and isinstance(v2, Number)): + # Numbers might have different classes (e.g. float("32") and np.float64(32)) + # but numbers are compatible with each other + return False + if isinstance(v1, tuple): + return len(v1) == len(v2) or all( + compatible_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True) + ) + elif isinstance(v1, dict): + return set(v1) == set(v2) or all(compatible_dataclass_values(v1[k], v2[k]) for k in v1) + elif isinstance(v1, np.ndarray): + return v1.dtype == v2.dtype + elif isinstance(v1, np.random.Generator): + return True + elif isinstance(v1, DataClassState): + return set(fields(v1)) == set(fields(v2)) and all( + compatible_dataclass_values(getattr(v1, f1.name), getattr(v2, f2.name)) + for f1, f2 in zip(fields(v1), fields(v2), strict=True) + ) + return True + def equal_dataclass_values(v1, v2): - if v1.__class__ != v2.__class__: + if v1.__class__ != v2.__class__ and not (isinstance(v1, Number) and isinstance(v2, Number)): + # Numbers might have different classes (e.g. float("32") and np.float64(32)) + # but numbers are equal based on their value and not their type return False if isinstance(v1, (list, tuple)): # noqa: UP038 return len(v1) == len(v2) and all( @@ -51,7 +80,10 @@ def equal_dataclass_values(v1, v2): return v1 == v2 -class WithSamplingState: +SamplingStateType = TypeVar("SamplingStateType", bound=DataClassState) + + +class WithSamplingState(Generic[SamplingStateType]): """Mixin class that adds the ``sampling_state`` property to an object. The object's type must define the ``_state_class`` as a valid @@ -60,10 +92,10 @@ class WithSamplingState: the state represented as objects of the ``_state_class`` type. """ - _state_class: type[DataClassState] = DataClassState + _state_class: type[SamplingStateType] @property - def sampling_state(self) -> DataClassState: + def sampling_state(self) -> SamplingStateType: state_class = self._state_class kwargs = {} for field in fields(state_class): @@ -93,6 +125,9 @@ def sampling_state(self, state: DataClassState): assert isinstance(state, state_class), ( f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" ) + assert self.sampling_state.is_compatible(state), ( + "The supplied state is incompatible with the current sampling state." + ) for field in fields(state_class): is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..b8362e776e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -812,6 +812,7 @@ def __init__( self._show_progress = show_progress self.divergences = 0 + self.draws = 0 self.completed_draws = 0 self.total_draws = draws + tune self.desc = "Sampling chain" @@ -827,18 +828,26 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): return self._progress.__exit__(exc_type, exc_val, exc_tb) + def set_initial_state(self, draws: int = 0, divergences: int = 0): + self.draws = draws + self.completed_draws += draws + self.divergences = divergences + def _initialize_tasks(self): if self.combined_progress: self.tasks = [ self._progress.add_task( self.desc.format(self), - completed=0, - draws=0, + completed=self.completed_draws, + draws=self.completed_draws, total=self.total_draws * self.chains - 1, chain_idx=0, sampling_speed=0, speed_unit="draws/s", - **{stat: value[0] for stat, value in self.progress_stats.items()}, + **{ + stat: value[0] if stat != "diverging" else self.divergences + for stat, value in self.progress_stats.items() + }, ) ] @@ -846,8 +855,8 @@ def _initialize_tasks(self): self.tasks = [ self._progress.add_task( self.desc.format(self), - completed=0, - draws=0, + completed=self.completed_draws, + draws=self.draws, total=self.total_draws - 1, chain_idx=chain_idx, sampling_speed=0, diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index af9c9e0a06..695128c1fb 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -27,6 +27,7 @@ from pymc.backends.zarr import ZarrTrace from pymc.stats.convergence import SamplerWarning from pymc.step_methods import NUTS, CompoundStep, Metropolis +from pymc.step_methods.hmc import quadpotential from pymc.step_methods.state import equal_dataclass_values from tests.helpers import equal_sampling_states @@ -94,6 +95,40 @@ def model_step(request, model): return step +@pytest.fixture(scope="function", params=["mid_tuning", "finished_tuning"]) +def populated_trace(model, request): + tune = 5 + draws = 5 + chains = 1 + if request.param == "mid_tuning": + total_steps = 2 + else: + total_steps = 7 + trace = ZarrTrace( + draws_per_chunk=1, + include_transformed=True, + ) + with model: + rng = np.random.default_rng(42) + stepper = NUTS(rng=rng) + trace.init_trace( + chains=chains, + draws=draws, + tune=tune, + step=stepper, + model=model, + ) + point = model.initial_point() + for draw in range(total_steps): + tuning = draw < tune + if not tuning: + stepper.stop_tuning() + point, stats = stepper.step(point) + trace.straces[0].record(point, stats) + trace.straces[0].record_sampling_state(stepper) + return trace, total_steps, tune, draws + + def test_record(model, model_step, include_transformed, draws_per_chunk): store = zarr.TempStore() trace = ZarrTrace( @@ -108,11 +143,14 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): "_sampling_state", "sample_stats", "posterior", + "warmup_sample_stats", + "warmup_posterior", "constant_data", "observed_data", } if include_transformed: expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") assert {group_name for group_name, _ in trace.root.groups()} == expected_groups # Record samples from the ZarrChain @@ -121,6 +159,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): manually_collected_draws = [] manually_collected_stats = [] point = model.initial_point() + divergences = 0 for draw in range(tune + draws): tuning = draw < tune if not tuning: @@ -133,28 +172,17 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): manually_collected_draws.append(point) manually_collected_stats.append(stats) trace.straces[0].record(point, stats) + for step_stats in stats: + divergences += sum( + int(step_stats[key] and not step_stats["tune"]) + for key in step_stats + if "diverging" in key + ) + assert trace.straces[0].completed_draws_and_divergences() == (draw + 1, divergences) + last_point = point trace.straces[0].record_sampling_state(model_step) assert {group_name for group_name, _ in trace.root.groups()} == expected_groups - # Assert split warmup - trace.split_warmup("posterior") - trace.split_warmup("sample_stats") - expected_groups = { - "_sampling_state", - "sample_stats", - "posterior", - "warmup_sample_stats", - "warmup_posterior", - "constant_data", - "observed_data", - } - if include_transformed: - trace.split_warmup("unconstrained_posterior") - expected_groups.add("unconstrained_posterior") - expected_groups.add("warmup_unconstrained_posterior") - assert {group_name for group_name, _ in trace.root.groups()} == expected_groups - # trace.consolidate() - # Assert observed data is correct assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"} assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"] @@ -236,7 +264,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): stat = stats_bijection.map(stat) for var, value in draw.items(): if var in trace.posterior.arrays(): - assert np.array_equal(trace.posterior[var][0, draw_idx], value) + np.testing.assert_array_equal(trace.posterior[var][0, draw_idx], value) for var, value in stat.items(): sample_stats = trace.root["sample_stats"] stat_val = sample_stats[var][0, draw_idx] @@ -260,7 +288,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): else: posterior = trace.root["warmup_posterior"] if var in posterior.arrays(): - assert np.array_equal(posterior[var][0, draw_idx], value) + np.testing.assert_array_equal(posterior[var][0, draw_idx], value) for var, value in stat.items(): sample_stats = trace.root["warmup_sample_stats"] stat_val = sample_stats[var][0, draw_idx] @@ -284,7 +312,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): else: posterior = trace.root["posterior"] if var in posterior.arrays(): - assert np.array_equal(posterior[var][0, draw_idx], value) + np.testing.assert_array_equal(posterior[var][0, draw_idx], value) for var, value in stat.items(): sample_stats = trace.root["sample_stats"] stat_val = sample_stats[var][0, draw_idx] @@ -301,6 +329,9 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): trace._sampling_state.sampling_state[0], model_step.sampling_state, ) + assert set(last_point) == set(trace._sampling_state.mcmc_point.array_keys()) + for var_name, value in trace._sampling_state.mcmc_point.arrays(): + np.testing.assert_array_equal(last_point[var_name][None, ...], value) # Assert to inference data returns the expected groups idata = trace.to_inferencedata(save_warmup=True) @@ -336,30 +367,82 @@ def test_split_warmup(tune, model, model_step, include_transformed): draws = 10 - tune trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) - trace.split_warmup("posterior") - trace.split_warmup("sample_stats") assert len(trace.root.posterior.draw) == draws assert len(trace.root.sample_stats.draw) == draws - if tune == 0: - with pytest.raises(KeyError): - trace.root["warmup_posterior"] + assert len(trace.root["warmup_posterior"].draw) == tune + assert len(trace.root["warmup_sample_stats"].draw) == tune + + for var_name, posterior_array in trace.posterior.arrays(): + dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert posterior_array.shape[1] == draws + assert trace.root["warmup_posterior"][var_name].shape[1] == tune + for var_name, sample_stats_array in trace.sample_stats.arrays(): + dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert sample_stats_array.shape[1] == draws + assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune + + +@pytest.mark.parametrize( + "desired_tune_and_draws", + [ + [None, 1], + [3, None], + [10, None], + [None, 10], + ], +) +def test_resize(populated_trace, desired_tune_and_draws): + desired_tune, desired_draws = desired_tune_and_draws + trace, total_steps, tune, draws = populated_trace + expect_to_fail = False + failure_message = "" + if desired_tune is not None: + if total_steps > tune: + expect_to_fail = True + failure_message = ( + "Cannot change the number of tuning steps in the trace. " + "Some chains have finished their tuning phase and have " + "already performed steps in the posterior sampling regime." + ) + elif total_steps > desired_tune: + expect_to_fail = True + failure_message = ( + "Cannot change the number of tuning steps in the trace. " + "Some chains have already taken more steps than the desired number " + "of tuning steps. Please increase the desired number of tuning " + f"steps to at least {total_steps}." + ) + if desired_draws is not None and total_steps > (desired_draws + tune): + expect_to_fail = True + failure_message = ( + "Cannot change the number of draws in the trace. " + "Some chains have already taken more steps than the desired number " + "of draws. Please increase the desired number of draws " + f"to at least {total_steps - tune}." + ) + if expect_to_fail: + with pytest.raises(ValueError, match=failure_message): + trace.resize(tune=desired_tune, draws=desired_draws) else: - assert len(trace.root["warmup_posterior"].draw) == tune - assert len(trace.root["warmup_sample_stats"].draw) == tune - - with pytest.raises(RuntimeError): - trace.split_warmup("posterior") - - for var_name, posterior_array in trace.posterior.arrays(): - dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] - if len(dims) >= 2 and dims[1] == "draw": - assert posterior_array.shape[1] == draws - assert trace.root["warmup_posterior"][var_name].shape[1] == tune - for var_name, sample_stats_array in trace.sample_stats.arrays(): - dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] - if len(dims) >= 2 and dims[1] == "draw": - assert sample_stats_array.shape[1] == draws - assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune + trace.resize(tune=desired_tune, draws=desired_draws) + result_tune = desired_tune or tune + result_draws = desired_draws or draws + assert trace.tuning_steps == result_tune + assert trace.draws == result_draws + posterior_groups = ["posterior", "sample_stats", "unconstrained_posterior"] + warmup_groups = [f"warmup_{name}" for name in posterior_groups] + for group_set, expected_size in zip( + [posterior_groups, warmup_groups], [result_draws, result_tune] + ): + for group in group_set: + zarr_group = getattr(trace, group) + for name, values in zarr_group.arrays(): + if values.ndim > 1: # Quick and dirty hack to filter out coordinate arrays + assert values.shape[1] == expected_size + elif name == "draw": + assert values.shape[0] == expected_size @pytest.fixture(scope="function", params=["discard_tuning", "keep_tuning"]) @@ -536,3 +619,260 @@ def test_sampling_consistency( sequential_trace._sampling_state.sampling_state[chain], ) xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior) + + +def test_from_store(populated_trace): + trace, total_steps, tune, draws = populated_trace + loaded_trace = ZarrTrace.from_store( + trace.root.store, + ) + assert loaded_trace.is_root_populated and not loaded_trace._is_base_setup + assert trace.draws_per_chunk == loaded_trace.draws_per_chunk + assert trace.include_transformed == loaded_trace.include_transformed + assert set(trace.varnames) == set(loaded_trace.varnames) + assert set(trace.coords) == set(loaded_trace.coords) and ( + all( + np.array_equal(np.asarray(coord), np.asarray(loaded_trace.coords[dim])) + for dim, coord in trace.coords.items() + ) + ) + assert trace.vars_to_dims == loaded_trace.vars_to_dims + + assert not hasattr(loaded_trace, "straces") + assert set(trace.root.group_keys()) == set(loaded_trace.root.group_keys()) + for group_name, group in trace.root.groups(): + loaded_group = loaded_trace.root[group_name] + if group_name == "_sampling_state": + assert all( + equal_sampling_states(this, other) if this is not None else this is other + for this, other in zip(group.sampling_state[:], loaded_group.sampling_state[:]) + ) + np.testing.assert_array_equal(group.draw_idx, loaded_group.draw_idx) + assert trace.tuning_steps == loaded_trace.tuning_steps + assert trace.draws == loaded_trace.draws + assert trace.sampling_time == loaded_trace.sampling_time + else: + assert set(group.array_keys()) == set(loaded_group.array_keys()) + for name, array in group.arrays(): + loaded_array = loaded_group[name] + assert dict(array.attrs) == dict(loaded_array.attrs) + np.testing.assert_array_equal(np.asarray(array), np.asarray(loaded_array)) + + +def test_resume_sampling( + model, + model_step, + include_transformed, + parallel, + draws_per_chunk, +): + tune = 2 + draws = 3 + if parallel: + chains = 2 + cores = 2 + else: + chains = 1 + cores = 1 + store1 = zarr.TempStore() + store2 = zarr.TempStore() + trace1 = ZarrTrace( + store=store1, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + trace2 = ZarrTrace( + store=store2, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + tune = 2 + draws = 3 + if parallel: + chains = 2 + cores = 2 + else: + chains = 1 + cores = 1 + initial_step_state = model_step.sampling_state + with model: + idata_full = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + trace=trace1, + step=model_step, + discard_tuned_samples=False, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": True}, + random_seed=42, + ) + model_step.sampling_state = initial_step_state + with model: + pm.sample( + draws=0, + tune=tune - 1, + chains=chains, + cores=cores, + trace=trace2, + step=model_step, + discard_tuned_samples=False, + return_inferencedata=False, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": True}, + random_seed=42, + ) + pm.sample( + draws=draws - 1, + tune=tune, + chains=chains, + cores=1, + trace=trace2, + step=model_step, + discard_tuned_samples=False, + return_inferencedata=False, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": True}, + ) + idata_with_pauses = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + trace=trace2, + step=model_step, + discard_tuned_samples=False, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": True}, + ) + for group in idata_full.groups(): + if "sample_stats" in group: + comparable_stats = [ + stat_name + for stat_name in idata_full[group].data_vars + if not any( + incomparable in stat_name + for incomparable in [ + "process_time_diff", + "perf_counter_diff", + "perf_counter_start", + ] + ) + ] + for comparable_stat in comparable_stats: + xr.testing.assert_equal( + idata_full[group][comparable_stat], + idata_with_pauses[group][comparable_stat], + ) + else: + xr.testing.assert_equal(idata_full[group], idata_with_pauses[group]) + + +incompatibility_modes = [ + "wrong_coordinates", + "changed_coordinates", + "changed_data", + "changed_observations", + "untracked_vars", + "different_step_stats", + "different_step_state", +] + + +def basic_model(coords, observed_value, include_free_var=True, include_data=True, mix_dims=False): + with pm.Model(coords=coords) as base_model: + trans_var = pm.HalfNormal("trans_var", dims="free_dims" if mix_dims else "trans_dims") + if include_free_var: + free_var = pm.Normal("free_var", dims="free_dims") + det_var = pm.Deterministic("det_var", free_var.sum() + trans_var.sum()) + else: + det_var = pm.Deterministic("det_var", trans_var.sum()) + if include_data: + data_var = pm.Data( + "data_var", np.ones(len(coords.get("data_dims", [1]))), dims="data_dims" + ) + obs_var = pm.Normal("obs_var", data_var.sum() + det_var, observed=observed_value) + else: + obs_var = pm.Normal("obs_var", det_var, observed=observed_value) + return base_model + + +@pytest.fixture(scope="module", params=incompatibility_modes) +def incompatible_model(request): + mode = request.param + base_coords = { + "trans_dims": range(3), + "free_dims": ["A", "B"], + "data_dims": range(5), + } + base_observed = np.arange(4) + base_model = basic_model( + coords=base_coords, include_free_var=True, include_data=True, observed_value=base_observed + ) + with base_model: + base_step = NUTS() + store = zarr.TempStore() + trace = ZarrTrace(store=store, include_transformed=True) + trace = pm.sample( + tune=4, + draws=4, + chains=1, + step=base_step, + random_seed=42, + trace=trace, + return_inferencedata=False, + discard_tuned_samples=False, + ) + test_step = base_step + if mode == "wrong_coordinates": + wrong_model = basic_model(coords=base_coords, observed_value=base_observed, mix_dims=True) + error_message = ( + "Some model variables have different dimensions than those stored in the trace." + ) + elif mode == "changed_coordinates": + wrong_coords = base_coords.copy() + wrong_coords["trans_dims"] = range(10) + wrong_model = basic_model(coords=wrong_coords, observed_value=base_observed) + error_message = "Model coordinates don't match the coordinates stored in the trace" + elif mode == "changed_data": + wrong_model = basic_model(coords=base_coords, observed_value=base_observed) + with wrong_model: + pm.set_data({"data_var": np.zeros_like(wrong_model["data_var"].get_value())}) + error_message = "The model constant data does not match with the stored constant data" + elif mode == "changed_observations": + wrong_model = basic_model(coords=base_coords, observed_value=base_observed + 44) + error_message = "The model observed data does not match with the stored observed data" + elif mode == "untracked_vars": + wrong_model = basic_model( + coords=base_coords, include_free_var=False, observed_value=base_observed + ) + error_message = ( + "The model deterministics and random variables given the sampled var_names " + "do not match with the stored deterministics variables in the trace." + ) + elif mode == "different_step_stats": + wrong_model = base_model + with wrong_model: + test_step = Metropolis() + error_message = "The step method sample stats do not match the ones stored in the trace." + elif mode == "different_step_state": + wrong_model = base_model + with wrong_model: + potential = quadpotential.QuadPotentialFullAdapt( + base_step.potential._n, + base_step.potential._initial_mean, + ) + test_step = NUTS(potential=potential) + error_message = ( + "The state method sampling state class is incompatible with what's stored in the trace." + ) + else: + raise NotImplementedError() + return trace, wrong_model, error_message, test_step + + +def test_model_and_step_are_compatible(incompatible_model): + trace, model, expected_error, step = incompatible_model + with pytest.raises(AssertionError, match=expected_error): + trace.assert_model_and_step_are_compatible( + step=step, model=model, vars=model.unobserved_value_vars + ) diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 234dabb5a4..781dbddc0b 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -159,16 +159,16 @@ def test_demcmc_tune_parameter(self): pm.Normal("n", mu=0, sigma=1, size=(2, 3)) step = DEMetropolis() - assert step.tune == "scaling" + assert step.tune_target == "scaling" step = DEMetropolis(tune=None) - assert step.tune is None + assert step.tune_target is None step = DEMetropolis(tune="scaling") - assert step.tune == "scaling" + assert step.tune_target == "scaling" step = DEMetropolis(tune="lambda") - assert step.tune == "lambda" + assert step.tune_target == "lambda" with pytest.raises(ValueError): DEMetropolis(tune="foo") diff --git a/tests/step_methods/test_state.py b/tests/step_methods/test_state.py index dd351bb555..fbedbf8429 100644 --- a/tests/step_methods/test_state.py +++ b/tests/step_methods/test_state.py @@ -101,7 +101,7 @@ def __init__(self, rng=None): def test_sampling_state(): b1 = B() b2 = B(mutable_field=2.0) - b3 = B(c=1, extra_info1=np.array([10, 20])) + b3 = B(extra_info1=np.array([10, 20])) b4 = B(a=2, b=3.0, c="d") b5 = B(c=1) b6 = B(f={"a": 1, "b": "c", "d": None}) @@ -130,6 +130,11 @@ def test_sampling_state(): with pytest.raises(AssertionError, match="Encountered invalid state class"): b1.sampling_state = b1_state.state1 + with pytest.raises( + AssertionError, match="The supplied state is incompatible with the current sampling state." + ): + b1.sampling_state = b5.sampling_state + b1.sampling_state = b4_state assert equal_sampling_states(b1.sampling_state, b4_state) assert not equal_sampling_states(b1.sampling_state, b5.sampling_state)