Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7be781e
Add initial coverage metrics
sgreenbury Feb 2, 2026
4e42a96
Add support for plotting coverage with wandb
sgreenbury Feb 2, 2026
ae652d2
Refactor and add tests
sgreenbury Feb 2, 2026
7a229c1
Fix typing
sgreenbury Feb 2, 2026
135a8c6
Refactor coverage and comment out plots
sgreenbury Feb 2, 2026
aa663af
Fix trues shape returned
sgreenbury Feb 2, 2026
61deb1b
Add plotting functionality
sgreenbury Feb 2, 2026
c2bfe41
Remove plot logging
sgreenbury Feb 2, 2026
86b0daa
Initial coverage plotting
sgreenbury Feb 2, 2026
515130f
Initial coverage plotting ensemble notebook
sgreenbury Feb 2, 2026
8371982
Refactor covarage scores function
sgreenbury Feb 3, 2026
553cfc5
Refactor MultiCoverage
sgreenbury Feb 3, 2026
50b0555
Update eval script for coverage
sgreenbury Feb 3, 2026
00d87b3
Update notebook
sgreenbury Feb 3, 2026
0f2e786
Fix device handling
sgreenbury Feb 4, 2026
2ac5019
Refactor to use lightning, fix dim extraction
sgreenbury Feb 4, 2026
5e0bcc1
Fix types and refactor lightning use
sgreenbury Feb 4, 2026
0b596fd
Fix destructure
sgreenbury Feb 4, 2026
0092470
Add optimizations
sgreenbury Feb 4, 2026
e4ff577
Initial per-channel approach
sgreenbury Feb 4, 2026
0c2cec2
Add coverage window config and handling
sgreenbury Feb 4, 2026
3348a27
Skip if empty predictions or ground truth
sgreenbury Feb 4, 2026
ec0a2fe
Fix window types
sgreenbury Feb 4, 2026
59b2122
Add map_windows
sgreenbury Feb 4, 2026
716c169
Add batch_indices back
sgreenbury Feb 4, 2026
963b652
Refactor to simplify script
sgreenbury Feb 4, 2026
c4ce1ba
Simplify coverage plot functionality
sgreenbury Feb 4, 2026
cc92cb9
Move shape comment
sgreenbury Feb 4, 2026
7ca3ee6
Remove exception handling
sgreenbury Feb 4, 2026
bc8a5e6
Refactor plot function
sgreenbury Feb 4, 2026
ee7d0fd
Refactor, add comments, remove logging
sgreenbury Feb 5, 2026
086cab3
Fix test
sgreenbury Feb 5, 2026
b111dff
Fix label
sgreenbury Feb 5, 2026
599ceeb
Add tests
sgreenbury Feb 5, 2026
1f50877
Update notebook
sgreenbury Feb 5, 2026
4908b3d
Rename coverage to metric
sgreenbury Feb 5, 2026
a0b5c57
Fix tests
sgreenbury Feb 5, 2026
d7740ab
Add API to support computing any metrics from dataloader
sgreenbury Feb 5, 2026
717600e
Update eval script to get metrics for rollout windows
sgreenbury Feb 5, 2026
94634e3
Refactor eval script to reuse evaluate metric logic
sgreenbury Feb 5, 2026
ad35e50
Add per batch metrics, refactor
sgreenbury Feb 5, 2026
8085728
Update eval config
sgreenbury Feb 6, 2026
2e369f3
Refactor `compute_metrics_from_dataloader`
sgreenbury Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/eval/encoder_processor_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ defaults:
# Override with EPD-specific metrics and settings
metrics:
- mse
- mae
- rmse
- vrmse

max_rollout_steps: 25
compute_rollout_coverage: true
compute_rollout_metrics: true
metric_windows: [null]
metric_windows_rollout: [[0, 1], [6, 12], [13, 30], [31, 99]]
76 changes: 64 additions & 12 deletions notebooks/07_ViT_ensemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@
"\n",
"logger, watch = create_notebook_logger(\n",
" project=\"autocast-notebooks\",\n",
" name=f\"07_ViT_ensemble_{simulation_name}\",\n",
" name=f\"07_ViT_ensemble_{simulation_name}_coverage\",\n",
" tags=[\"notebook\", simulation_name],\n",
" # enabled=False\n",
")"
]
},
Expand Down Expand Up @@ -164,11 +165,19 @@
"metadata": {},
"outputs": [],
"source": [
"from autocast.metrics.coverage import MultiCoverage\n",
"from autocast.metrics.deterministic import MAE, RMSE, VRMSE\n",
"from autocast.utils import get_optimizer_config\n",
"\n",
"encoder = PermuteConcat(in_channels=n_channels, n_steps_input=n_steps_input, with_constants=True)\n",
"decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)\n",
"encoder = PermuteConcat(\n",
" in_channels=n_channels,\n",
" n_steps_input=n_steps_input,\n",
" with_constants=True,\n",
")\n",
"decoder = ChannelsLast(\n",
" output_channels=n_channels,\n",
" time_steps=n_steps_output\n",
")\n",
"\n",
"noise_channels = 1\n",
"processor = AViTProcessor(\n",
Expand All @@ -188,8 +197,8 @@
" processor=processor,\n",
" train_in_latent_space=False,\n",
" optimizer_config=get_optimizer_config(5e-4),\n",
" test_metrics=[VRMSE(), RMSE(), MAE(), CRPS()],\n",
" val_metrics=[VRMSE(), RMSE(), MAE(), CRPS()],\n",
" test_metrics=[VRMSE(), RMSE(), MAE(), CRPS(), MultiCoverage()],\n",
" val_metrics=[VRMSE(), RMSE(), MAE(), CRPS(), MultiCoverage()],\n",
" strie=stride,\n",
" loss_func=loss_func, # processor.loss_func,\n",
" n_members=3,\n",
Expand Down Expand Up @@ -361,7 +370,7 @@
"from autocast.metrics import MSE\n",
"\n",
"assert trues is not None\n",
"assert preds.shape == trues.shape\n",
"# assert preds.shape == trues.shape\n",
"mse = MSE()\n",
"mse_error_spatial = mse(preds, trues)\n",
"mse_error = mse(preds, trues)\n",
Expand Down Expand Up @@ -390,9 +399,10 @@
"else:\n",
" channel_names = None\n",
"\n",
"assert trues is not None\n",
"anim = plot_spatiotemporal_video(\n",
" pred=preds.mean(-1),\n",
" true=trues[..., 0], # type: ignore\n",
" true=trues,\n",
" pred_uq=preds.std(-1),\n",
" batch_idx=batch_idx,\n",
" save_path=f\"{simulation_name}_{batch_idx:02d}.mp4\",\n",
Expand All @@ -411,21 +421,63 @@
"metadata": {},
"outputs": [],
"source": [
"# Plot coverage metrics\n",
"from autocast.utils.plots import plot_coverage\n",
"\n",
"assert trues is not None\n",
"plot_coverage(preds, trues)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28",
"metadata": {},
"outputs": [],
"source": [
"# Plot example ensemble members and true trajectory for a single spatial location\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.figure()\n",
"plt.plot(preds[0, :, 0, 0, 0].detach().numpy())\n",
"plt.plot(trues[0, :, 0, 0, 0].detach().numpy()) # type: ignore\n",
"fig, ax = plt.subplots(1, 1)\n",
"x, y = 0, 0\n",
"plt.plot(preds[0, :, x, y, 0].detach().numpy())\n",
"assert trues is not None\n",
"plt.plot(trues[0, :, x, y, 0].detach().numpy())\n",
"plt.legend(ncol=3, fontsize=\"small\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28",
"id": "29",
"metadata": {},
"outputs": [],
"source": []
"source": [
"# Plot empirical quantiles\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"x, y = 0, 0\n",
"cmap = plt.get_cmap(\"Reds\")\n",
"for alpha in sorted([0.1, 0.3, 0.5, 0.7, 0.9], reverse=True):\n",
" upper = np.quantile(preds[0, :, x, y, 0].detach().numpy(), 1-alpha/2, axis=-1)\n",
" lower = np.quantile(preds[0, :, x, y, 0].detach().numpy(), alpha/2, axis=-1)\n",
" plt.fill_between(\n",
" range(preds.shape[1]),\n",
" lower,\n",
" upper,\n",
" color=cmap(alpha),\n",
" alpha=0.2,\n",
" label=f\"{round((1-alpha)*100)}% coverage\",\n",
" lw=0\n",
" )\n",
"assert trues is not None\n",
"plt.plot(trues[0, :, x, y, 0].detach().numpy())\n",
"plt.legend(ncol=3, fontsize=\"small\")\n",
"plt.show()"
]
}
],
"metadata": {
Expand Down
5 changes: 4 additions & 1 deletion src/autocast/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .coverage import Coverage, MultiCoverage
from .deterministic import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
from .ensemble import CRPS, AlphaFairCRPS, FairCRPS

Expand All @@ -12,9 +13,11 @@
"VMSE",
"VRMSE",
"AlphaFairCRPS",
"Coverage",
"FairCRPS",
"LInfinity",
"MultiCoverage",
]

ALL_DETERMINISTIC_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity)
ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS)
ALL_ENSEMBLE_METRICS = (CRPS, AlphaFairCRPS, FairCRPS, Coverage, MultiCoverage)
214 changes: 214 additions & 0 deletions src/autocast/metrics/coverage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from pathlib import Path

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from torch.nn import ModuleList
from torchmetrics import Metric

from autocast.metrics.ensemble import BTSCMMetric
from autocast.types import Tensor, TensorBTC, TensorBTSC, TensorBTSCM


class Coverage(BTSCMMetric):
"""
Coverage probability for a fixed coverage level.

Calculates the proportion of true values that fall within the symmetric
prediction interval defined by the coverage level.
"""

name: str = "coverage"

def __init__(self, coverage_level: float = 0.95, **kwargs):
"""Initialize Coverage metric.

Args:
coverage_level: nominal coverage probability (e.g. 0.95 for 95% interval).
Must be between 0 and 1.
**kwargs: Additional arguments passed to BaseMetric.
"""
super().__init__(**kwargs)
if not (0 < coverage_level < 1):
raise ValueError(f"coverage_level must be in (0, 1), got {coverage_level}")
self.coverage_level = coverage_level

def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTC:
"""
Compute coverage reduced over spatial dims.

Args:
y_pred: (B, T, S, C, M)
y_true: (B, T, S, C)

Returns
-------
coverage: (B, T, C)
"""
# Calculate quantiles of the ensemble distribution
# e.g. coverage_level=0.95 -> 0.025 and 0.975 quantiles
q_low = 0.5 - self.coverage_level / 2
q_high = 0.5 + self.coverage_level / 2

# Calculate quantiles
q_tensor = torch.tensor(
[q_low, q_high], device=y_pred.device, dtype=y_pred.dtype
)
quantiles = torch.quantile(y_pred, q_tensor, dim=-1) # (2, B, T, S, C)

lower_q = quantiles[0]
upper_q = quantiles[1]

# Calculate coverage (1 if inside, 0 otherwise)
is_covered = ((y_true >= lower_q) & (y_true <= upper_q)).float()

# Reduce over spatial dimensions: (B, T, S, C) -> (B, T, C)
n_spatial_dims = self._infer_n_spatial_dims(is_covered)
spatial_dims = tuple(range(2, 2 + n_spatial_dims))
coverage_reduced = is_covered.mean(dim=spatial_dims)

return coverage_reduced


class MultiCoverage(Metric):
"""
Computes coverage for multiple coverage levels at once.

This is a wrapper around multiple Coverage metrics. It inherits from Metric
to integrate with PyTorch Lightning and TorchMetrics.
"""

def __init__(self, coverage_levels: list[float] | None = None):
super().__init__()
if coverage_levels is None:
coverage_levels = [
round(x, 2) for x in torch.linspace(0.05, 0.95, steps=19).tolist()
]

self.coverage_levels = coverage_levels
# List of Coverage metrics with reduce_all=False to allow per-channel analysis
self.metrics = ModuleList(
[Coverage(coverage_level=cl, reduce_all=False) for cl in coverage_levels]
)

def update(self, y_pred, y_true):
for metric in self.metrics:
assert isinstance(metric, Coverage)
metric.update(y_pred, y_true)

def compute(self) -> Tensor:
"""Compute the Average Calibration Error."""
errors = []
for cl, metric in zip(self.coverage_levels, self.metrics, strict=True):
assert isinstance(metric, Coverage)
# Calibration error: |observed - expected|
errors.append(torch.abs(metric.compute() - cl))

return torch.stack(errors).mean()

def _compute_levels_and_values(self) -> tuple[list[float], list[float]]:
"""Get coverage levels and observed values for plotting."""
levels, observed = [], []
for cl, metric in zip(self.coverage_levels, self.metrics, strict=True):
assert isinstance(metric, Coverage)
levels.append(cl)
observed.append(metric.compute().mean().item())
return levels, observed

def compute_detailed(self) -> dict[str, float]:
"""Return a dict of results, keys formatted as 'coverage_{coverage_level}'."""
return {
f"coverage_{level}": value
for level, value in zip(*self._compute_levels_and_values(), strict=True)
}

def plot(
self,
save_path: Path | str | None = None,
title: str = "Coverage Plot",
cmap_str: str = "viridis",
):
"""
Plot reliability diagram showing expected vs observed coverage.

Parameters
----------
save_path: str, optional
Path to save the plot.
title: str
Plot title.
cmap_str: str
Color map string from matplotlib.

Returns
-------
matplotlib.figure.Figure
"""
# Prepare data structure: levels -> [val_c1, val_c2, ...]
levels = self.coverage_levels
observed_means = []
observed_channels = [] # shape (L, C)

# Loop over metrics
for metric in self.metrics:
assert isinstance(metric, Coverage)
val = metric.compute()
val_c = val.mean(dim=0).cpu().numpy() # (C,)
observed_channels.append(val_c)
observed_means.append(val_c.mean())

# Create matplotlib figure
fig, ax = plt.subplots(figsize=(8, 8))

# Optimal line (y=x)
ax.plot([0, 1], [0, 1], "k:", label="Expected", linewidth=2)

# Plot channels
observed_arr = np.stack(observed_channels) # (L, C)
cmap = plt.get_cmap(cmap_str) # cmap for each channel
n_channels = observed_channels[0].shape[0]
for c in range(n_channels):
color = cmap(c / n_channels) if n_channels > 1 else "blue"
label = f"Ch {c}" if n_channels <= 10 else None
ax.plot(
levels,
observed_arr[:, c],
color=color,
alpha=0.3,
linewidth=1,
label=label,
)

# Plot mean coverage in bold
ax.plot(levels, observed_means, "k-", linewidth=3, label="Mean")
ax.set_xlabel(r"Coverage level, $\alpha$")
ax.set_ylabel("Observed Coverage")
ax.set_title(title)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.grid(True, linestyle=":", alpha=0.6)

# Only show legend if not cluttered
if n_channels > 10:
# Add manual legend for trace
custom_lines = [
Line2D([0], [0], color="k", lw=3),
Line2D([0], [0], color="grey", lw=1, alpha=0.5),
]
ax.legend(custom_lines, ["Mean", "Individual Channels"])
else:
ax.legend()

if save_path:
plt.savefig(save_path, bbox_inches="tight")
print(f"Plot saved to {save_path}")

plt.close(fig)
return fig

def reset(self):
# Reset all sub-metrics
for metric in self.metrics:
assert isinstance(metric, Coverage)
metric.reset()
Loading