diff --git a/src/autocast/metrics/__init__.py b/src/autocast/metrics/__init__.py index 5aad06b6..64dae463 100644 --- a/src/autocast/metrics/__init__.py +++ b/src/autocast/metrics/__init__.py @@ -1,6 +1,6 @@ from .coverage import Coverage, MultiCoverage from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity -from .ensemble import CRPS, AlphaFairCRPS, FairCRPS +from .ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio __all__ = [ "CRPS", @@ -17,7 +17,15 @@ "FairCRPS", "LInfinity", "MultiCoverage", + "SpreadSkillRatio", ] ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity) -ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage) +ALL_ENSEMBLE_METRICS = ( + CRPS, + AlphaFairCRPS, + FairCRPS, + SpreadSkillRatio, + Coverage, + MultiCoverage, +) diff --git a/src/autocast/metrics/ensemble.py b/src/autocast/metrics/ensemble.py index d276bba8..266394ae 100644 --- a/src/autocast/metrics/ensemble.py +++ b/src/autocast/metrics/ensemble.py @@ -279,3 +279,63 @@ def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC: afcrps_reduced = afcrps.mean(dim=tuple(range(2, 2 + n_spatial_dims))) return afcrps_reduced + + +class SpreadSkillRatio(BTSCMMetric): + r""" + Corrected spread-to-skill ratio (SSR) for ensemble forecasts. + + Notes + ----- + Uses the corrected finite-ensemble form: + .. math:: + \text{SSR}_{\text{corrected}} = \frac{\text{Spread}}{\text{Skill}} + \sqrt{\frac{M + 1}{M}}, + where skill is RMSE of ensemble mean and spread is ensemble standard deviation, + both aggregated over spatial dimensions. + + """ + + name: str = "ssr" + + def __init__(self, eps: float = 1e-6): + super().__init__() + if eps <= 0: + msg = "eps must be > 0" + raise ValueError(msg) + self.eps = eps + + def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC: + """ + Compute corrected spread-to-skill ratio reduced over spatial dims. + + Args: + y_pred: (B, T, S, C, M) + y_true: (B, T, S, C) + + Returns + ------- + SSR: (B, T, C) + """ + n_ensemble = y_pred.shape[-1] + if n_ensemble < 2: + raise ValueError( + "SpreadSkillRatio requires at least 2 ensemble members " + f"(got {n_ensemble})." + ) + + # Ensemble mean forecast: (B, T, S, C) + ensemble_mean = y_pred.mean(dim=-1) + n_spatial_dims = self._infer_n_spatial_dims(ensemble_mean) + spatial_dims = tuple(range(2, 2 + n_spatial_dims)) + + # Skill = RMSE of ensemble mean over spatial dims: (B, T, C) + skill = torch.sqrt(((ensemble_mean - y_true) ** 2).mean(dim=spatial_dims)) + + # Spread = sqrt(mean spatial ensemble variance): (B, T, C) + spread_variance = y_pred.var(dim=-1, unbiased=True) + spread = torch.sqrt(spread_variance.mean(dim=spatial_dims)) + + correction = float(np.sqrt((n_ensemble + 1) / n_ensemble)) + ssr = (spread / torch.clamp(skill, min=self.eps)) * correction + return ssr diff --git a/src/autocast/scripts/eval/encoder_processor_decoder.py b/src/autocast/scripts/eval/encoder_processor_decoder.py index ce51824e..3d3051a9 100644 --- a/src/autocast/scripts/eval/encoder_processor_decoder.py +++ b/src/autocast/scripts/eval/encoder_processor_decoder.py @@ -16,7 +16,7 @@ from autocast.benchmarking import benchmark_model, benchmark_rollout from autocast.metrics import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity from autocast.metrics.coverage import MultiCoverage -from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS +from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder from autocast.models.encoder_processor_decoder_ensemble import ( EncoderProcessorDecoderEnsemble, @@ -57,6 +57,7 @@ "crps": CRPS, "fcrps": FairCRPS, "afcrps": AlphaFairCRPS, + "ssr": SpreadSkillRatio, } @@ -691,11 +692,15 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no compute_coverage = eval_cfg.get("compute_coverage", False) test_metric_fns: dict[str, Callable[[], Metric]] = {} + metric_registry = dict(AVAILABLE_METRICS) + if n_members and n_members > 1: + metric_registry.update(AVAILABLE_METRICS_ENSEMBLE) + for name in metrics_list: - if name in AVAILABLE_METRICS: - test_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + test_metric_fns[name] = metric_registry[name] else: - log.warning("Metric %s not found in AVAILABLE_METRICS", name) + log.warning("Metric %s not found in available metrics", name) if (n_members > 1) or compute_coverage: @@ -796,10 +801,10 @@ def coverage_factory() -> Metric: if compute_rollout_metrics: for name in metrics_list: - if name in AVAILABLE_METRICS: - rollout_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + rollout_metric_fns[name] = metric_registry[name] else: - msg = f"Metric {name} not found in AVAILABLE_METRICS" + msg = f"Metric {name} not found in available metrics" log.warning(msg) if compute_rollout_coverage and n_members and n_members > 1: diff --git a/tests/metrics/test_ensemble.py b/tests/metrics/test_ensemble.py index cbe3d351..8d69d738 100644 --- a/tests/metrics/test_ensemble.py +++ b/tests/metrics/test_ensemble.py @@ -4,6 +4,7 @@ from autocast.metrics import ALL_ENSEMBLE_METRICS from autocast.metrics.base import BaseMetric from autocast.metrics.coverage import Coverage +from autocast.metrics.ensemble import SpreadSkillRatio from autocast.types import TensorBTSC from autocast.types.types import TensorBTC @@ -13,7 +14,9 @@ if issubclass(metric_cls, BaseMetric) ) -ENSEMBLE_ERROR_METRICS = tuple(m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage]) +ENSEMBLE_ERROR_METRICS = tuple( + m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage, SpreadSkillRatio] +) @pytest.mark.parametrize("MetricCls", ENSEMBLE_ERROR_METRICS) @@ -74,3 +77,26 @@ def test_ensemble_metrics_stateful(MetricCls): value = metric.compute() assert torch.allclose(value, torch.tensor(0.0)) + + +def test_spread_skill_ratio_matches_reference_formula(): + # Shape: (B=1, T=1, S=1, C=1, M=2) + # Members: [0, 2], truth: [0] + # ensemble_mean = 1 -> skill = sqrt((1 - 0)^2) = 1 + # unbiased ensemble variance = ((0-1)^2 + (2-1)^2) / (2-1) = 2 + # spread = sqrt(2) + # correction = sqrt((M+1)/M) = sqrt(3/2) + # corrected SSR = sqrt(2) * sqrt(3/2) = sqrt(3) + y_pred = torch.tensor([[[[[0.0, 2.0]]]]]) + y_true = torch.tensor([[[[0.0]]]]) + + value = SpreadSkillRatio(eps=1e-12)(y_pred, y_true) + assert torch.allclose(value, torch.tensor(3.0**0.5), atol=1e-6) + + +def test_spread_skill_ratio_requires_multiple_ensemble_members(): + y_pred = torch.ones((1, 1, 1, 1, 1)) + y_true = torch.ones((1, 1, 1, 1)) + + with pytest.raises(ValueError, match="at least 2 ensemble members"): + SpreadSkillRatio()(y_pred, y_true)