@@ -257,16 +257,14 @@ def _stats_by_channel(values: torch.Tensor) -> dict[str, float]:
257257 }
258258
259259 return {
260- "normalization_stats" : {
261- "stats" : {
262- "mean" : _stats_by_channel (mean ),
263- "std" : _stats_by_channel (std ),
264- "mean_delta" : _stats_by_channel (mean_delta ),
265- "std_delta" : _stats_by_channel (std_delta ),
266- },
267- "core_field_names" : resolved_core_field_names ,
268- "constant_field_names" : constant_field_names or [],
269- }
260+ "stats" : {
261+ "mean" : _stats_by_channel (mean ),
262+ "std" : _stats_by_channel (std ),
263+ "mean_delta" : _stats_by_channel (mean_delta ),
264+ "std_delta" : _stats_by_channel (std_delta ),
265+ },
266+ "core_field_names" : resolved_core_field_names ,
267+ "constant_field_names" : constant_field_names or [],
270268 }
271269
272270
@@ -290,10 +288,7 @@ def _rounded_normalization_stats_payload(
290288 OmegaConf .to_container (OmegaConf .create (stats_payload ), resolve = True ),
291289 )
292290
293- normalization_stats = rounded .get ("normalization_stats" )
294- if not isinstance (normalization_stats , dict ):
295- return rounded
296- stats = normalization_stats .get ("stats" )
291+ stats = rounded .get ("stats" )
297292 if not isinstance (stats , dict ):
298293 return rounded
299294
@@ -359,9 +354,7 @@ def generate_normalization_stats_yaml(
359354 )
360355
361356 resolved_output_path = (
362- output_path
363- if output_path is not None
364- else dataset_dir / f"{ dataset_dir .name } .yaml"
357+ output_path if output_path is not None else dataset_dir / "stats.yml"
365358 )
366359 save_normalization_stats (
367360 stats_payload = stats_payload ,
@@ -501,7 +494,7 @@ def _generate_main(cfg: Any) -> None:
501494 )
502495 save_normalization_stats (
503496 stats_payload = normalization_stats_payload ,
504- output_path = output_dir / f" { output_dir . name } .yaml " ,
497+ output_path = output_dir / "stats.yml " ,
505498 )
506499 save_example_videos (
507500 splits = splits ,
@@ -575,10 +568,7 @@ def main() -> None:
575568 stats_parser .add_argument (
576569 "--output" ,
577570 default = None ,
578- help = (
579- "Optional output YAML path (default: "
580- "<dataset_dir>/<dataset_dir_name>.yaml)."
581- ),
571+ help = ("Optional output YAML path (default: <dataset_dir>/stats.yml)." ),
582572 )
583573 stats_parser .add_argument (
584574 "--field-names" ,
0 commit comments