diff --git a/src/autocast/metrics/__init__.py b/src/autocast/metrics/__init__.py index 5aad06b6..85ec39c6 100644 --- a/src/autocast/metrics/__init__.py +++ b/src/autocast/metrics/__init__.py @@ -1,5 +1,25 @@ from .coverage import Coverage, MultiCoverage -from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity +from .deterministic import ( + MAE, + MSE, + NMAE, + NMSE, + NRMSE, + RMSE, + VMSE, + VRMSE, + LInfinity, + PowerSpectrumCCRMSE, + PowerSpectrumCCRMSEHigh, + PowerSpectrumCCRMSELow, + PowerSpectrumCCRMSEMid, + PowerSpectrumCCRMSETail, + PowerSpectrumRMSE, + PowerSpectrumRMSEHigh, + PowerSpectrumRMSELow, + PowerSpectrumRMSEMid, + PowerSpectrumRMSETail, +) from .ensemble import CRPS, AlphaFairCRPS, FairCRPS __all__ = [ @@ -17,7 +37,43 @@ "FairCRPS", "LInfinity", "MultiCoverage", + "PowerSpectrumCCRMSE", + "PowerSpectrumCCRMSEHigh", + "PowerSpectrumCCRMSELow", + "PowerSpectrumCCRMSEMid", + "PowerSpectrumCCRMSETail", + "PowerSpectrumRMSE", + "PowerSpectrumRMSEHigh", + "PowerSpectrumRMSELow", + "PowerSpectrumRMSEMid", + "PowerSpectrumRMSETail", ] -ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity) -ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage) +ALL_DETERMINISTIC_METRICS = ( + MSE, + MAE, + NMAE, + NMSE, + RMSE, + NRMSE, + VMSE, + VRMSE, + LInfinity, + PowerSpectrumRMSE, + PowerSpectrumRMSELow, + PowerSpectrumRMSEMid, + PowerSpectrumRMSEHigh, + PowerSpectrumRMSETail, + PowerSpectrumCCRMSE, + PowerSpectrumCCRMSELow, + PowerSpectrumCCRMSEMid, + PowerSpectrumCCRMSEHigh, + PowerSpectrumCCRMSETail, +) +ALL_ENSEMBLE_METRICS = ( + CRPS, + AlphaFairCRPS, + FairCRPS, + Coverage, + MultiCoverage, +) diff --git a/src/autocast/metrics/deterministic.py b/src/autocast/metrics/deterministic.py index 8eb81406..8098a417 100644 --- a/src/autocast/metrics/deterministic.py +++ b/src/autocast/metrics/deterministic.py @@ -1,3 +1,17 @@ +"""Deterministic metrics. + +Power-spectrum RMSE utilities in this module based on the implementation from: +- Lost in Latent Space: An Empirical Study of Latent Diffusion Models for Physics +Emulation (Rozet et al., 2024), https://arxiv.org/abs/2507.02608, +https://github.com/PolymathicAI/lola +- Specific code from: + - https://github.com/PolymathicAI/lola/blob/main/lola/fourier.py + - https://github.com/PolymathicAI/lola/blob/main/experiments/eval.py +""" + +import math +from functools import cache + import numpy as np import torch @@ -273,3 +287,296 @@ def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: torch.abs(y_pred - y_true).flatten(start_dim=spatial_dims[0], end_dim=-2), dim=-2, ).values + + +@cache +def _isotropic_binning_cpu( + shape: tuple[int, ...], + bins: int | None = None, +) -> tuple[Tensor, Tensor, Tensor]: + """Isotropic frequency binning over FFT domain on CPU (cached). + + References + ---------- + - https://github.com/PolymathicAI/lola/blob/bd4bdf2a9fc024e6b2aa95eb4e24a800fec98dae/lola/fourier.py + """ + k = [] + for s in shape: + k_i = torch.fft.fftfreq(s) + k.append(k_i) + + k2 = map(torch.square, k) + k2_grid = torch.meshgrid(*k2, indexing="ij") + k2_iso = torch.zeros_like(k2_grid[0]) + for component in k2_grid: + k2_iso = k2_iso + component + k_iso = torch.sqrt(k2_iso) + + if bins is None: + bins = math.floor(math.sqrt(k_iso.ndim) * min(k_iso.shape) / 2) + + edges = torch.linspace(0, k_iso.max(), bins + 1) + indices = torch.bucketize(k_iso.flatten(), edges) + counts = torch.bincount(indices, minlength=bins + 1) + + return edges, counts, indices + + +def _isotropic_binning( + shape: tuple[int, ...], + bins: int | None = None, + device: torch.device | None = None, +) -> tuple[Tensor, Tensor, Tensor]: + """Isotropic frequency binning over FFT domain. + + The cached representation is always on CPU to avoid storing device-specific + tensors in the global cache. Returned tensors are moved to `device`. + """ + edges, counts, indices = _isotropic_binning_cpu(shape, bins) + if device is None: + return edges, counts, indices + + dev = torch.device(device) + return edges.to(dev), counts.to(dev), indices.to(dev) + + +def _lola_eval_power_band_masks( + freq_bins: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Build the four power-spectrum masks used in Lola eval.py.""" + # bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2) + bins = torch.logspace( + freq_bins[0].log2(), + -1.0, + steps=4, + base=2, + device=freq_bins.device, + ) + + m0 = torch.logical_and(bins[0] <= freq_bins, freq_bins <= bins[1]) + m1 = torch.logical_and(bins[1] <= freq_bins, freq_bins <= bins[2]) + m2 = torch.logical_and(bins[2] <= freq_bins, freq_bins <= bins[3]) + m3 = bins[3] <= freq_bins + return m0, m1, m2, m3 + + +def _isotropic_spectral_components( + y_pred: TensorBTSC, y_true: TensorBTSC, n_spatial_dims: int +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Compute isotropic power and cross-power spectra from a single FFT pair. + + Returns (pred_spec, true_spec, cross_spec, freq_bins), each with shape + (B, T, C, bins) except freq_bins which has shape (bins,). + """ + y_pred_btc = y_pred.movedim(-1, 2) # (B, T, C, S...) + y_true_btc = y_true.movedim(-1, 2) + spatial_shape = tuple(y_pred_btc.shape[-n_spatial_dims:]) + + edges, counts, indices = _isotropic_binning(spatial_shape, device=y_pred.device) + + fft_dims = tuple(range(-n_spatial_dims, 0)) + spec_pred = torch.fft.fftn(y_pred_btc, dim=fft_dims, norm="ortho") + spec_true = torch.fft.fftn(y_true_btc, dim=fft_dims, norm="ortho") + + power_pred = torch.abs(spec_pred).square().flatten(start_dim=-n_spatial_dims) + power_true = torch.abs(spec_true).square().flatten(start_dim=-n_spatial_dims) + cross = torch.abs(spec_pred * torch.conj(spec_true)).flatten( + start_dim=-n_spatial_dims + ) + + counts_clamped = torch.clamp(counts, min=1).to(dtype=y_pred.dtype) + counts_view = counts_clamped.view(*([1] * (power_pred.ndim - 1)), -1) + + def _bin(p: Tensor) -> Tensor: + iso = torch.zeros( + (*p.shape[:-1], edges.numel()), dtype=y_pred.dtype, device=y_pred.device + ) + iso = iso.scatter_add(dim=-1, index=indices.expand_as(p), src=p) + return (iso / counts_view)[..., 1:] + + return _bin(power_pred), _bin(power_true), _bin(cross), edges[1:] + + +def _power_spectrum_rmse_bands( + y_pred: TensorBTSC, + y_true: TensorBTSC, + eps: float, +) -> tuple[TensorBTC, TensorBTC, TensorBTC, TensorBTC]: + """Compute Lola-style per-band RMSE of relative isotropic power spectra.""" + n_spatial_dims = y_true.ndim - 3 + pred_spec, true_spec, _, freq_bins = _isotropic_spectral_components( + y_pred, y_true, n_spatial_dims + ) + + m0, m1, m2, m3 = _lola_eval_power_band_masks(freq_bins) + + def _band_rmse(mask: Tensor) -> TensorBTC: + # Small spatial grids can produce empty spectral bands. + # Define RMSE over an empty band as zero to avoid NaNs in downstream + # deterministic/stateful reductions. + if not torch.any(mask): + return torch.zeros( + pred_spec.shape[:-1], dtype=pred_spec.dtype, device=pred_spec.device + ) + # se_p = (1 - (p_v + eps) / (p_u + eps))^2 + se_p = torch.square( + 1.0 - (pred_spec[..., mask] + eps) / (true_spec[..., mask] + eps) + ) + return torch.sqrt(torch.mean(se_p, dim=-1)) + + return _band_rmse(m0), _band_rmse(m1), _band_rmse(m2), _band_rmse(m3) + + +def _cross_correlation_rmse_bands( + y_pred: TensorBTSC, + y_true: TensorBTSC, + eps: float, +) -> tuple[TensorBTC, TensorBTC, TensorBTC, TensorBTC]: + """Compute Lola-style per-band RMSE of cross-correlation spectra.""" + n_spatial_dims = y_true.ndim - 3 + pred_spec, true_spec, cross_spec, freq_bins = _isotropic_spectral_components( + y_pred, y_true, n_spatial_dims + ) + + m0, m1, m2, m3 = _lola_eval_power_band_masks(freq_bins) + + def _band_rmse(mask: Tensor) -> TensorBTC: + if not torch.any(mask): + return torch.zeros( + cross_spec.shape[:-1], dtype=cross_spec.dtype, device=cross_spec.device + ) + # se_c = (1 - (c_uv + eps) / sqrt(p_u * p_v + eps^2))^2 + se_c = torch.square( + 1.0 + - (cross_spec[..., mask] + eps) + / torch.sqrt(pred_spec[..., mask] * true_spec[..., mask] + eps**2) + ) + return torch.sqrt(torch.mean(se_c, dim=-1)) + + return _band_rmse(m0), _band_rmse(m1), _band_rmse(m2), _band_rmse(m3) + + +class PowerSpectrumRMSE(BTSCMetric): + """Average power spectrum RMSE across first three Lola eval bands.""" + + name: str = "psrmse" + + def __init__( + self, + reduce_all: bool = True, + dist_sync_on_step: bool = False, + eps: float = 1e-6, + ): + super().__init__( + reduce_all=reduce_all, + dist_sync_on_step=dist_sync_on_step, + ) + self.eps = eps + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + low, mid, high, _tail = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps) + return (low + mid + high) / 3.0 + + +class PowerSpectrumRMSELow(PowerSpectrumRMSE): + """Power spectrum RMSE in the low-frequency band.""" + + name: str = "psrmse_low" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + low, _, _, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps) + return low + + +class PowerSpectrumRMSEMid(PowerSpectrumRMSE): + """Power spectrum RMSE in the mid-frequency band.""" + + name: str = "psrmse_mid" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, mid, _, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps) + return mid + + +class PowerSpectrumRMSEHigh(PowerSpectrumRMSE): + """Power spectrum RMSE in the high-frequency band.""" + + name: str = "psrmse_high" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, _, high, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps) + return high + + +class PowerSpectrumRMSETail(PowerSpectrumRMSE): + """Power spectrum RMSE in the Lola high-frequency tail band.""" + + name: str = "psrmse_tail" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, _, _, tail = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps) + return tail + + +class PowerSpectrumCCRMSE(BTSCMetric): + """Average cross-correlation RMSE across first three Lola eval bands.""" + + name: str = "pscc" + + def __init__( + self, + reduce_all: bool = True, + dist_sync_on_step: bool = False, + eps: float = 1e-6, + ): + super().__init__( + reduce_all=reduce_all, + dist_sync_on_step=dist_sync_on_step, + ) + self.eps = eps + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + low, mid, high, _tail = _cross_correlation_rmse_bands( + y_pred, y_true, eps=self.eps + ) + return (low + mid + high) / 3.0 + + +class PowerSpectrumCCRMSELow(PowerSpectrumCCRMSE): + """Cross-correlation RMSE in the low-frequency band.""" + + name: str = "pscc_low" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + low, _, _, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps) + return low + + +class PowerSpectrumCCRMSEMid(PowerSpectrumCCRMSE): + """Cross-correlation RMSE in the mid-frequency band.""" + + name: str = "pscc_mid" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, mid, _, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps) + return mid + + +class PowerSpectrumCCRMSEHigh(PowerSpectrumCCRMSE): + """Cross-correlation RMSE in the high-frequency band.""" + + name: str = "pscc_high" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, _, high, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps) + return high + + +class PowerSpectrumCCRMSETail(PowerSpectrumCCRMSE): + """Cross-correlation RMSE in the Lola high-frequency tail band.""" + + name: str = "pscc_tail" + + def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC: + _, _, _, tail = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps) + return tail diff --git a/src/autocast/scripts/eval/encoder_processor_decoder.py b/src/autocast/scripts/eval/encoder_processor_decoder.py index 2462f70b..121759ae 100644 --- a/src/autocast/scripts/eval/encoder_processor_decoder.py +++ b/src/autocast/scripts/eval/encoder_processor_decoder.py @@ -14,7 +14,27 @@ from torchmetrics import Metric from autocast.benchmarking import benchmark_model, benchmark_rollout -from autocast.metrics import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity +from autocast.metrics import ( + MAE, + MSE, + NMAE, + NMSE, + NRMSE, + RMSE, + VMSE, + VRMSE, + LInfinity, + PowerSpectrumCCRMSE, + PowerSpectrumCCRMSEHigh, + PowerSpectrumCCRMSELow, + PowerSpectrumCCRMSEMid, + PowerSpectrumCCRMSETail, + PowerSpectrumRMSE, + PowerSpectrumRMSEHigh, + PowerSpectrumRMSELow, + PowerSpectrumRMSEMid, + PowerSpectrumRMSETail, +) from autocast.metrics.coverage import MultiCoverage from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, FairCRPS from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder @@ -54,6 +74,16 @@ "vmse": VMSE, "vrmse": VRMSE, "linf": LInfinity, + "psrmse": PowerSpectrumRMSE, + "psrmse_low": PowerSpectrumRMSELow, + "psrmse_mid": PowerSpectrumRMSEMid, + "psrmse_high": PowerSpectrumRMSEHigh, + "psrmse_tail": PowerSpectrumRMSETail, + "pscc": PowerSpectrumCCRMSE, + "pscc_low": PowerSpectrumCCRMSELow, + "pscc_mid": PowerSpectrumCCRMSEMid, + "pscc_high": PowerSpectrumCCRMSEHigh, + "pscc_tail": PowerSpectrumCCRMSETail, } AVAILABLE_METRICS_ENSEMBLE = { @@ -716,11 +746,15 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no compute_coverage = eval_cfg.get("compute_coverage", False) test_metric_fns: dict[str, Callable[[], Metric]] = {} + metric_registry = dict(AVAILABLE_METRICS) + if n_members and n_members > 1: + metric_registry.update(AVAILABLE_METRICS_ENSEMBLE) + for name in metrics_list: - if name in AVAILABLE_METRICS: - test_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + test_metric_fns[name] = metric_registry[name] else: - log.warning("Metric %s not found in AVAILABLE_METRICS", name) + log.warning("Metric %s not found in available metrics", name) if (n_members > 1) or compute_coverage: @@ -822,10 +856,10 @@ def coverage_factory() -> Metric: if compute_rollout_metrics: for name in metrics_list: - if name in AVAILABLE_METRICS: - rollout_metric_fns[name] = AVAILABLE_METRICS[name] + if name in metric_registry: + rollout_metric_fns[name] = metric_registry[name] else: - msg = f"Metric {name} not found in AVAILABLE_METRICS" + msg = f"Metric {name} not found in available metrics" log.warning(msg) if compute_rollout_coverage and n_members and n_members > 1: diff --git a/tests/metrics/test_deterministic.py b/tests/metrics/test_deterministic.py index e0c38351..a013c15b 100644 --- a/tests/metrics/test_deterministic.py +++ b/tests/metrics/test_deterministic.py @@ -2,6 +2,19 @@ import torch from autocast.metrics import ALL_DETERMINISTIC_METRICS +from autocast.metrics.deterministic import ( + PowerSpectrumCCRMSE, + PowerSpectrumCCRMSEHigh, + PowerSpectrumCCRMSELow, + PowerSpectrumCCRMSEMid, + PowerSpectrumCCRMSETail, + PowerSpectrumRMSE, + PowerSpectrumRMSEHigh, + PowerSpectrumRMSELow, + PowerSpectrumRMSEMid, + PowerSpectrumRMSETail, + _isotropic_binning, +) from autocast.types import TensorBTSC @@ -31,3 +44,80 @@ def test_spatiotemporal_metrics_stateful(MetricCls): value = metric.compute() assert torch.allclose(value, torch.tensor(0.0)) + + +@pytest.mark.parametrize( + "MetricCls", + [ + PowerSpectrumRMSE, + PowerSpectrumRMSELow, + PowerSpectrumRMSEMid, + PowerSpectrumRMSEHigh, + PowerSpectrumRMSETail, + ], +) +def test_power_spectrum_rmse_increases_with_spectral_scale(MetricCls): + torch.manual_seed(0) + y_true: TensorBTSC = torch.randn((1, 1, 8, 8, 1)) + y_pred: TensorBTSC = 2.0 * y_true + + value = MetricCls()(y_pred, y_true) + assert torch.all(value > 0) + + +@pytest.mark.parametrize( + "MetricCls", + [ + PowerSpectrumCCRMSE, + PowerSpectrumCCRMSELow, + PowerSpectrumCCRMSEMid, + PowerSpectrumCCRMSEHigh, + PowerSpectrumCCRMSETail, + ], +) +def test_cross_correlation_rmse_nonzero_for_uncorrelated(MetricCls): + torch.manual_seed(0) + y_true: TensorBTSC = torch.randn((1, 1, 8, 8, 1)) + y_pred: TensorBTSC = torch.randn((1, 1, 8, 8, 1)) + + value = MetricCls()(y_pred, y_true) + assert torch.all(value > 0) + + +def test_cross_correlation_rmse_near_zero_for_identical(): + torch.manual_seed(0) + y: TensorBTSC = torch.randn((1, 1, 8, 8, 1)) + value = PowerSpectrumCCRMSE()(y, y) + # eps regularisation means result is near-zero, not exactly zero. + assert torch.all(value < 1e-5) + + +def test_isotropic_binning_respects_requested_device(): + shape = (8, 8) + edges_cpu, counts_cpu, indices_cpu = _isotropic_binning( + shape, device=torch.device("cpu") + ) + assert edges_cpu.device.type == "cpu" + assert counts_cpu.device.type == "cpu" + assert indices_cpu.device.type == "cpu" + + target_device: torch.device | None = None + if torch.cuda.is_available(): + target_device = torch.device("cuda:0") + elif torch.backends.mps.is_available(): + target_device = torch.device("mps") + + if target_device is None: + return + + # Cross-device call after CPU call should still honor requested device. + edges_dev, counts_dev, indices_dev = _isotropic_binning(shape, device=target_device) + assert edges_dev.device.type == target_device.type + assert counts_dev.device.type == target_device.type + assert indices_dev.device.type == target_device.type + + # CUDA indices are stable/meaningful; MPS may report mps:0 while target is mps. + if target_device.type == "cuda": + assert edges_dev.device.index == target_device.index + assert counts_dev.device.index == target_device.index + assert indices_dev.device.index == target_device.index