Skip to content

Commit 158caad

Browse files
authored
Merge pull request #15 from alan-turing-institute/12-norm-stats
Record normalization stats (#12)
2 parents 6beebb3 + 0bc366a commit 158caad

2 files changed

Lines changed: 363 additions & 5 deletions

File tree

src/autosim/cli.py

Lines changed: 251 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import sys
55
import uuid
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, cast
88

99
import hydra
1010
import torch
1111
from hydra.utils import get_original_cwd, instantiate
12-
from omegaconf import OmegaConf
12+
from omegaconf import DictConfig, OmegaConf
1313

1414
from autosim.simulations.base import SpatioTemporalSimulator
1515
from autosim.utils import plot_spatiotemporal_video
@@ -177,6 +177,193 @@ def save_example_videos(
177177
)
178178

179179

180+
def _parse_field_names_csv(field_names_csv: str | None) -> list[str] | None:
181+
"""Parse a comma-separated field-name string into a cleaned list."""
182+
if field_names_csv is None:
183+
return None
184+
names = [name.strip() for name in field_names_csv.split(",") if name.strip()]
185+
return names if names else None
186+
187+
188+
def _infer_core_field_names_from_resolved_config(
189+
dataset_dir: Path, n_channels: int
190+
) -> list[str] | None:
191+
"""Infer channel names from `resolved_config.yaml` when available."""
192+
resolved_cfg_path = dataset_dir / "resolved_config.yaml"
193+
if not resolved_cfg_path.exists():
194+
return None
195+
try:
196+
cfg = OmegaConf.load(resolved_cfg_path)
197+
assert isinstance(cfg, DictConfig)
198+
simulator_cfg = cfg.get("simulator")
199+
if simulator_cfg is None:
200+
return None
201+
sim = build_simulator(simulator_cfg)
202+
inferred_names = [str(name) for name in sim.output_names]
203+
except Exception:
204+
return None
205+
206+
if len(inferred_names) != n_channels:
207+
return None
208+
return inferred_names
209+
210+
211+
def compute_normalization_stats(
212+
split_payload: dict[str, Any],
213+
core_field_names: list[str] | None = None,
214+
constant_field_names: list[str] | None = None,
215+
) -> dict[str, Any]:
216+
"""Compute normalization statistics for one split payload."""
217+
data = split_payload.get("data")
218+
if not isinstance(data, torch.Tensor) or data.ndim != 5:
219+
msg = (
220+
"Normalization stats require split payload 'data' as a 5D torch.Tensor "
221+
"with shape [batch,time,x,y,channels]."
222+
)
223+
raise ValueError(msg)
224+
225+
_, n_time, _, _, n_channels = data.shape
226+
if n_time < 2:
227+
msg = (
228+
"Normalization delta stats require at least 2 time steps in "
229+
"split payload 'data'."
230+
)
231+
raise ValueError(msg)
232+
233+
resolved_core_field_names = core_field_names
234+
if resolved_core_field_names is None:
235+
resolved_core_field_names = [f"field_{idx}" for idx in range(n_channels)]
236+
if len(resolved_core_field_names) != n_channels:
237+
msg = (
238+
"Number of core field names must match data channel count. "
239+
f"Received {len(resolved_core_field_names)} names "
240+
f"for {n_channels} channels."
241+
)
242+
raise ValueError(msg)
243+
244+
deltas = data[:, 1:, ...] - data[:, :-1, ...]
245+
246+
flattened_data = data.reshape(-1, n_channels)
247+
flattened_deltas = deltas.reshape(-1, n_channels)
248+
mean = flattened_data.mean(dim=0)
249+
std = flattened_data.std(dim=0, unbiased=False)
250+
mean_delta = flattened_deltas.mean(dim=0)
251+
std_delta = flattened_deltas.std(dim=0, unbiased=False)
252+
253+
def _stats_by_channel(values: torch.Tensor) -> dict[str, float]:
254+
return {
255+
name: float(values[idx].detach().cpu().item())
256+
for idx, name in enumerate(resolved_core_field_names or [])
257+
}
258+
259+
return {
260+
"stats": {
261+
"mean": _stats_by_channel(mean),
262+
"std": _stats_by_channel(std),
263+
"mean_delta": _stats_by_channel(mean_delta),
264+
"std_delta": _stats_by_channel(std_delta),
265+
},
266+
"core_field_names": resolved_core_field_names,
267+
"constant_field_names": constant_field_names or [],
268+
}
269+
270+
271+
def _round_sigfigs(value: float, sig_figs: int) -> float:
272+
"""Round a float to a fixed number of significant figures."""
273+
if sig_figs <= 0:
274+
msg = "sig_figs must be positive."
275+
raise ValueError(msg)
276+
if value == 0.0:
277+
return 0.0
278+
# General format preserves significant figures; may emit scientific notation.
279+
return float(f"{value:.{sig_figs}g}")
280+
281+
282+
def _rounded_normalization_stats_payload(
283+
stats_payload: dict[str, Any], sig_figs: int
284+
) -> dict[str, Any]:
285+
"""Return a copy of stats_payload with rounded float stat values."""
286+
rounded = cast(
287+
dict[str, Any],
288+
OmegaConf.to_container(OmegaConf.create(stats_payload), resolve=True),
289+
)
290+
291+
stats = rounded.get("stats")
292+
if not isinstance(stats, dict):
293+
return rounded
294+
295+
for key in ("mean", "std", "mean_delta", "std_delta"):
296+
bucket = stats.get(key)
297+
if not isinstance(bucket, dict):
298+
continue
299+
for field_name, field_value in list(bucket.items()):
300+
if isinstance(field_value, int | float):
301+
bucket[field_name] = _round_sigfigs(float(field_value), sig_figs)
302+
303+
return rounded
304+
305+
306+
def save_normalization_stats(
307+
stats_payload: dict[str, Any],
308+
output_path: Path,
309+
sig_figs: int = 4,
310+
) -> None:
311+
"""Persist normalization statistics as YAML."""
312+
output_path.parent.mkdir(parents=True, exist_ok=True)
313+
rounded_payload = _rounded_normalization_stats_payload(
314+
stats_payload=stats_payload, sig_figs=sig_figs
315+
)
316+
yaml_payload = OmegaConf.to_yaml(OmegaConf.create(rounded_payload), resolve=True)
317+
output_path.write_text(yaml_payload, encoding="utf-8")
318+
319+
320+
def generate_normalization_stats_yaml(
321+
dataset_dir: Path,
322+
split: str = "train",
323+
output_path: Path | None = None,
324+
core_field_names: list[str] | None = None,
325+
sig_figs: int = 4,
326+
) -> Path:
327+
"""Generate normalization-stats YAML from an existing dataset directory."""
328+
split_data_path = dataset_dir / split / "data.pt"
329+
if not split_data_path.exists():
330+
msg = f"Could not find split file '{split_data_path}'."
331+
raise FileNotFoundError(msg)
332+
split_payload = torch.load(split_data_path, map_location="cpu")
333+
if not isinstance(split_payload, dict):
334+
msg = f"Expected dict payload in '{split_data_path}'."
335+
raise ValueError(msg)
336+
337+
payload_data = split_payload.get("data")
338+
if not isinstance(payload_data, torch.Tensor) or payload_data.ndim != 5:
339+
msg = (
340+
"Expected split payload 'data' as a 5D torch.Tensor with shape "
341+
"[batch,time,x,y,channels]."
342+
)
343+
raise ValueError(msg)
344+
345+
resolved_field_names = core_field_names
346+
if resolved_field_names is None:
347+
resolved_field_names = _infer_core_field_names_from_resolved_config(
348+
dataset_dir=dataset_dir,
349+
n_channels=payload_data.shape[-1],
350+
)
351+
stats_payload = compute_normalization_stats(
352+
split_payload=split_payload,
353+
core_field_names=resolved_field_names,
354+
)
355+
356+
resolved_output_path = (
357+
output_path if output_path is not None else dataset_dir / "stats.yml"
358+
)
359+
save_normalization_stats(
360+
stats_payload=stats_payload,
361+
output_path=resolved_output_path,
362+
sig_figs=sig_figs,
363+
)
364+
return resolved_output_path
365+
366+
180367
def get_per_strata_counts(
181368
n_train: int,
182369
n_valid: int,
@@ -301,6 +488,14 @@ def _generate_main(cfg: Any) -> None:
301488
save_resolved_config(cfg=cfg, output_dir=output_dir)
302489

303490
save_dataset_splits(splits=splits, output_dir=output_dir, overwrite=cfg.overwrite)
491+
normalization_stats_payload = compute_normalization_stats(
492+
split_payload=splits["train"],
493+
core_field_names=channel_names_for_visualization,
494+
)
495+
save_normalization_stats(
496+
stats_payload=normalization_stats_payload,
497+
output_path=output_dir / "stats.yml",
498+
)
304499
save_example_videos(
305500
splits=splits,
306501
output_dir=output_dir,
@@ -321,6 +516,7 @@ def main() -> None:
321516
"""Dispatch tiny autosim subcommands.
322517
323518
- `autosim list` prints simulator config names.
519+
- `autosim stats` writes normalization stats YAML for an existing dataset.
324520
- `autosim` (or any Hydra overrides) runs data generation.
325521
"""
326522
argv = sys.argv[1:]
@@ -330,13 +526,15 @@ def main() -> None:
330526
prog="autosim",
331527
description=(
332528
"Generate simulation datasets using Hydra overrides, or list "
333-
"available simulator configs."
529+
"available simulator configs, or compute normalization stats."
334530
),
335531
)
336532
parser.add_argument(
337533
"command",
338534
nargs="?",
339-
help="Subcommand: 'list'. Omit to run data generation with Hydra.",
535+
help=(
536+
"Subcommand: 'list' or 'stats'. Omit to run data generation with Hydra."
537+
),
340538
)
341539
parser.print_help()
342540
return
@@ -351,6 +549,55 @@ def main() -> None:
351549
print(name)
352550
return
353551

552+
if argv and argv[0] == "stats":
553+
stats_parser = argparse.ArgumentParser(
554+
prog="autosim stats",
555+
description=(
556+
"Generate normalization_stats YAML for an existing dataset directory."
557+
),
558+
)
559+
stats_parser.add_argument(
560+
"dataset_dir",
561+
help="Dataset root containing split folders such as train/data.pt.",
562+
)
563+
stats_parser.add_argument(
564+
"--split",
565+
default="train",
566+
help="Split to use for stats (default: train).",
567+
)
568+
stats_parser.add_argument(
569+
"--output",
570+
default=None,
571+
help=("Optional output YAML path (default: <dataset_dir>/stats.yml)."),
572+
)
573+
stats_parser.add_argument(
574+
"--field-names",
575+
default=None,
576+
help=(
577+
"Optional comma-separated core field names, e.g. 'U,V'. "
578+
"If omitted, names are inferred from resolved_config.yaml "
579+
"when possible."
580+
),
581+
)
582+
stats_parser.add_argument(
583+
"--sig-figs",
584+
type=int,
585+
default=4,
586+
help="Significant figures for float stats in YAML (default: 4).",
587+
)
588+
args = stats_parser.parse_args(argv[1:])
589+
590+
output_path = Path(args.output) if args.output is not None else None
591+
written_path = generate_normalization_stats_yaml(
592+
dataset_dir=Path(args.dataset_dir),
593+
split=str(args.split),
594+
output_path=output_path,
595+
core_field_names=_parse_field_names_csv(args.field_names),
596+
sig_figs=int(args.sig_figs),
597+
)
598+
print(written_path.as_posix())
599+
return
600+
354601
# Preserve all original arguments for Hydra's own parser.
355602
sys.argv = [sys.argv[0], *argv]
356603
_generate_main()

0 commit comments

Comments
 (0)