From a8af2e8ddcfd5503834dcf942b836ed9cda62715 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 20 Mar 2025 16:01:28 +0800 Subject: [PATCH 1/3] Document how step methods provide progress bar stats --- pymc/step_methods/compound.py | 143 +++++++++++++++++++++++++++++--- pymc/step_methods/hmc/nuts.py | 72 +++++++++++++--- pymc/step_methods/metropolis.py | 71 +++++++++++++--- pymc/step_methods/slicer.py | 51 ++++++++++-- 4 files changed, 296 insertions(+), 41 deletions(-) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..820896758a 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -29,6 +29,7 @@ import numpy as np from pytensor.graph.basic import Variable +from rich.progress import ProgressColumn from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext @@ -181,17 +182,72 @@ def __new__(cls, *args, **kwargs): step.__newargs = (vars, *args), kwargs return step - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config(self, n_chains: int = 1): + """ + Get progressbar configuration for this step sampler. + + By default, the progress bar displays no stats columns, only basic info (number of draws and sampling time). + Specific step methods should overload this method to specify which stats to display and how. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [] stats = {} return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + By default, the update is a no-op. Specific step methods should implement special logic for which + statistics to display and how. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats: dict[str, str | float | int | bool | None], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + return displayed_stats return update_stats @@ -311,7 +367,28 @@ def set_rng(self, rng: RandomGenerator): for method, _rng in zip(self.methods, _rngs): method.set_rng(_rng) - def _progressbar_config(self, n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + The columns of the rich progress bar displayed during sampler are chosen by the step samplers themselves. In + the compound step case, we display the set union of all columns from the sub-step samplers. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ from functools import reduce column_lists, stat_dict_list = zip( @@ -332,14 +409,56 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): + def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ update_fns = [method._make_update_stats_function() for method in self.methods] - def update_stats(stats, step_stats, chain_idx): + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats: list[dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats: list of dict + List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when + taking the current step. For each dictionary, the keys are names of statistics and the values are + the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + # TODO: The compound step is commonly made of many instances of the same step (e.g. 3 Metropolis steps). + # In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the + # next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate + # the stats from each step. + if not isinstance(step_stats, list): + step_stats = [step_stats] + for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) + displayed_stats = update_fn(displayed_stats, step_stat, chain_idx) - return stats + return displayed_stats return update_stats diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c3592..23affa54ee 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -15,12 +15,13 @@ from __future__ import annotations from collections import namedtuple +from collections.abc import Callable from dataclasses import field import numpy as np from pytensor import config -from rich.progress import TextColumn +from rich.progress import ProgressColumn, TextColumn from rich.table import Column from pymc.stats.convergence import SamplerWarning @@ -231,8 +232,25 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [ TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), @@ -247,18 +265,52 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats: dict[str, str | float | int | bool | None], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ if isinstance(step_stats, list): step_stats = step_stats[0] if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] + displayed_stats["divergences"][chain_idx] += step_stats["diverging"] - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats + displayed_stats["step_size"][chain_idx] = step_stats["step_size"] + displayed_stats["tree_size"][chain_idx] = step_stats["tree_size"] + return displayed_stats return update_stats diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..396a6693f8 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -24,7 +24,7 @@ from pytensor import tensor as pt from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV -from rich.progress import TextColumn +from rich.progress import ProgressColumn, TextColumn from rich.table import Column import pymc as pm @@ -327,8 +327,25 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: def competence(var, has_grad): return Competence.COMPATIBLE - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [ TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), @@ -345,17 +362,51 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats: dict[str, str | float | int | bool | None], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ if isinstance(step_stats, list): step_stats = step_stats[0] - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] + displayed_stats["tune"][chain_idx] = step_stats["tune"] + displayed_stats["accept_rate"][chain_idx] = step_stats["accept"] + displayed_stats["scaling"][chain_idx] = step_stats["scaling"] - return stats + return displayed_stats return update_stats diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..f726efbdf5 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable # Modified from original implementation by Dominik Wabersich (2013) - - import numpy as np from rich.progress import TextColumn @@ -211,16 +210,50 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats: dict[str, str | float | int | bool | None], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ if isinstance(step_stats, list): step_stats = step_stats[0] - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] + displayed_stats["tune"][chain_idx] = step_stats["tune"] + displayed_stats["nstep_out"][chain_idx] = step_stats["nstep_out"] + displayed_stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - return stats + return displayed_stats return update_stats From eccfa39b3a1f7d3dab6432ca19c96b2d654eb784 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 24 Mar 2025 01:46:12 +0800 Subject: [PATCH 2/3] Assign each step method a unique id, and track ids via stats emitted by `Step.step` --- pymc/backends/ndarray.py | 4 ++++ pymc/sampling/mcmc.py | 11 ++++++++- pymc/step_methods/arraystep.py | 37 ++++++++++++++++++++++++---- pymc/step_methods/compound.py | 40 ++++++++++++++++++------------- pymc/step_methods/hmc/base_hmc.py | 3 +++ pymc/step_methods/hmc/nuts.py | 7 +++--- pymc/step_methods/metropolis.py | 16 ++++++++----- pymc/step_methods/slicer.py | 14 +++++------ 8 files changed, 92 insertions(+), 40 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index a08fc8f47e..326fe1e3f3 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None: if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): for key, val in vars.items(): + # step_meta is a key used by the progress bars to track which draw came from which step instance. It + # should never be stored as a sampler statistic. + if key == "step_meta": + continue data[key][draw_idx] = val elif self._stats is not None: raise ValueError("Expected sampler_stats") diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..f12c0345a0 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -15,6 +15,7 @@ """Functions for MCMC sampling.""" import contextlib +import itertools import logging import pickle import sys @@ -111,6 +112,7 @@ def instantiate_steppers( step_kwargs: dict[str, dict] | None = None, initial_point: PointType | None = None, compile_kwargs: dict | None = None, + step_id_generator: Iterator[int] | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -139,6 +141,9 @@ def instantiate_steppers( if step_kwargs is None: step_kwargs = {} + if step_id_generator is None: + step_id_generator = itertools.count() + used_keys = set() if selected_steps: if initial_point is None: @@ -154,6 +159,7 @@ def instantiate_steppers( model=model, initial_point=initial_point, compile_kwargs=compile_kwargs, + step_id_generator=step_id_generator, **kwargs, ) steps.append(step) @@ -853,6 +859,8 @@ def joined_blas_limiter(): initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)] # Instantiate automatically selected steps + # Use a counter to generate a unique id for each stepper used in the model. + step_id_generator = itertools.count() step = instantiate_steppers( model, steps=provided_steps, @@ -860,9 +868,10 @@ def joined_blas_limiter(): step_kwargs=kwargs, initial_point=initial_points[0], compile_kwargs=compile_kwargs, + step_id_generator=step_id_generator, ) if isinstance(step, list): - step = CompoundStep(step) + step = CompoundStep(step, step_id_generator=step_id_generator) if var_names is not None: trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 0c20e09a47..54785c2fe8 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Iterator from typing import cast import numpy as np @@ -43,14 +43,25 @@ class ArrayStep(BlockedStep): :py:func:`pymc.util.get_random_generator` for more information. """ - def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None): + def __init__( + self, + vars, + fs, + allvars=False, + blocked=True, + rng: RandomGenerator = None, + step_id_generator: Iterator[int] | None = None, + ): self.vars = vars self.fs = fs self.allvars = allvars self.blocked = blocked self.rng = get_random_generator(rng) + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: partial_funcs_and_point: list[Callable | PointType] = [ DictToArrayBijection.mapf(x, start_point=point) for x in self.fs ] @@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]: apoint = DictToArrayBijection.map(var_dict) apoint_new, stats = self.astep(apoint, *partial_funcs_and_point) + for sts in stats: + sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id} + if not isinstance(apoint_new, RaveledVars): # We assume that the mapping has stayed the same apoint_new = RaveledVars(apoint_new, apoint.point_map_info) @@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep): and unmapping overhead as well as moving fewer variables around. """ - def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): + def __init__( + self, + vars, + shared, + blocked=True, + rng: RandomGenerator = None, + step_id_generator: Iterator[int] | None = None, + ): """ Create the ArrayStepShared object. @@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): self.shared = {get_var_name(var): shared for var, shared in shared.items()} self.blocked = blocked self.rng = get_random_generator(rng) + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: full_point = None if self.shared: for name, shared_var in self.shared.items(): @@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]: q = DictToArrayBijection.map(point) apoint, stats = self.astep(q) + for sts in stats: + sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id} + if not isinstance(apoint, RaveledVars): # We assume that the mapping has stayed the same apoint = RaveledVars(apoint, q.point_map_info) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 820896758a..d522fa9961 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -125,6 +125,8 @@ class BlockedStep(ABC, WithSamplingState): def __new__(cls, *args, **kwargs): blocked = kwargs.get("blocked") + step_id_generator = kwargs.pop("step_id_generator", None) + if blocked is None: # Try to look up default value from class blocked = getattr(cls, "default_blocked", True) @@ -168,16 +170,19 @@ def __new__(cls, *args, **kwargs): # call __init__ _kwargs = kwargs.copy() _kwargs["rng"] = rng + _kwargs["step_id_generator"] = step_id_generator step.__init__([var], *args, **_kwargs) # Hack for creating the class correctly when unpickling. step.__newargs = ([var], *args), _kwargs steps.append(step) - return CompoundStep(steps) + return CompoundStep(steps, step_id_generator=step_id_generator) else: step = super().__new__(cls) step.stats_dtypes = stats_dtypes step.stats_dtypes_shapes = stats_dtypes_shapes + step._step_id = next(step_id_generator) if step_id_generator else None + # Hack for creating the class correctly when unpickling. step.__newargs = (vars, *args), kwargs return step @@ -223,7 +228,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: def update_stats( displayed_stats: dict[str, np.ndarray], - step_stats: dict[str, str | float | int | bool | None], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], chain_idx: int, ) -> dict[str, np.ndarray]: """ @@ -235,7 +240,7 @@ def update_stats( Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and the values are the current values of the statistics, with one value per chain being sampled. - step_stats: dict + step_stats_dict: dict Dictionary of statistics generated by the step sampler when taking the current step. The keys are the names of the statistics and the values are the values of the statistics generated by the step sampler. @@ -256,7 +261,9 @@ def __getnewargs_ex__(self): return self.__newargs @abstractmethod - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: """Perform a single step of the sampler.""" @staticmethod @@ -315,7 +322,7 @@ class CompoundStep(WithSamplingState): _state_class = CompoundStepState - def __init__(self, methods): + def __init__(self, methods, step_id_generator: Iterator[int] | None = None): self.methods = list(methods) self.stats_dtypes = [] for method in self.methods: @@ -325,11 +332,12 @@ def __init__(self, methods): f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" ) self.tune = True + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point) -> tuple[PointType, StatsType]: + def step(self, point, step_parent_id: int | None = None) -> tuple[PointType, StatsType]: stats = [] for method in self.methods: - point, sts = method.step(point) + point, sts = method.step(point, step_parent_id=self._step_id) stats.extend(sts) # Model logp can only be the logp of the _last_ stats, # if there is one. Pop all others. @@ -409,7 +417,7 @@ def _progressbar_config( return columns, stats - def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]: + def _make_update_stats_function(self) -> Callable[[dict, dict[int, dict], int], dict]: """ Create an update function used by the progress bar to update statistics during sampling. @@ -419,11 +427,13 @@ def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict] Function that updates displayed statistics for the current chain, given statistics generated by the step during the most recent step. """ - update_fns = [method._make_update_stats_function() for method in self.methods] + update_fns = { + method._step_id: method._make_update_stats_function() for method in self.methods + } def update_stats( displayed_stats: dict[str, np.ndarray], - step_stats: list[dict[str, str | float | int | bool | None]], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], chain_idx: int, ) -> dict[str, np.ndarray]: """ @@ -435,7 +445,7 @@ def update_stats( Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and the values are the current values of the statistics, with one value per chain being sampled. - step_stats: list of dict + step_stats_dict: dict of dict List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when taking the current step. For each dictionary, the keys are names of statistics and the values are the values of the statistics generated by the step sampler. @@ -452,11 +462,9 @@ def update_stats( # In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the # next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate # the stats from each step. - if not isinstance(step_stats, list): - step_stats = [step_stats] - for step_stat, update_fn in zip(step_stats, update_fns): - displayed_stats = update_fn(displayed_stats, step_stat, chain_idx) + for step_id, update_fn in update_fns.items(): + displayed_stats = update_fn(displayed_stats, step_stats_dict, chain_idx) return displayed_stats diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4b..ccce50e9f5 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -18,6 +18,7 @@ import time from abc import abstractmethod +from collections.abc import Iterator from typing import Any, NamedTuple import numpy as np @@ -99,6 +100,7 @@ def __init__( step_rand=None, rng=None, initial_point: PointType | None = None, + step_id_generator: Iterator[int] | None = None, **pytensor_kwargs, ): """Set up Hamiltonian samplers with common structures. @@ -133,6 +135,7 @@ def __init__( **pytensor_kwargs: passed to PyTensor functions """ self._model = modelcontext(model) + self._step_id = next(step_id_generator) if step_id_generator else None if vars is None: vars = self._model.continuous_value_vars diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 23affa54ee..98ca981b8e 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -278,7 +278,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: def update_stats( displayed_stats: dict[str, np.ndarray], - step_stats: dict[str, str | float | int | bool | None], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], chain_idx: int, ) -> dict[str, np.ndarray]: """ @@ -290,7 +290,7 @@ def update_stats( Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and the values are the current values of the statistics, with one value per chain being sampled. - step_stats: dict + step_stats_dict: dict Dictionary of statistics generated by the step sampler when taking the current step. The keys are the names of the statistics and the values are the values of the statistics generated by the step sampler. @@ -302,8 +302,7 @@ def update_stats( dict The updated statistics dictionary to be displayed in the progress bar. """ - if isinstance(step_stats, list): - step_stats = step_stats[0] + step_stats = step_stats_dict[self._step_id] if not step_stats["tune"]: displayed_stats["divergences"][chain_idx] += step_stats["diverging"] diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 396a6693f8..e27ea83458 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import field from typing import Any @@ -166,6 +166,7 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, + step_id_generator: Iterator[int] | None = None, ): """Create an instance of a Metropolis stepper. @@ -258,7 +259,9 @@ def __init__( shared = pm.make_shared_replacements(initial_point, vars, model) self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) - super().__init__(vars, shared, blocked=blocked, rng=rng) + super().__init__( + vars, shared, blocked=blocked, rng=rng, step_id_generator=step_id_generator + ) def reset_tuning(self): """Reset the tuned sampler parameters to their initial values.""" @@ -375,7 +378,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: def update_stats( displayed_stats: dict[str, np.ndarray], - step_stats: dict[str, str | float | int | bool | None], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], chain_idx: int, ) -> dict[str, np.ndarray]: """ @@ -387,7 +390,7 @@ def update_stats( Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and the values are the current values of the statistics, with one value per chain being sampled. - step_stats: dict + step_stats_dict: dict Dictionary of statistics generated by the step sampler when taking the current step. The keys are the names of the statistics and the values are the values of the statistics generated by the step sampler. @@ -399,8 +402,7 @@ def update_stats( dict The updated statistics dictionary to be displayed in the progress bar. """ - if isinstance(step_stats, list): - step_stats = step_stats[0] + step_stats = step_stats_dict[self._step_id] displayed_stats["tune"][chain_idx] = step_stats["tune"] displayed_stats["accept_rate"][chain_idx] = step_stats["accept"] @@ -1002,7 +1004,9 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, + step_id_generator: Iterator[int] | None = None, ): + self._step_id = next(step_id_generator) if step_id_generator else None model = pm.modelcontext(model) if initial_point is None: initial_point = model.initial_point() diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index f726efbdf5..8350e6b767 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable +from collections.abc import Callable, Iterator # Modified from original implementation by Dominik Wabersich (2013) import numpy as np @@ -34,9 +34,6 @@ LOOP_ERR_MSG = "max slicer iters %d exceeded" -dataclass_state - - @dataclass_state class SliceState(StepMethodState): w: np.ndarray @@ -90,7 +87,9 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? + step_id_generator: Iterator[int] | None = None, ): + self._step_id = next(step_id_generator) if step_id_generator else None model = modelcontext(model) self.w = np.asarray(w).copy() self.tune = tune @@ -223,7 +222,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: def update_stats( displayed_stats: dict[str, np.ndarray], - step_stats: dict[str, str | float | int | bool | None], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], chain_idx: int, ) -> dict[str, np.ndarray]: """ @@ -235,7 +234,7 @@ def update_stats( Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and the values are the current values of the statistics, with one value per chain being sampled. - step_stats: dict + step_stats_dict: dict Dictionary of statistics generated by the step sampler when taking the current step. The keys are the names of the statistics and the values are the values of the statistics generated by the step sampler. @@ -247,8 +246,7 @@ def update_stats( dict The updated statistics dictionary to be displayed in the progress bar. """ - if isinstance(step_stats, list): - step_stats = step_stats[0] + step_stats = step_stats_dict[self._step_id] displayed_stats["tune"][chain_idx] = step_stats["tune"] displayed_stats["nstep_out"][chain_idx] = step_stats["nstep_out"] From 02f7f2b0b2aac33ddabf087f716a3e8df00ba1b7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 24 Mar 2025 01:46:24 +0800 Subject: [PATCH 3/3] Update progress bar via step id --- pymc/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..586d567204 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -883,7 +883,10 @@ def update(self, chain_idx, is_last, draw, tuning, stats): if not tuning and stats and stats[0].get("diverging"): self.divergences += 1 - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) + step_meta = [entry["step_meta"] for entry in stats] + step_id_to_stats = {meta["step_id"]: entry for meta, entry in zip(step_meta, stats)} + + self.progress_stats = self.update_stats(self.progress_stats, step_id_to_stats, chain_idx) more_updates = ( {stat: value[chain_idx] for stat, value in self.progress_stats.items()} if self.full_stats