diff --git a/meridian/analysis/analyzer.py b/meridian/analysis/analyzer.py index b29522fca..a693b6e86 100644 --- a/meridian/analysis/analyzer.py +++ b/meridian/analysis/analyzer.py @@ -3959,36 +3959,44 @@ def rhat_summary(self, bad_rhat_threshold: float = 1.2) -> pd.DataFrame: is not `1` or `2`. """ rhat = self.get_rhat() - rhat_summary = [] + for param in rhat: - # Skip if parameter is deterministic according to the prior. - if self._meridian.prior_broadcast.has_deterministic_param(param): - continue - - if rhat[param].ndim == 2: - row_idx, col_idx = np.where(rhat[param] > bad_rhat_threshold) - elif rhat[param].ndim == 1: - row_idx = np.where(rhat[param] > bad_rhat_threshold)[0] - col_idx = [] - elif rhat[param].ndim == 0: - row_idx = col_idx = [] - else: - raise ValueError(f"Unexpected dimension for parameter {param}.") - - rhat_summary.append( - pd.Series({ - constants.PARAM: param, - constants.N_PARAMS: np.prod(rhat[param].shape), - constants.AVG_RHAT: np.nanmean(rhat[param]), - constants.MAX_RHAT: np.nanmax(rhat[param]), - constants.PERCENT_BAD_RHAT: np.nanmean( - rhat[param] > bad_rhat_threshold - ), - constants.ROW_IDX_BAD_RHAT: row_idx, - constants.COL_IDX_BAD_RHAT: col_idx, - }) - ) + # Skip if parameter is deterministic according to the prior. + if self._meridian.prior_broadcast.has_deterministic_param(param): + continue + param_rhat = rhat[param] + + # Handle scalar case (0D array) + if param_rhat.ndim == 0: + row_idx = col_idx = [] + param_rhat = np.array([param_rhat]) # Convert to 1D for consistent processing + + # Handle vector case (1D array) + elif param_rhat.ndim == 1: + row_idx = np.where(param_rhat > bad_rhat_threshold)[0] + col_idx = [] + + # Handle matrix case (2D array) + elif param_rhat.ndim == 2: + row_idx, col_idx = np.where(param_rhat > bad_rhat_threshold) + + else: + raise ValueError(f"Unexpected dimension {param_rhat.ndim} for parameter {param}.") + + # Calculate statistics + flat_rhat = param_rhat.ravel() # Flatten for consistent calculations + rhat_summary.append( + pd.Series({ + constants.PARAM: param, + constants.N_PARAMS: np.prod(param_rhat.shape), + constants.AVG_RHAT: np.nanmean(flat_rhat), + constants.MAX_RHAT: np.nanmax(flat_rhat), + constants.PERCENT_BAD_RHAT: np.nanmean(flat_rhat > bad_rhat_threshold), + constants.ROW_IDX_BAD_RHAT: row_idx, + constants.COL_IDX_BAD_RHAT: col_idx, + }) + ) return pd.DataFrame(rhat_summary) def response_curves(