44import sys
55import uuid
66from pathlib import Path
7- from typing import Any
7+ from typing import Any , cast
88
99import hydra
1010import torch
1111from hydra .utils import get_original_cwd , instantiate
12- from omegaconf import OmegaConf
12+ from omegaconf import DictConfig , OmegaConf
1313
1414from autosim .simulations .base import SpatioTemporalSimulator
1515from autosim .utils import plot_spatiotemporal_video
@@ -177,6 +177,193 @@ def save_example_videos(
177177 )
178178
179179
180+ def _parse_field_names_csv (field_names_csv : str | None ) -> list [str ] | None :
181+ """Parse a comma-separated field-name string into a cleaned list."""
182+ if field_names_csv is None :
183+ return None
184+ names = [name .strip () for name in field_names_csv .split ("," ) if name .strip ()]
185+ return names if names else None
186+
187+
188+ def _infer_core_field_names_from_resolved_config (
189+ dataset_dir : Path , n_channels : int
190+ ) -> list [str ] | None :
191+ """Infer channel names from `resolved_config.yaml` when available."""
192+ resolved_cfg_path = dataset_dir / "resolved_config.yaml"
193+ if not resolved_cfg_path .exists ():
194+ return None
195+ try :
196+ cfg = OmegaConf .load (resolved_cfg_path )
197+ assert isinstance (cfg , DictConfig )
198+ simulator_cfg = cfg .get ("simulator" )
199+ if simulator_cfg is None :
200+ return None
201+ sim = build_simulator (simulator_cfg )
202+ inferred_names = [str (name ) for name in sim .output_names ]
203+ except Exception :
204+ return None
205+
206+ if len (inferred_names ) != n_channels :
207+ return None
208+ return inferred_names
209+
210+
211+ def compute_normalization_stats (
212+ split_payload : dict [str , Any ],
213+ core_field_names : list [str ] | None = None ,
214+ constant_field_names : list [str ] | None = None ,
215+ ) -> dict [str , Any ]:
216+ """Compute normalization statistics for one split payload."""
217+ data = split_payload .get ("data" )
218+ if not isinstance (data , torch .Tensor ) or data .ndim != 5 :
219+ msg = (
220+ "Normalization stats require split payload 'data' as a 5D torch.Tensor "
221+ "with shape [batch,time,x,y,channels]."
222+ )
223+ raise ValueError (msg )
224+
225+ _ , n_time , _ , _ , n_channels = data .shape
226+ if n_time < 2 :
227+ msg = (
228+ "Normalization delta stats require at least 2 time steps in "
229+ "split payload 'data'."
230+ )
231+ raise ValueError (msg )
232+
233+ resolved_core_field_names = core_field_names
234+ if resolved_core_field_names is None :
235+ resolved_core_field_names = [f"field_{ idx } " for idx in range (n_channels )]
236+ if len (resolved_core_field_names ) != n_channels :
237+ msg = (
238+ "Number of core field names must match data channel count. "
239+ f"Received { len (resolved_core_field_names )} names "
240+ f"for { n_channels } channels."
241+ )
242+ raise ValueError (msg )
243+
244+ deltas = data [:, 1 :, ...] - data [:, :- 1 , ...]
245+
246+ flattened_data = data .reshape (- 1 , n_channels )
247+ flattened_deltas = deltas .reshape (- 1 , n_channels )
248+ mean = flattened_data .mean (dim = 0 )
249+ std = flattened_data .std (dim = 0 , unbiased = False )
250+ mean_delta = flattened_deltas .mean (dim = 0 )
251+ std_delta = flattened_deltas .std (dim = 0 , unbiased = False )
252+
253+ def _stats_by_channel (values : torch .Tensor ) -> dict [str , float ]:
254+ return {
255+ name : float (values [idx ].detach ().cpu ().item ())
256+ for idx , name in enumerate (resolved_core_field_names or [])
257+ }
258+
259+ return {
260+ "stats" : {
261+ "mean" : _stats_by_channel (mean ),
262+ "std" : _stats_by_channel (std ),
263+ "mean_delta" : _stats_by_channel (mean_delta ),
264+ "std_delta" : _stats_by_channel (std_delta ),
265+ },
266+ "core_field_names" : resolved_core_field_names ,
267+ "constant_field_names" : constant_field_names or [],
268+ }
269+
270+
271+ def _round_sigfigs (value : float , sig_figs : int ) -> float :
272+ """Round a float to a fixed number of significant figures."""
273+ if sig_figs <= 0 :
274+ msg = "sig_figs must be positive."
275+ raise ValueError (msg )
276+ if value == 0.0 :
277+ return 0.0
278+ # General format preserves significant figures; may emit scientific notation.
279+ return float (f"{ value :.{sig_figs }g} " )
280+
281+
282+ def _rounded_normalization_stats_payload (
283+ stats_payload : dict [str , Any ], sig_figs : int
284+ ) -> dict [str , Any ]:
285+ """Return a copy of stats_payload with rounded float stat values."""
286+ rounded = cast (
287+ dict [str , Any ],
288+ OmegaConf .to_container (OmegaConf .create (stats_payload ), resolve = True ),
289+ )
290+
291+ stats = rounded .get ("stats" )
292+ if not isinstance (stats , dict ):
293+ return rounded
294+
295+ for key in ("mean" , "std" , "mean_delta" , "std_delta" ):
296+ bucket = stats .get (key )
297+ if not isinstance (bucket , dict ):
298+ continue
299+ for field_name , field_value in list (bucket .items ()):
300+ if isinstance (field_value , int | float ):
301+ bucket [field_name ] = _round_sigfigs (float (field_value ), sig_figs )
302+
303+ return rounded
304+
305+
306+ def save_normalization_stats (
307+ stats_payload : dict [str , Any ],
308+ output_path : Path ,
309+ sig_figs : int = 4 ,
310+ ) -> None :
311+ """Persist normalization statistics as YAML."""
312+ output_path .parent .mkdir (parents = True , exist_ok = True )
313+ rounded_payload = _rounded_normalization_stats_payload (
314+ stats_payload = stats_payload , sig_figs = sig_figs
315+ )
316+ yaml_payload = OmegaConf .to_yaml (OmegaConf .create (rounded_payload ), resolve = True )
317+ output_path .write_text (yaml_payload , encoding = "utf-8" )
318+
319+
320+ def generate_normalization_stats_yaml (
321+ dataset_dir : Path ,
322+ split : str = "train" ,
323+ output_path : Path | None = None ,
324+ core_field_names : list [str ] | None = None ,
325+ sig_figs : int = 4 ,
326+ ) -> Path :
327+ """Generate normalization-stats YAML from an existing dataset directory."""
328+ split_data_path = dataset_dir / split / "data.pt"
329+ if not split_data_path .exists ():
330+ msg = f"Could not find split file '{ split_data_path } '."
331+ raise FileNotFoundError (msg )
332+ split_payload = torch .load (split_data_path , map_location = "cpu" )
333+ if not isinstance (split_payload , dict ):
334+ msg = f"Expected dict payload in '{ split_data_path } '."
335+ raise ValueError (msg )
336+
337+ payload_data = split_payload .get ("data" )
338+ if not isinstance (payload_data , torch .Tensor ) or payload_data .ndim != 5 :
339+ msg = (
340+ "Expected split payload 'data' as a 5D torch.Tensor with shape "
341+ "[batch,time,x,y,channels]."
342+ )
343+ raise ValueError (msg )
344+
345+ resolved_field_names = core_field_names
346+ if resolved_field_names is None :
347+ resolved_field_names = _infer_core_field_names_from_resolved_config (
348+ dataset_dir = dataset_dir ,
349+ n_channels = payload_data .shape [- 1 ],
350+ )
351+ stats_payload = compute_normalization_stats (
352+ split_payload = split_payload ,
353+ core_field_names = resolved_field_names ,
354+ )
355+
356+ resolved_output_path = (
357+ output_path if output_path is not None else dataset_dir / "stats.yml"
358+ )
359+ save_normalization_stats (
360+ stats_payload = stats_payload ,
361+ output_path = resolved_output_path ,
362+ sig_figs = sig_figs ,
363+ )
364+ return resolved_output_path
365+
366+
180367def get_per_strata_counts (
181368 n_train : int ,
182369 n_valid : int ,
@@ -301,6 +488,14 @@ def _generate_main(cfg: Any) -> None:
301488 save_resolved_config (cfg = cfg , output_dir = output_dir )
302489
303490 save_dataset_splits (splits = splits , output_dir = output_dir , overwrite = cfg .overwrite )
491+ normalization_stats_payload = compute_normalization_stats (
492+ split_payload = splits ["train" ],
493+ core_field_names = channel_names_for_visualization ,
494+ )
495+ save_normalization_stats (
496+ stats_payload = normalization_stats_payload ,
497+ output_path = output_dir / "stats.yml" ,
498+ )
304499 save_example_videos (
305500 splits = splits ,
306501 output_dir = output_dir ,
@@ -321,6 +516,7 @@ def main() -> None:
321516 """Dispatch tiny autosim subcommands.
322517
323518 - `autosim list` prints simulator config names.
519+ - `autosim stats` writes normalization stats YAML for an existing dataset.
324520 - `autosim` (or any Hydra overrides) runs data generation.
325521 """
326522 argv = sys .argv [1 :]
@@ -330,13 +526,15 @@ def main() -> None:
330526 prog = "autosim" ,
331527 description = (
332528 "Generate simulation datasets using Hydra overrides, or list "
333- "available simulator configs."
529+ "available simulator configs, or compute normalization stats ."
334530 ),
335531 )
336532 parser .add_argument (
337533 "command" ,
338534 nargs = "?" ,
339- help = "Subcommand: 'list'. Omit to run data generation with Hydra." ,
535+ help = (
536+ "Subcommand: 'list' or 'stats'. Omit to run data generation with Hydra."
537+ ),
340538 )
341539 parser .print_help ()
342540 return
@@ -351,6 +549,55 @@ def main() -> None:
351549 print (name )
352550 return
353551
552+ if argv and argv [0 ] == "stats" :
553+ stats_parser = argparse .ArgumentParser (
554+ prog = "autosim stats" ,
555+ description = (
556+ "Generate normalization_stats YAML for an existing dataset directory."
557+ ),
558+ )
559+ stats_parser .add_argument (
560+ "dataset_dir" ,
561+ help = "Dataset root containing split folders such as train/data.pt." ,
562+ )
563+ stats_parser .add_argument (
564+ "--split" ,
565+ default = "train" ,
566+ help = "Split to use for stats (default: train)." ,
567+ )
568+ stats_parser .add_argument (
569+ "--output" ,
570+ default = None ,
571+ help = ("Optional output YAML path (default: <dataset_dir>/stats.yml)." ),
572+ )
573+ stats_parser .add_argument (
574+ "--field-names" ,
575+ default = None ,
576+ help = (
577+ "Optional comma-separated core field names, e.g. 'U,V'. "
578+ "If omitted, names are inferred from resolved_config.yaml "
579+ "when possible."
580+ ),
581+ )
582+ stats_parser .add_argument (
583+ "--sig-figs" ,
584+ type = int ,
585+ default = 4 ,
586+ help = "Significant figures for float stats in YAML (default: 4)." ,
587+ )
588+ args = stats_parser .parse_args (argv [1 :])
589+
590+ output_path = Path (args .output ) if args .output is not None else None
591+ written_path = generate_normalization_stats_yaml (
592+ dataset_dir = Path (args .dataset_dir ),
593+ split = str (args .split ),
594+ output_path = output_path ,
595+ core_field_names = _parse_field_names_csv (args .field_names ),
596+ sig_figs = int (args .sig_figs ),
597+ )
598+ print (written_path .as_posix ())
599+ return
600+
354601 # Preserve all original arguments for Hydra's own parser.
355602 sys .argv = [sys .argv [0 ], * argv ]
356603 _generate_main ()
0 commit comments