From 35ec3db00148a0ef71cb9485195fd182b06f38d4 Mon Sep 17 00:00:00 2001 From: Sam Greenbury Date: Tue, 10 Mar 2026 11:29:06 +0000 Subject: [PATCH] Add SpreadSkillRatio metric and update ensemble metrics Introduce the SpreadSkillRatio class for ensemble forecasts, implementing the corrected spread-to-skill ratio formula. Update the metrics initialization to include SpreadSkillRatio in both the ensemble metrics list and the evaluation scripts. Enhance tests to validate the functionality and requirements of the new metric. --- src/autocast/metrics/__init__.py | 12 +++- src/autocast/metrics/ensemble.py | 60 +++++++++++++++++++ .../scripts/eval/encoder_processor_decoder.py | 19 +++--- tests/metrics/test_ensemble.py | 28 ++++++++- 4 files changed, 109 insertions(+), 10 deletions(-) diff --git a/src/autocast/metrics/__init__.py b/src/autocast/metrics/__init__.py index 5aad06b6..64dae463 100644 --- a/src/autocast/metrics/__init__.py +++ b/src/autocast/metrics/__init__.py @@ -1,6 +1,6 @@ from .coverage import Coverage, MultiCoverage from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity -from .ensemble import CRPS, AlphaFairCRPS, FairCRPS +from .ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio __all__ = [ "CRPS", @@ -17,7 +17,15 @@ "FairCRPS", "LInfinity", "MultiCoverage", + "SpreadSkillRatio", ] ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity) -ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage) +ALL_ENSEMBLE_METRICS = ( + CRPS, + AlphaFairCRPS, + FairCRPS, + SpreadSkillRatio, + Coverage, + MultiCoverage, +) diff --git a/src/autocast/metrics/ensemble.py b/src/autocast/metrics/ensemble.py index d276bba8..266394ae 100644 --- a/src/autocast/metrics/ensemble.py +++ b/src/autocast/metrics/ensemble.py @@ -279,3 +279,63 @@ def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC: afcrps_reduced = afcrps.mean(dim=tuple(range(2, 2 + n_spatial_dims))) return afcrps_reduced + + +class SpreadSkillRatio(BTSCMMetric): + r""" + Corrected spread-to-skill ratio (SSR) for ensemble forecasts. + + Notes + ----- + Uses the corrected finite-ensemble form: + .. math:: + \text{SSR}_{\text{corrected}} = \frac{\text{Spread}}{\text{Skill}} + \sqrt{\frac{M + 1}{M}}, + where skill is RMSE of ensemble mean and spread is ensemble standard deviation, + both aggregated over spatial dimensions. + + """ + + name: str = "ssr" + + def __init__(self, eps: float = 1e-6): + super().__init__() + if eps <= 0: + msg = "eps must be > 0" + raise ValueError(msg) + self.eps = eps + + def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC: + """ + Compute corrected spread-to-skill ratio reduced over spatial dims. + + Args: + y_pred: (B, T, S, C, M) + y_true: (B, T, S, C) + + Returns + ------- + SSR: (B, T, C) + """ + n_ensemble = y_pred.shape[-1] + if n_ensemble < 2: + raise ValueError( + "SpreadSkillRatio requires at least 2 ensemble members " + f"(got {n_ensemble})." + ) + + # Ensemble mean forecast: (B, T, S, C) + ensemble_mean = y_pred.mean(dim=-1) + n_spatial_dims = self._infer_n_spatial_dims(ensemble_mean) + spatial_dims = tuple(range(2, 2 + n_spatial_dims)) + + # Skill = RMSE of ensemble mean over spatial dims: (B, T, C) + skill = torch.sqrt(((ensemble_mean - y_true) ** 2).mean(dim=spatial_dims)) + + # Spread = sqrt(mean spatial ensemble variance): (B, T, C) + spread_variance = y_pred.var(dim=-1, unbiased=True) + spread = torch.sqrt(spread_variance.mean(dim=spatial_dims)) + + correction = float(np.sqrt((n_ensemble + 1) / n_ensemble)) + ssr = (spread / torch.clamp(skill, min=self.eps)) * correction + return ssr diff --git a/src/autocast/scripts/eval/encoder_processor_decoder.py b/src/autocast/scripts/eval/encoder_processor_decoder.py index ce51824e..3d3051a9 100644 --- a/src/autocast/scripts/eval/encoder_processor_decoder.py +++ b/src/autocast/scripts/eval/encoder_processor_decoder.py @@ -16,7 +16,7 @@ from autocast.benchmarking import benchmark_model, benchmark_rollout from autocast.metrics import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity from autocast.metrics.coverage import MultiCoverage -from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS +from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder from autocast.models.encoder_processor_decoder_ensemble import ( EncoderProcessorDecoderEnsemble, @@ -57,6 +57,7 @@ "crps": CRPS, "fcrps": FairCRPS, "afcrps": AlphaFairCRPS, + "ssr": SpreadSkillRatio, } @@ -691,11 +692,15 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no compute_coverage = eval_cfg.get("compute_coverage", False) test_metric_fns: dict[str, Callable[[], Metric]] = {} + metric_registry = dict(AVAILABLE_METRICS) + if n_members and n_members > 1: + metric_registry.update(AVAILABLE_METRICS_ENSEMBLE) + for name in metrics_list: - if name in AVAILABLE_METRICS: - test_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + test_metric_fns[name] = metric_registry[name] else: - log.warning("Metric %s not found in AVAILABLE_METRICS", name) + log.warning("Metric %s not found in available metrics", name) if (n_members > 1) or compute_coverage: @@ -796,10 +801,10 @@ def coverage_factory() -> Metric: if compute_rollout_metrics: for name in metrics_list: - if name in AVAILABLE_METRICS: - rollout_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + rollout_metric_fns[name] = metric_registry[name] else: - msg = f"Metric {name} not found in AVAILABLE_METRICS" + msg = f"Metric {name} not found in available metrics" log.warning(msg) if compute_rollout_coverage and n_members and n_members > 1: diff --git a/tests/metrics/test_ensemble.py b/tests/metrics/test_ensemble.py index cbe3d351..8d69d738 100644 --- a/tests/metrics/test_ensemble.py +++ b/tests/metrics/test_ensemble.py @@ -4,6 +4,7 @@ from autocast.metrics import ALL_ENSEMBLE_METRICS from autocast.metrics.base import BaseMetric from autocast.metrics.coverage import Coverage +from autocast.metrics.ensemble import SpreadSkillRatio from autocast.types import TensorBTSC from autocast.types.types import TensorBTC @@ -13,7 +14,9 @@ if issubclass(metric_cls, BaseMetric) ) -ENSEMBLE_ERROR_METRICS = tuple(m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage]) +ENSEMBLE_ERROR_METRICS = tuple( + m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage, SpreadSkillRatio] +) @pytest.mark.parametrize("MetricCls", ENSEMBLE_ERROR_METRICS) @@ -74,3 +77,26 @@ def test_ensemble_metrics_stateful(MetricCls): value = metric.compute() assert torch.allclose(value, torch.tensor(0.0)) + + +def test_spread_skill_ratio_matches_reference_formula(): + # Shape: (B=1, T=1, S=1, C=1, M=2) + # Members: [0, 2], truth: [0] + # ensemble_mean = 1 -> skill = sqrt((1 - 0)^2) = 1 + # unbiased ensemble variance = ((0-1)^2 + (2-1)^2) / (2-1) = 2 + # spread = sqrt(2) + # correction = sqrt((M+1)/M) = sqrt(3/2) + # corrected SSR = sqrt(2) * sqrt(3/2) = sqrt(3) + y_pred = torch.tensor([[[[[0.0, 2.0]]]]]) + y_true = torch.tensor([[[[0.0]]]]) + + value = SpreadSkillRatio(eps=1e-12)(y_pred, y_true) + assert torch.allclose(value, torch.tensor(3.0**0.5), atol=1e-6) + + +def test_spread_skill_ratio_requires_multiple_ensemble_members(): + y_pred = torch.ones((1, 1, 1, 1, 1)) + y_true = torch.ones((1, 1, 1, 1)) + + with pytest.raises(ValueError, match="at least 2 ensemble members"): + SpreadSkillRatio()(y_pred, y_true)