diff --git a/src/autocast/metrics/coverage.py b/src/autocast/metrics/coverage.py index a6635a2e..a18ddff7 100644 --- a/src/autocast/metrics/coverage.py +++ b/src/autocast/metrics/coverage.py @@ -182,7 +182,7 @@ def plot( # Plot mean coverage in bold ax.plot(levels, observed_means, "k-", linewidth=3, label="Mean") - ax.set_xlabel("Coverage level, $\alpha$") + ax.set_xlabel(r"Coverage level, $\alpha$") ax.set_ylabel("Observed Coverage") ax.set_title(title) ax.set_xlim(0, 1) @@ -212,3 +212,135 @@ def reset(self): for metric in self.metrics: assert isinstance(metric, Coverage) metric.reset() + + +class MultiCoverageByTime(Metric): + """ + Computes coverage for multiple coverage levels, plotting by time zones. + + Similar to MultiCoverage but plots separate lines for different time zones + instead of different channels. + """ + + def __init__( + self, + time_zones: list[tuple[int, int]], + coverage_levels: list[float] | None = None, + ): + """Initialize MultiCoverageByTime metric. + + Args: + time_zones: List of (start_idx, end_idx) tuples defining time ranges. + Each range is [start, end) (exclusive end). + coverage_levels: List of coverage levels to compute. Defaults to + 19 levels from 0.05 to 0.95. + """ + super().__init__() + if coverage_levels is None: + coverage_levels = [ + round(x, 2) for x in torch.linspace(0.05, 0.95, steps=19).tolist() + ] + + self.time_zones = time_zones + self.coverage_levels = coverage_levels + self.metrics = ModuleList( + [Coverage(coverage_level=cl, reduce_all=False) for cl in coverage_levels] + ) + + def update(self, y_pred, y_true): + for metric in self.metrics: + assert isinstance(metric, Coverage) + metric.update(y_pred, y_true) + + def compute(self) -> Tensor: + """Compute the Average Calibration Error.""" + errors = [] + for cl, metric in zip(self.coverage_levels, self.metrics, strict=True): + assert isinstance(metric, Coverage) + errors.append(torch.abs(metric.compute() - cl)) + + return torch.stack(errors).mean() + + def plot( + self, + save_path: Path | str | None = None, + title: str = "Coverage Plot by Time Zone", + cmap_str: str = "viridis", + ): + """ + Plot reliability diagram showing expected vs observed coverage by time zone. + + Parameters + ---------- + save_path: str, optional + Path to save the plot. + title: str + Plot title. + cmap_str: str + Color map string from matplotlib. + + Returns + ------- + matplotlib.figure.Figure + """ + levels = self.coverage_levels + n_zones = len(self.time_zones) + + # Collect observed coverage per time zone and overall mean + observed_by_zone = [[] for _ in self.time_zones] + observed_mean = [] # Mean across all time steps + + for metric in self.metrics: + assert isinstance(metric, Coverage) + val = metric.compute().cpu() # (T, C) after batch reduction + + # Mean across all time steps and channels + observed_mean.append(val.mean().item()) + + for z_idx, (start, end) in enumerate(self.time_zones): + # Average over time slice and all channels + zone_mean = val[start:end].mean().item() + observed_by_zone[z_idx].append(zone_mean) + + # Create matplotlib figure + fig, ax = plt.subplots(figsize=(8, 8)) + + # Optimal line (y=x) + ax.plot([0, 1], [0, 1], "k:", label="Expected", linewidth=2) + + # Plot each time zone as dashed lines + cmap = plt.get_cmap(cmap_str) + for z_idx, (start, end) in enumerate(self.time_zones): + color = cmap(z_idx / n_zones) if n_zones > 1 else "blue" + label = f"t=[{start}, {end})" + ax.plot( + levels, + observed_by_zone[z_idx], + color=color, + linewidth=1, + linestyle="--", + label=label, + ) + + # Plot mean across all time steps as solid line + ax.plot(levels, observed_mean, "k-", linewidth=3, label="All time") + + ax.set_xlabel(r"Coverage level, $\alpha$") + ax.set_ylabel("Observed Coverage") + ax.set_title(title) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.grid(True, linestyle=":", alpha=0.6) + ax.legend() + + if save_path: + plt.savefig(save_path, bbox_inches="tight") + print(f"Plot saved to {save_path}") + + plt.close(fig) + return fig + + def reset(self): + for metric in self.metrics: + assert isinstance(metric, Coverage) + metric.reset() diff --git a/src/autocast/utils/plots.py b/src/autocast/utils/plots.py index fcd4cc7e..958f6d95 100644 --- a/src/autocast/utils/plots.py +++ b/src/autocast/utils/plots.py @@ -9,7 +9,7 @@ from matplotlib.colors import Normalize, TwoSlopeNorm from matplotlib.gridspec import GridSpec -from autocast.metrics.coverage import MultiCoverage +from autocast.metrics.coverage import MultiCoverage, MultiCoverageByTime from autocast.types.types import Tensor, TensorBTSC, TensorBTSCM @@ -338,13 +338,15 @@ def plot_coverage( pred: TensorBTSCM, true: TensorBTSC, coverage_levels: list[float] | None = None, + time_zones: list[tuple[int, int]] | None = None, save_path: str | None = None, title: str = "Coverage plot", ): """ Plot reliability diagram showing expected vs observed coverage. - This is a convenience wrapper around MultiCoverage.plot(). + This is a convenience wrapper around MultiCoverage.plot() or + MultiCoverageByTime.plot() if time_zones is provided. Parameters ---------- @@ -354,6 +356,9 @@ def plot_coverage( Ground truth tensor. coverage_levels: list[float], optional Coverage levels to evaluate (default: 0.05 to 0.95). + time_zones: list[tuple[int, int]], optional + If provided, plot coverage by time zones instead of by channel. + Each tuple is (start_idx, end_idx) defining a time range [start, end). save_path: str, optional Path to save the plot. title: str @@ -368,6 +373,11 @@ def plot_coverage( ) # Create metric, update with data, and plot - metric = MultiCoverage(coverage_levels=coverage_levels_) + if time_zones is not None: + metric = MultiCoverageByTime( + time_zones=time_zones, coverage_levels=coverage_levels_ + ) + else: + metric = MultiCoverage(coverage_levels=coverage_levels_) metric.update(pred, true) return metric.plot(save_path=save_path, title=title)