Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions meridian/analysis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3959,36 +3959,44 @@ def rhat_summary(self, bad_rhat_threshold: float = 1.2) -> pd.DataFrame:
is not `1` or `2`.
"""
rhat = self.get_rhat()

rhat_summary = []

for param in rhat:
# Skip if parameter is deterministic according to the prior.
if self._meridian.prior_broadcast.has_deterministic_param(param):
continue

if rhat[param].ndim == 2:
row_idx, col_idx = np.where(rhat[param] > bad_rhat_threshold)
elif rhat[param].ndim == 1:
row_idx = np.where(rhat[param] > bad_rhat_threshold)[0]
col_idx = []
elif rhat[param].ndim == 0:
row_idx = col_idx = []
else:
raise ValueError(f"Unexpected dimension for parameter {param}.")

rhat_summary.append(
pd.Series({
constants.PARAM: param,
constants.N_PARAMS: np.prod(rhat[param].shape),
constants.AVG_RHAT: np.nanmean(rhat[param]),
constants.MAX_RHAT: np.nanmax(rhat[param]),
constants.PERCENT_BAD_RHAT: np.nanmean(
rhat[param] > bad_rhat_threshold
),
constants.ROW_IDX_BAD_RHAT: row_idx,
constants.COL_IDX_BAD_RHAT: col_idx,
})
)
# Skip if parameter is deterministic according to the prior.
if self._meridian.prior_broadcast.has_deterministic_param(param):
continue
param_rhat = rhat[param]

# Handle scalar case (0D array)
if param_rhat.ndim == 0:
row_idx = col_idx = []
param_rhat = np.array([param_rhat]) # Convert to 1D for consistent processing

# Handle vector case (1D array)
elif param_rhat.ndim == 1:
row_idx = np.where(param_rhat > bad_rhat_threshold)[0]
col_idx = []

# Handle matrix case (2D array)
elif param_rhat.ndim == 2:
row_idx, col_idx = np.where(param_rhat > bad_rhat_threshold)

else:
raise ValueError(f"Unexpected dimension {param_rhat.ndim} for parameter {param}.")

# Calculate statistics
flat_rhat = param_rhat.ravel() # Flatten for consistent calculations
rhat_summary.append(
pd.Series({
constants.PARAM: param,
constants.N_PARAMS: np.prod(param_rhat.shape),
constants.AVG_RHAT: np.nanmean(flat_rhat),
constants.MAX_RHAT: np.nanmax(flat_rhat),
constants.PERCENT_BAD_RHAT: np.nanmean(flat_rhat > bad_rhat_threshold),
constants.ROW_IDX_BAD_RHAT: row_idx,
constants.COL_IDX_BAD_RHAT: col_idx,
})
)
return pd.DataFrame(rhat_summary)

def response_curves(
Expand Down