Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion src/autocast/metrics/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def plot(

# Plot mean coverage in bold
ax.plot(levels, observed_means, "k-", linewidth=3, label="Mean")
ax.set_xlabel("Coverage level, $\alpha$")
ax.set_xlabel(r"Coverage level, $\alpha$")
ax.set_ylabel("Observed Coverage")
ax.set_title(title)
ax.set_xlim(0, 1)
Expand Down Expand Up @@ -212,3 +212,135 @@ def reset(self):
for metric in self.metrics:
assert isinstance(metric, Coverage)
metric.reset()


class MultiCoverageByTime(Metric):
"""
Computes coverage for multiple coverage levels, plotting by time zones.

Similar to MultiCoverage but plots separate lines for different time zones
instead of different channels.
"""

def __init__(
self,
time_zones: list[tuple[int, int]],
coverage_levels: list[float] | None = None,
):
"""Initialize MultiCoverageByTime metric.

Args:
time_zones: List of (start_idx, end_idx) tuples defining time ranges.
Each range is [start, end) (exclusive end).
coverage_levels: List of coverage levels to compute. Defaults to
19 levels from 0.05 to 0.95.
"""
super().__init__()
if coverage_levels is None:
coverage_levels = [
round(x, 2) for x in torch.linspace(0.05, 0.95, steps=19).tolist()
]

self.time_zones = time_zones
self.coverage_levels = coverage_levels
self.metrics = ModuleList(
[Coverage(coverage_level=cl, reduce_all=False) for cl in coverage_levels]
)

def update(self, y_pred, y_true):
for metric in self.metrics:
assert isinstance(metric, Coverage)
metric.update(y_pred, y_true)

def compute(self) -> Tensor:
"""Compute the Average Calibration Error."""
errors = []
for cl, metric in zip(self.coverage_levels, self.metrics, strict=True):
assert isinstance(metric, Coverage)
errors.append(torch.abs(metric.compute() - cl))

return torch.stack(errors).mean()

def plot(
self,
save_path: Path | str | None = None,
title: str = "Coverage Plot by Time Zone",
cmap_str: str = "viridis",
):
"""
Plot reliability diagram showing expected vs observed coverage by time zone.

Parameters
----------
save_path: str, optional
Path to save the plot.
title: str
Plot title.
cmap_str: str
Color map string from matplotlib.

Returns
-------
matplotlib.figure.Figure
"""
levels = self.coverage_levels
n_zones = len(self.time_zones)

# Collect observed coverage per time zone and overall mean
observed_by_zone = [[] for _ in self.time_zones]
observed_mean = [] # Mean across all time steps

for metric in self.metrics:
assert isinstance(metric, Coverage)
val = metric.compute().cpu() # (T, C) after batch reduction

# Mean across all time steps and channels
observed_mean.append(val.mean().item())

for z_idx, (start, end) in enumerate(self.time_zones):
# Average over time slice and all channels
zone_mean = val[start:end].mean().item()
observed_by_zone[z_idx].append(zone_mean)

# Create matplotlib figure
fig, ax = plt.subplots(figsize=(8, 8))

# Optimal line (y=x)
ax.plot([0, 1], [0, 1], "k:", label="Expected", linewidth=2)

# Plot each time zone as dashed lines
cmap = plt.get_cmap(cmap_str)
for z_idx, (start, end) in enumerate(self.time_zones):
color = cmap(z_idx / n_zones) if n_zones > 1 else "blue"
label = f"t=[{start}, {end})"
ax.plot(
levels,
observed_by_zone[z_idx],
color=color,
linewidth=1,
linestyle="--",
label=label,
)

# Plot mean across all time steps as solid line
ax.plot(levels, observed_mean, "k-", linewidth=3, label="All time")

ax.set_xlabel(r"Coverage level, $\alpha$")
ax.set_ylabel("Observed Coverage")
ax.set_title(title)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.grid(True, linestyle=":", alpha=0.6)
ax.legend()

if save_path:
plt.savefig(save_path, bbox_inches="tight")
print(f"Plot saved to {save_path}")

plt.close(fig)
return fig

def reset(self):
for metric in self.metrics:
assert isinstance(metric, Coverage)
metric.reset()
16 changes: 13 additions & 3 deletions src/autocast/utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from matplotlib.colors import Normalize, TwoSlopeNorm
from matplotlib.gridspec import GridSpec

from autocast.metrics.coverage import MultiCoverage
from autocast.metrics.coverage import MultiCoverage, MultiCoverageByTime
from autocast.types.types import Tensor, TensorBTSC, TensorBTSCM


Expand Down Expand Up @@ -338,13 +338,15 @@ def plot_coverage(
pred: TensorBTSCM,
true: TensorBTSC,
coverage_levels: list[float] | None = None,
time_zones: list[tuple[int, int]] | None = None,
save_path: str | None = None,
title: str = "Coverage plot",
):
"""
Plot reliability diagram showing expected vs observed coverage.

This is a convenience wrapper around MultiCoverage.plot().
This is a convenience wrapper around MultiCoverage.plot() or
MultiCoverageByTime.plot() if time_zones is provided.

Parameters
----------
Expand All @@ -354,6 +356,9 @@ def plot_coverage(
Ground truth tensor.
coverage_levels: list[float], optional
Coverage levels to evaluate (default: 0.05 to 0.95).
time_zones: list[tuple[int, int]], optional
If provided, plot coverage by time zones instead of by channel.
Each tuple is (start_idx, end_idx) defining a time range [start, end).
save_path: str, optional
Path to save the plot.
title: str
Expand All @@ -368,6 +373,11 @@ def plot_coverage(
)

# Create metric, update with data, and plot
metric = MultiCoverage(coverage_levels=coverage_levels_)
if time_zones is not None:
metric = MultiCoverageByTime(
time_zones=time_zones, coverage_levels=coverage_levels_
)
else:
metric = MultiCoverage(coverage_levels=coverage_levels_)
metric.update(pred, true)
return metric.plot(save_path=save_path, title=title)
Loading