Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions src/autocast/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
from .coverage import Coverage, MultiCoverage
from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
from .deterministic import (
MAE,
MSE,
NMAE,
NMSE,
NRMSE,
RMSE,
VMSE,
VRMSE,
LInfinity,
PowerSpectrumCCRMSE,
PowerSpectrumCCRMSEHigh,
PowerSpectrumCCRMSELow,
PowerSpectrumCCRMSEMid,
PowerSpectrumCCRMSETail,
PowerSpectrumRMSE,
PowerSpectrumRMSEHigh,
PowerSpectrumRMSELow,
PowerSpectrumRMSEMid,
PowerSpectrumRMSETail,
)
from .ensemble import CRPS, AlphaFairCRPS, FairCRPS

__all__ = [
Expand All @@ -17,7 +37,43 @@
"FairCRPS",
"LInfinity",
"MultiCoverage",
"PowerSpectrumCCRMSE",
"PowerSpectrumCCRMSEHigh",
"PowerSpectrumCCRMSELow",
"PowerSpectrumCCRMSEMid",
"PowerSpectrumCCRMSETail",
"PowerSpectrumRMSE",
"PowerSpectrumRMSEHigh",
"PowerSpectrumRMSELow",
"PowerSpectrumRMSEMid",
"PowerSpectrumRMSETail",
]

ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity)
ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage)
ALL_DETERMINISTIC_METRICS = (
MSE,
MAE,
NMAE,
NMSE,
RMSE,
NRMSE,
VMSE,
VRMSE,
LInfinity,
PowerSpectrumRMSE,
PowerSpectrumRMSELow,
PowerSpectrumRMSEMid,
PowerSpectrumRMSEHigh,
PowerSpectrumRMSETail,
PowerSpectrumCCRMSE,
PowerSpectrumCCRMSELow,
PowerSpectrumCCRMSEMid,
PowerSpectrumCCRMSEHigh,
PowerSpectrumCCRMSETail,
)
ALL_ENSEMBLE_METRICS = (
CRPS,
AlphaFairCRPS,
FairCRPS,
Coverage,
MultiCoverage,
)
307 changes: 307 additions & 0 deletions src/autocast/metrics/deterministic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
"""Deterministic metrics.

Power-spectrum RMSE utilities in this module based on the implementation from:
- Lost in Latent Space: An Empirical Study of Latent Diffusion Models for Physics
Emulation (Rozet et al., 2024), https://arxiv.org/abs/2507.02608,
https://github.com/PolymathicAI/lola
- Specific code from:
- https://github.com/PolymathicAI/lola/blob/main/lola/fourier.py
- https://github.com/PolymathicAI/lola/blob/main/experiments/eval.py
"""

import math
from functools import cache

import numpy as np
import torch

Expand Down Expand Up @@ -273,3 +287,296 @@ def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
torch.abs(y_pred - y_true).flatten(start_dim=spatial_dims[0], end_dim=-2),
dim=-2,
).values


@cache
def _isotropic_binning_cpu(
shape: tuple[int, ...],
bins: int | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""Isotropic frequency binning over FFT domain on CPU (cached).

References
----------
- https://github.com/PolymathicAI/lola/blob/bd4bdf2a9fc024e6b2aa95eb4e24a800fec98dae/lola/fourier.py
"""
k = []
for s in shape:
k_i = torch.fft.fftfreq(s)
k.append(k_i)

k2 = map(torch.square, k)
k2_grid = torch.meshgrid(*k2, indexing="ij")
k2_iso = torch.zeros_like(k2_grid[0])
for component in k2_grid:
k2_iso = k2_iso + component
k_iso = torch.sqrt(k2_iso)

if bins is None:
bins = math.floor(math.sqrt(k_iso.ndim) * min(k_iso.shape) / 2)

edges = torch.linspace(0, k_iso.max(), bins + 1)
indices = torch.bucketize(k_iso.flatten(), edges)
counts = torch.bincount(indices, minlength=bins + 1)

return edges, counts, indices


def _isotropic_binning(
shape: tuple[int, ...],
bins: int | None = None,
device: torch.device | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""Isotropic frequency binning over FFT domain.

The cached representation is always on CPU to avoid storing device-specific
tensors in the global cache. Returned tensors are moved to `device`.
"""
edges, counts, indices = _isotropic_binning_cpu(shape, bins)
if device is None:
return edges, counts, indices

dev = torch.device(device)
return edges.to(dev), counts.to(dev), indices.to(dev)


def _lola_eval_power_band_masks(
freq_bins: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Build the four power-spectrum masks used in Lola eval.py."""
# bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2)
bins = torch.logspace(
freq_bins[0].log2(),
-1.0,
steps=4,
base=2,
device=freq_bins.device,
)

m0 = torch.logical_and(bins[0] <= freq_bins, freq_bins <= bins[1])
m1 = torch.logical_and(bins[1] <= freq_bins, freq_bins <= bins[2])
m2 = torch.logical_and(bins[2] <= freq_bins, freq_bins <= bins[3])
m3 = bins[3] <= freq_bins
return m0, m1, m2, m3


def _isotropic_spectral_components(
y_pred: TensorBTSC, y_true: TensorBTSC, n_spatial_dims: int
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Compute isotropic power and cross-power spectra from a single FFT pair.

Returns (pred_spec, true_spec, cross_spec, freq_bins), each with shape
(B, T, C, bins) except freq_bins which has shape (bins,).
"""
y_pred_btc = y_pred.movedim(-1, 2) # (B, T, C, S...)
y_true_btc = y_true.movedim(-1, 2)
spatial_shape = tuple(y_pred_btc.shape[-n_spatial_dims:])

edges, counts, indices = _isotropic_binning(spatial_shape, device=y_pred.device)

fft_dims = tuple(range(-n_spatial_dims, 0))
spec_pred = torch.fft.fftn(y_pred_btc, dim=fft_dims, norm="ortho")
spec_true = torch.fft.fftn(y_true_btc, dim=fft_dims, norm="ortho")

power_pred = torch.abs(spec_pred).square().flatten(start_dim=-n_spatial_dims)
power_true = torch.abs(spec_true).square().flatten(start_dim=-n_spatial_dims)
cross = torch.abs(spec_pred * torch.conj(spec_true)).flatten(
start_dim=-n_spatial_dims
)

counts_clamped = torch.clamp(counts, min=1).to(dtype=y_pred.dtype)
counts_view = counts_clamped.view(*([1] * (power_pred.ndim - 1)), -1)

def _bin(p: Tensor) -> Tensor:
iso = torch.zeros(
(*p.shape[:-1], edges.numel()), dtype=y_pred.dtype, device=y_pred.device
)
iso = iso.scatter_add(dim=-1, index=indices.expand_as(p), src=p)
return (iso / counts_view)[..., 1:]

return _bin(power_pred), _bin(power_true), _bin(cross), edges[1:]


def _power_spectrum_rmse_bands(
y_pred: TensorBTSC,
y_true: TensorBTSC,
eps: float,
) -> tuple[TensorBTC, TensorBTC, TensorBTC, TensorBTC]:
"""Compute Lola-style per-band RMSE of relative isotropic power spectra."""
n_spatial_dims = y_true.ndim - 3
pred_spec, true_spec, _, freq_bins = _isotropic_spectral_components(
y_pred, y_true, n_spatial_dims
)

m0, m1, m2, m3 = _lola_eval_power_band_masks(freq_bins)

def _band_rmse(mask: Tensor) -> TensorBTC:
# Small spatial grids can produce empty spectral bands.
# Define RMSE over an empty band as zero to avoid NaNs in downstream
# deterministic/stateful reductions.
if not torch.any(mask):
return torch.zeros(
pred_spec.shape[:-1], dtype=pred_spec.dtype, device=pred_spec.device
)
# se_p = (1 - (p_v + eps) / (p_u + eps))^2
se_p = torch.square(
1.0 - (pred_spec[..., mask] + eps) / (true_spec[..., mask] + eps)
)
return torch.sqrt(torch.mean(se_p, dim=-1))

return _band_rmse(m0), _band_rmse(m1), _band_rmse(m2), _band_rmse(m3)


def _cross_correlation_rmse_bands(
y_pred: TensorBTSC,
y_true: TensorBTSC,
eps: float,
) -> tuple[TensorBTC, TensorBTC, TensorBTC, TensorBTC]:
"""Compute Lola-style per-band RMSE of cross-correlation spectra."""
n_spatial_dims = y_true.ndim - 3
pred_spec, true_spec, cross_spec, freq_bins = _isotropic_spectral_components(
y_pred, y_true, n_spatial_dims
)

m0, m1, m2, m3 = _lola_eval_power_band_masks(freq_bins)

def _band_rmse(mask: Tensor) -> TensorBTC:
if not torch.any(mask):
return torch.zeros(
cross_spec.shape[:-1], dtype=cross_spec.dtype, device=cross_spec.device
)
# se_c = (1 - (c_uv + eps) / sqrt(p_u * p_v + eps^2))^2
se_c = torch.square(
1.0
- (cross_spec[..., mask] + eps)
/ torch.sqrt(pred_spec[..., mask] * true_spec[..., mask] + eps**2)
)
return torch.sqrt(torch.mean(se_c, dim=-1))

return _band_rmse(m0), _band_rmse(m1), _band_rmse(m2), _band_rmse(m3)


class PowerSpectrumRMSE(BTSCMetric):
"""Average power spectrum RMSE across first three Lola eval bands."""

name: str = "psrmse"

def __init__(
self,
reduce_all: bool = True,
dist_sync_on_step: bool = False,
eps: float = 1e-6,
):
super().__init__(
reduce_all=reduce_all,
dist_sync_on_step=dist_sync_on_step,
)
self.eps = eps

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
low, mid, high, _tail = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps)
return (low + mid + high) / 3.0


class PowerSpectrumRMSELow(PowerSpectrumRMSE):
"""Power spectrum RMSE in the low-frequency band."""

name: str = "psrmse_low"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
low, _, _, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps)
return low


class PowerSpectrumRMSEMid(PowerSpectrumRMSE):
"""Power spectrum RMSE in the mid-frequency band."""

name: str = "psrmse_mid"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, mid, _, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps)
return mid


class PowerSpectrumRMSEHigh(PowerSpectrumRMSE):
"""Power spectrum RMSE in the high-frequency band."""

name: str = "psrmse_high"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, _, high, _ = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps)
return high


class PowerSpectrumRMSETail(PowerSpectrumRMSE):
"""Power spectrum RMSE in the Lola high-frequency tail band."""

name: str = "psrmse_tail"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, _, _, tail = _power_spectrum_rmse_bands(y_pred, y_true, eps=self.eps)
return tail


class PowerSpectrumCCRMSE(BTSCMetric):
"""Average cross-correlation RMSE across first three Lola eval bands."""

name: str = "pscc"

def __init__(
self,
reduce_all: bool = True,
dist_sync_on_step: bool = False,
eps: float = 1e-6,
):
super().__init__(
reduce_all=reduce_all,
dist_sync_on_step=dist_sync_on_step,
)
self.eps = eps

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
low, mid, high, _tail = _cross_correlation_rmse_bands(
y_pred, y_true, eps=self.eps
)
return (low + mid + high) / 3.0


class PowerSpectrumCCRMSELow(PowerSpectrumCCRMSE):
"""Cross-correlation RMSE in the low-frequency band."""

name: str = "pscc_low"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
low, _, _, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps)
return low


class PowerSpectrumCCRMSEMid(PowerSpectrumCCRMSE):
"""Cross-correlation RMSE in the mid-frequency band."""

name: str = "pscc_mid"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, mid, _, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps)
return mid


class PowerSpectrumCCRMSEHigh(PowerSpectrumCCRMSE):
"""Cross-correlation RMSE in the high-frequency band."""

name: str = "pscc_high"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, _, high, _ = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps)
return high


class PowerSpectrumCCRMSETail(PowerSpectrumCCRMSE):
"""Cross-correlation RMSE in the Lola high-frequency tail band."""

name: str = "pscc_tail"

def _score(self, y_pred: TensorBTSC, y_true: TensorBTSC) -> TensorBTC:
_, _, _, tail = _cross_correlation_rmse_bands(y_pred, y_true, eps=self.eps)
return tail
Loading
Loading