Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/autocast/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .coverage import Coverage, MultiCoverage
from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
from .ensemble import CRPS, AlphaFairCRPS, FairCRPS
from .ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio

__all__ = [
"CRPS",
Expand All @@ -17,7 +17,15 @@
"FairCRPS",
"LInfinity",
"MultiCoverage",
"SpreadSkillRatio",
]

ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity)
ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage)
ALL_ENSEMBLE_METRICS = (
CRPS,
AlphaFairCRPS,
FairCRPS,
SpreadSkillRatio,
Coverage,
MultiCoverage,
)
60 changes: 60 additions & 0 deletions src/autocast/metrics/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,63 @@ def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC:
afcrps_reduced = afcrps.mean(dim=tuple(range(2, 2 + n_spatial_dims)))

return afcrps_reduced


class SpreadSkillRatio(BTSCMMetric):
r"""
Corrected spread-to-skill ratio (SSR) for ensemble forecasts.

Notes
-----
Uses the corrected finite-ensemble form:
.. math::
\text{SSR}_{\text{corrected}} = \frac{\text{Spread}}{\text{Skill}}
\sqrt{\frac{M + 1}{M}},
where skill is RMSE of ensemble mean and spread is ensemble standard deviation,
both aggregated over spatial dimensions.

"""

name: str = "ssr"

def __init__(self, eps: float = 1e-6):
super().__init__()
if eps <= 0:
msg = "eps must be > 0"
raise ValueError(msg)
self.eps = eps

def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC:
"""
Compute corrected spread-to-skill ratio reduced over spatial dims.

Args:
y_pred: (B, T, S, C, M)
y_true: (B, T, S, C)

Returns
-------
SSR: (B, T, C)
"""
n_ensemble = y_pred.shape[-1]
if n_ensemble < 2:
raise ValueError(
"SpreadSkillRatio requires at least 2 ensemble members "
f"(got {n_ensemble})."
)

# Ensemble mean forecast: (B, T, S, C)
ensemble_mean = y_pred.mean(dim=-1)
n_spatial_dims = self._infer_n_spatial_dims(ensemble_mean)
spatial_dims = tuple(range(2, 2 + n_spatial_dims))

# Skill = RMSE of ensemble mean over spatial dims: (B, T, C)
skill = torch.sqrt(((ensemble_mean - y_true) ** 2).mean(dim=spatial_dims))

# Spread = sqrt(mean spatial ensemble variance): (B, T, C)
spread_variance = y_pred.var(dim=-1, unbiased=True)
spread = torch.sqrt(spread_variance.mean(dim=spatial_dims))

correction = float(np.sqrt((n_ensemble + 1) / n_ensemble))
ssr = (spread / torch.clamp(skill, min=self.eps)) * correction
return ssr
19 changes: 12 additions & 7 deletions src/autocast/scripts/eval/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from autocast.benchmarking import benchmark_model, benchmark_rollout
from autocast.metrics import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
from autocast.metrics.coverage import MultiCoverage
from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS
from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS, SpreadSkillRatio
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
from autocast.models.encoder_processor_decoder_ensemble import (
EncoderProcessorDecoderEnsemble,
Expand Down Expand Up @@ -57,6 +57,7 @@
"crps": CRPS,
"fcrps": FairCRPS,
"afcrps": AlphaFairCRPS,
"ssr": SpreadSkillRatio,
}


Expand Down Expand Up @@ -691,11 +692,15 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no
compute_coverage = eval_cfg.get("compute_coverage", False)
test_metric_fns: dict[str, Callable[[], Metric]] = {}

metric_registry = dict(AVAILABLE_METRICS)
if n_members and n_members > 1:
metric_registry.update(AVAILABLE_METRICS_ENSEMBLE)

for name in metrics_list:
if name in AVAILABLE_METRICS:
test_metric_fns[name] = AVAILABLE_METRICS[name]
if name in metric_registry:
test_metric_fns[name] = metric_registry[name]
else:
log.warning("Metric %s not found in AVAILABLE_METRICS", name)
log.warning("Metric %s not found in available metrics", name)

if (n_members > 1) or compute_coverage:

Expand Down Expand Up @@ -796,10 +801,10 @@ def coverage_factory() -> Metric:

if compute_rollout_metrics:
for name in metrics_list:
if name in AVAILABLE_METRICS:
rollout_metric_fns[name] = AVAILABLE_METRICS[name]
if name in metric_registry:
rollout_metric_fns[name] = metric_registry[name]
else:
msg = f"Metric {name} not found in AVAILABLE_METRICS"
msg = f"Metric {name} not found in available metrics"
log.warning(msg)

if compute_rollout_coverage and n_members and n_members > 1:
Expand Down
28 changes: 27 additions & 1 deletion tests/metrics/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from autocast.metrics import ALL_ENSEMBLE_METRICS
from autocast.metrics.base import BaseMetric
from autocast.metrics.coverage import Coverage
from autocast.metrics.ensemble import SpreadSkillRatio
from autocast.types import TensorBTSC
from autocast.types.types import TensorBTC

Expand All @@ -13,7 +14,9 @@
if issubclass(metric_cls, BaseMetric)
)

ENSEMBLE_ERROR_METRICS = tuple(m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage])
ENSEMBLE_ERROR_METRICS = tuple(
m for m in ENSEMBLE_BASE_METRICS if m not in [Coverage, SpreadSkillRatio]
)


@pytest.mark.parametrize("MetricCls", ENSEMBLE_ERROR_METRICS)
Expand Down Expand Up @@ -74,3 +77,26 @@ def test_ensemble_metrics_stateful(MetricCls):
value = metric.compute()

assert torch.allclose(value, torch.tensor(0.0))


def test_spread_skill_ratio_matches_reference_formula():
# Shape: (B=1, T=1, S=1, C=1, M=2)
# Members: [0, 2], truth: [0]
# ensemble_mean = 1 -> skill = sqrt((1 - 0)^2) = 1
# unbiased ensemble variance = ((0-1)^2 + (2-1)^2) / (2-1) = 2
# spread = sqrt(2)
# correction = sqrt((M+1)/M) = sqrt(3/2)
# corrected SSR = sqrt(2) * sqrt(3/2) = sqrt(3)
y_pred = torch.tensor([[[[[0.0, 2.0]]]]])
y_true = torch.tensor([[[[0.0]]]])

value = SpreadSkillRatio(eps=1e-12)(y_pred, y_true)
assert torch.allclose(value, torch.tensor(3.0**0.5), atol=1e-6)


def test_spread_skill_ratio_requires_multiple_ensemble_members():
y_pred = torch.ones((1, 1, 1, 1, 1))
y_true = torch.ones((1, 1, 1, 1))

with pytest.raises(ValueError, match="at least 2 ensemble members"):
SpreadSkillRatio()(y_pred, y_true)