Skip to content

Fix progress bar error when nested CompoundStep samplers are assigned #7730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
@@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None:
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
# step_meta is a key used by the progress bars to track which draw came from which step instance. It
# should never be stored as a sampler statistic.
if key == "step_meta":
continue
data[key][draw_idx] = val
elif self._stats is not None:
raise ValueError("Expected sampler_stats")
11 changes: 10 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
"""Functions for MCMC sampling."""

import contextlib
import itertools
import logging
import pickle
import sys
@@ -111,6 +112,7 @@ def instantiate_steppers(
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
step_id_generator: Iterator[int] | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.
@@ -139,6 +141,9 @@ def instantiate_steppers(
if step_kwargs is None:
step_kwargs = {}

if step_id_generator is None:
step_id_generator = itertools.count()

used_keys = set()
if selected_steps:
if initial_point is None:
@@ -154,6 +159,7 @@ def instantiate_steppers(
model=model,
initial_point=initial_point,
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
**kwargs,
)
steps.append(step)
@@ -853,16 +859,19 @@ def joined_blas_limiter():
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]

# Instantiate automatically selected steps
# Use a counter to generate a unique id for each stepper used in the model.
step_id_generator = itertools.count()
step = instantiate_steppers(
model,
steps=provided_steps,
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
)
if isinstance(step, list):
step = CompoundStep(step)
step = CompoundStep(step, step_id_generator=step_id_generator)

if var_names is not None:
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
37 changes: 32 additions & 5 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from typing import cast

import numpy as np
@@ -43,14 +43,25 @@ class ArrayStep(BlockedStep):
:py:func:`pymc.util.get_random_generator` for more information.
"""

def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
fs,
allvars=False,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
self.vars = vars
self.fs = fs
self.allvars = allvars
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
partial_funcs_and_point: list[Callable | PointType] = [
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
]
@@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
apoint = DictToArrayBijection.map(var_dict)
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint_new, RaveledVars):
# We assume that the mapping has stayed the same
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
@@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep):
and unmapping overhead as well as moving fewer variables around.
"""

def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
shared,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
"""
Create the ArrayStepShared object.
@@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
full_point = None
if self.shared:
for name, shared_var in self.shared.items():
@@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
q = DictToArrayBijection.map(point)
apoint, stats = self.astep(q)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint, RaveledVars):
# We assume that the mapping has stayed the same
apoint = RaveledVars(apoint, q.point_map_info)
165 changes: 146 additions & 19 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
@@ -21,14 +21,15 @@
import warnings

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from dataclasses import field
from enum import IntEnum, unique
from typing import Any

import numpy as np

from pytensor.graph.basic import Variable
from rich.progress import ProgressColumn

from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
from pymc.model import modelcontext
@@ -124,6 +125,8 @@ class BlockedStep(ABC, WithSamplingState):

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
step_id_generator = kwargs.pop("step_id_generator", None)

if blocked is None:
# Try to look up default value from class
blocked = getattr(cls, "default_blocked", True)
@@ -167,31 +170,89 @@ def __new__(cls, *args, **kwargs):
# call __init__
_kwargs = kwargs.copy()
_kwargs["rng"] = rng
_kwargs["step_id_generator"] = step_id_generator
step.__init__([var], *args, **_kwargs)
# Hack for creating the class correctly when unpickling.
step.__newargs = ([var], *args), _kwargs
steps.append(step)

return CompoundStep(steps)
return CompoundStep(steps, step_id_generator=step_id_generator)
else:
step = super().__new__(cls)
step.stats_dtypes = stats_dtypes
step.stats_dtypes_shapes = stats_dtypes_shapes
step._step_id = next(step_id_generator) if step_id_generator else None

# Hack for creating the class correctly when unpickling.
step.__newargs = (vars, *args), kwargs
return step

@staticmethod
def _progressbar_config(n_chains=1):
def _progressbar_config(self, n_chains: int = 1):
"""
Get progressbar configuration for this step sampler.
By default, the progress bar displays no stats columns, only basic info (number of draws and sampling time).
Specific step methods should overload this method to specify which stats to display and how.
Parameters
----------
n_chains: int
Number of chains being sampled. This controls the number of progress bars that will be displayed.
Returns
-------
columns: list of rich.progress.ProgressColumn
List of columns to display in the progress bar.
stats: dict
Dictionary of statistics associated with each column.
"""
columns = []
stats = {}

return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
return stats
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
"""
Create an update function used by the progress bar to update statistics during sampling.
By default, the update is a no-op. Specific step methods should implement special logic for which
statistics to display and how.
Returns
-------
update_stats: Callable
Function that updates displayed statistics for the current chain, given statistics generated by the step
during the most recent step.
"""

def update_stats(
displayed_stats: dict[str, np.ndarray],
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
chain_idx: int,
) -> dict[str, np.ndarray]:
"""
Update the statistics displayed in the progress bar after each step.
Parameters
----------
displayed_stats: dict
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
the values are the current values of the statistics, with one value per chain being sampled.
step_stats_dict: dict
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
names of the statistics and the values are the values of the statistics generated by the step sampler.
chain_idx: int
The chain number associated with the current step
Returns
-------
dict
The updated statistics dictionary to be displayed in the progress bar.
"""
return displayed_stats

return update_stats

@@ -200,7 +261,9 @@ def __getnewargs_ex__(self):
return self.__newargs

@abstractmethod
def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
"""Perform a single step of the sampler."""

@staticmethod
@@ -259,7 +322,7 @@ class CompoundStep(WithSamplingState):

_state_class = CompoundStepState

def __init__(self, methods):
def __init__(self, methods, step_id_generator: Iterator[int] | None = None):
self.methods = list(methods)
self.stats_dtypes = []
for method in self.methods:
@@ -269,11 +332,12 @@ def __init__(self, methods):
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
)
self.tune = True
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point) -> tuple[PointType, StatsType]:
def step(self, point, step_parent_id: int | None = None) -> tuple[PointType, StatsType]:
stats = []
for method in self.methods:
point, sts = method.step(point)
point, sts = method.step(point, step_parent_id=self._step_id)
stats.extend(sts)
# Model logp can only be the logp of the _last_ stats,
# if there is one. Pop all others.
@@ -311,7 +375,28 @@ def set_rng(self, rng: RandomGenerator):
for method, _rng in zip(self.methods, _rngs):
method.set_rng(_rng)

def _progressbar_config(self, n_chains=1):
def _progressbar_config(
self, n_chains: int = 1
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
"""
Get progressbar configuration for this step sampler.
The columns of the rich progress bar displayed during sampler are chosen by the step samplers themselves. In
the compound step case, we display the set union of all columns from the sub-step samplers.
Parameters
----------
n_chains: int
Number of chains being sampled. This controls the number of progress bars that will be displayed.
Returns
-------
columns: list of rich.progress.ProgressColumn
List of columns to display in the progress bar.
stats: dict
Dictionary of statistics associated with each column.
"""
from functools import reduce

column_lists, stat_dict_list = zip(
@@ -332,14 +417,56 @@ def _progressbar_config(self, n_chains=1):

return columns, stats

def _make_update_stats_function(self):
update_fns = [method._make_update_stats_function() for method in self.methods]
def _make_update_stats_function(self) -> Callable[[dict, dict[int, dict], int], dict]:
"""
Create an update function used by the progress bar to update statistics during sampling.
def update_stats(stats, step_stats, chain_idx):
for step_stat, update_fn in zip(step_stats, update_fns):
stats = update_fn(stats, step_stat, chain_idx)
Returns
-------
update_stats: Callable
Function that updates displayed statistics for the current chain, given statistics generated by the step
during the most recent step.
"""
update_fns = {
method._step_id: method._make_update_stats_function() for method in self.methods
}

return stats
def update_stats(
displayed_stats: dict[str, np.ndarray],
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
chain_idx: int,
) -> dict[str, np.ndarray]:
"""
Update the statistics displayed in the progress bar after each step.
Parameters
----------
displayed_stats: dict
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
the values are the current values of the statistics, with one value per chain being sampled.
step_stats_dict: dict of dict
List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when
taking the current step. For each dictionary, the keys are names of statistics and the values are
the values of the statistics generated by the step sampler.
chain_idx: int
The chain number associated with the current step
Returns
-------
dict
The updated statistics dictionary to be displayed in the progress bar.
"""
# TODO: The compound step is commonly made of many instances of the same step (e.g. 3 Metropolis steps).
# In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the
# next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate
# the stats from each step.

for step_id, update_fn in update_fns.items():
displayed_stats = update_fn(displayed_stats, step_stats_dict, chain_idx)

return displayed_stats

return update_stats

3 changes: 3 additions & 0 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import time

from abc import abstractmethod
from collections.abc import Iterator
from typing import Any, NamedTuple

import numpy as np
@@ -99,6 +100,7 @@ def __init__(
step_rand=None,
rng=None,
initial_point: PointType | None = None,
step_id_generator: Iterator[int] | None = None,
**pytensor_kwargs,
):
"""Set up Hamiltonian samplers with common structures.
@@ -133,6 +135,7 @@ def __init__(
**pytensor_kwargs: passed to PyTensor functions
"""
self._model = modelcontext(model)
self._step_id = next(step_id_generator) if step_id_generator else None

if vars is None:
vars = self._model.continuous_value_vars
75 changes: 63 additions & 12 deletions pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
@@ -15,12 +15,13 @@
from __future__ import annotations

from collections import namedtuple
from collections.abc import Callable
from dataclasses import field

import numpy as np

from pytensor import config
from rich.progress import TextColumn
from rich.progress import ProgressColumn, TextColumn
from rich.table import Column

from pymc.stats.convergence import SamplerWarning
@@ -231,8 +232,25 @@ def competence(var, has_grad):
return Competence.PREFERRED
return Competence.INCOMPATIBLE

@staticmethod
def _progressbar_config(n_chains=1):
def _progressbar_config(
self, n_chains: int = 1
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
"""
Get progressbar configuration for this step sampler.
Parameters
----------
n_chains: int
Number of chains being sampled. This controls the number of progress bars that will be displayed.
Returns
-------
columns: list of rich.progress.ProgressColumn
List of columns to display in the progress bar.
stats: dict
Dictionary of statistics associated with each column.
"""
columns = [
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)),
@@ -247,18 +265,51 @@ def _progressbar_config(n_chains=1):

return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
"""
Create an update function used by the progress bar to update statistics during sampling.
Returns
-------
update_stats: Callable
Function that updates displayed statistics for the current chain, given statistics generated by the step
during the most recent step.
"""

def update_stats(
displayed_stats: dict[str, np.ndarray],
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
chain_idx: int,
) -> dict[str, np.ndarray]:
"""
Update the statistics displayed in the progress bar after each step.
Parameters
----------
displayed_stats: dict
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
the values are the current values of the statistics, with one value per chain being sampled.
step_stats_dict: dict
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
names of the statistics and the values are the values of the statistics generated by the step sampler.
chain_idx: int
The chain number associated with the current step
Returns
-------
dict
The updated statistics dictionary to be displayed in the progress bar.
"""
step_stats = step_stats_dict[self._step_id]

if not step_stats["tune"]:
stats["divergences"][chain_idx] += step_stats["diverging"]
displayed_stats["divergences"][chain_idx] += step_stats["diverging"]

stats["step_size"][chain_idx] = step_stats["step_size"]
stats["tree_size"][chain_idx] = step_stats["tree_size"]
return stats
displayed_stats["step_size"][chain_idx] = step_stats["step_size"]
displayed_stats["tree_size"][chain_idx] = step_stats["tree_size"]
return displayed_stats

return update_stats

83 changes: 69 additions & 14 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from collections.abc import Callable, Iterator
from dataclasses import field
from typing import Any

@@ -24,7 +24,7 @@
from pytensor import tensor as pt
from pytensor.graph.fg import MissingInputError
from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV
from rich.progress import TextColumn
from rich.progress import ProgressColumn, TextColumn
from rich.table import Column

import pymc as pm
@@ -166,6 +166,7 @@ def __init__(
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = False,
step_id_generator: Iterator[int] | None = None,
):
"""Create an instance of a Metropolis stepper.
@@ -258,7 +259,9 @@ def __init__(

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, blocked=blocked, rng=rng)
super().__init__(
vars, shared, blocked=blocked, rng=rng, step_id_generator=step_id_generator
)

def reset_tuning(self):
"""Reset the tuned sampler parameters to their initial values."""
@@ -327,8 +330,25 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
def competence(var, has_grad):
return Competence.COMPATIBLE

@staticmethod
def _progressbar_config(n_chains=1):
def _progressbar_config(
self, n_chains: int = 1
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
"""
Get progressbar configuration for this step sampler.
Parameters
----------
n_chains: int
Number of chains being sampled. This controls the number of progress bars that will be displayed.
Returns
-------
columns: list of rich.progress.ProgressColumn
List of columns to display in the progress bar.
stats: dict
Dictionary of statistics associated with each column.
"""
columns = [
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)),
@@ -345,17 +365,50 @@ def _progressbar_config(n_chains=1):

return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
"""
Create an update function used by the progress bar to update statistics during sampling.
Returns
-------
update_stats: Callable
Function that updates displayed statistics for the current chain, given statistics generated by the step
during the most recent step.
"""

def update_stats(
displayed_stats: dict[str, np.ndarray],
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
chain_idx: int,
) -> dict[str, np.ndarray]:
"""
Update the statistics displayed in the progress bar after each step.
Parameters
----------
displayed_stats: dict
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
the values are the current values of the statistics, with one value per chain being sampled.
step_stats_dict: dict
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
names of the statistics and the values are the values of the statistics generated by the step sampler.
chain_idx: int
The chain number associated with the current step
Returns
-------
dict
The updated statistics dictionary to be displayed in the progress bar.
"""
step_stats = step_stats_dict[self._step_id]

stats["tune"][chain_idx] = step_stats["tune"]
stats["accept_rate"][chain_idx] = step_stats["accept"]
stats["scaling"][chain_idx] = step_stats["scaling"]
displayed_stats["tune"][chain_idx] = step_stats["tune"]
displayed_stats["accept_rate"][chain_idx] = step_stats["accept"]
displayed_stats["scaling"][chain_idx] = step_stats["scaling"]

return stats
return displayed_stats

return update_stats

@@ -951,7 +1004,9 @@ def __init__(
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = True,
step_id_generator: Iterator[int] | None = None,
):
self._step_id = next(step_id_generator) if step_id_generator else None
model = pm.modelcontext(model)
if initial_point is None:
initial_point = model.initial_point()
63 changes: 47 additions & 16 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Iterator

# Modified from original implementation by Dominik Wabersich (2013)


import numpy as np

from rich.progress import TextColumn
@@ -35,9 +34,6 @@
LOOP_ERR_MSG = "max slicer iters %d exceeded"


dataclass_state


@dataclass_state
class SliceState(StepMethodState):
w: np.ndarray
@@ -91,7 +87,9 @@ def __init__(
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = False, # Could be true since tuning is independent across dims?
step_id_generator: Iterator[int] | None = None,
):
self._step_id = next(step_id_generator) if step_id_generator else None
model = modelcontext(model)
self.w = np.asarray(w).copy()
self.tune = tune
@@ -211,16 +209,49 @@ def _progressbar_config(n_chains=1):

return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]

stats["tune"][chain_idx] = step_stats["tune"]
stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
stats["nstep_in"][chain_idx] = step_stats["nstep_in"]

return stats
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
"""
Create an update function used by the progress bar to update statistics during sampling.
Returns
-------
update_stats: Callable
Function that updates displayed statistics for the current chain, given statistics generated by the step
during the most recent step.
"""

def update_stats(
displayed_stats: dict[str, np.ndarray],
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
chain_idx: int,
) -> dict[str, np.ndarray]:
"""
Update the statistics displayed in the progress bar after each step.
Parameters
----------
displayed_stats: dict
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
the values are the current values of the statistics, with one value per chain being sampled.
step_stats_dict: dict
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
names of the statistics and the values are the values of the statistics generated by the step sampler.
chain_idx: int
The chain number associated with the current step
Returns
-------
dict
The updated statistics dictionary to be displayed in the progress bar.
"""
step_stats = step_stats_dict[self._step_id]

displayed_stats["tune"][chain_idx] = step_stats["tune"]
displayed_stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
displayed_stats["nstep_in"][chain_idx] = step_stats["nstep_in"]

return displayed_stats

return update_stats
5 changes: 4 additions & 1 deletion pymc/util.py
Original file line number Diff line number Diff line change
@@ -883,7 +883,10 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
if not tuning and stats and stats[0].get("diverging"):
self.divergences += 1

self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
step_meta = [entry["step_meta"] for entry in stats]
step_id_to_stats = {meta["step_id"]: entry for meta, entry in zip(step_meta, stats)}

self.progress_stats = self.update_stats(self.progress_stats, step_id_to_stats, chain_idx)
more_updates = (
{stat: value[chain_idx] for stat, value in self.progress_stats.items()}
if self.full_stats