Skip to content

Commit 0bc366a

Browse files
committed
Write to stats.yml, remove root
1 parent e75d5e4 commit 0bc366a

2 files changed

Lines changed: 21 additions & 37 deletions

File tree

src/autosim/cli.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,14 @@ def _stats_by_channel(values: torch.Tensor) -> dict[str, float]:
257257
}
258258

259259
return {
260-
"normalization_stats": {
261-
"stats": {
262-
"mean": _stats_by_channel(mean),
263-
"std": _stats_by_channel(std),
264-
"mean_delta": _stats_by_channel(mean_delta),
265-
"std_delta": _stats_by_channel(std_delta),
266-
},
267-
"core_field_names": resolved_core_field_names,
268-
"constant_field_names": constant_field_names or [],
269-
}
260+
"stats": {
261+
"mean": _stats_by_channel(mean),
262+
"std": _stats_by_channel(std),
263+
"mean_delta": _stats_by_channel(mean_delta),
264+
"std_delta": _stats_by_channel(std_delta),
265+
},
266+
"core_field_names": resolved_core_field_names,
267+
"constant_field_names": constant_field_names or [],
270268
}
271269

272270

@@ -290,10 +288,7 @@ def _rounded_normalization_stats_payload(
290288
OmegaConf.to_container(OmegaConf.create(stats_payload), resolve=True),
291289
)
292290

293-
normalization_stats = rounded.get("normalization_stats")
294-
if not isinstance(normalization_stats, dict):
295-
return rounded
296-
stats = normalization_stats.get("stats")
291+
stats = rounded.get("stats")
297292
if not isinstance(stats, dict):
298293
return rounded
299294

@@ -359,9 +354,7 @@ def generate_normalization_stats_yaml(
359354
)
360355

361356
resolved_output_path = (
362-
output_path
363-
if output_path is not None
364-
else dataset_dir / f"{dataset_dir.name}.yaml"
357+
output_path if output_path is not None else dataset_dir / "stats.yml"
365358
)
366359
save_normalization_stats(
367360
stats_payload=stats_payload,
@@ -501,7 +494,7 @@ def _generate_main(cfg: Any) -> None:
501494
)
502495
save_normalization_stats(
503496
stats_payload=normalization_stats_payload,
504-
output_path=output_dir / f"{output_dir.name}.yaml",
497+
output_path=output_dir / "stats.yml",
505498
)
506499
save_example_videos(
507500
splits=splits,
@@ -575,10 +568,7 @@ def main() -> None:
575568
stats_parser.add_argument(
576569
"--output",
577570
default=None,
578-
help=(
579-
"Optional output YAML path (default: "
580-
"<dataset_dir>/<dataset_dir_name>.yaml)."
581-
),
571+
help=("Optional output YAML path (default: <dataset_dir>/stats.yml)."),
582572
)
583573
stats_parser.add_argument(
584574
"--field-names",

tests/test_cli.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_compute_normalization_stats_includes_temporal_deltas() -> None:
183183
stats_payload = compute_normalization_stats(
184184
split_payload=split_payload, core_field_names=["U", "V"]
185185
)
186-
stats = stats_payload["normalization_stats"]["stats"]
186+
stats = stats_payload["stats"]
187187

188188
assert stats["mean"]["U"] == pytest.approx(3.0)
189189
assert stats["mean"]["V"] == pytest.approx(14.0)
@@ -193,8 +193,8 @@ def test_compute_normalization_stats_includes_temporal_deltas() -> None:
193193
assert stats["mean_delta"]["V"] == pytest.approx(4.0)
194194
assert stats["std_delta"]["U"] == pytest.approx(0.0)
195195
assert stats["std_delta"]["V"] == pytest.approx(0.0)
196-
assert stats_payload["normalization_stats"]["core_field_names"] == ["U", "V"]
197-
assert stats_payload["normalization_stats"]["constant_field_names"] == []
196+
assert stats_payload["core_field_names"] == ["U", "V"]
197+
assert stats_payload["constant_field_names"] == []
198198

199199

200200
def test_generate_normalization_stats_yaml_from_existing_dataset(
@@ -228,11 +228,9 @@ def test_generate_normalization_stats_yaml_from_existing_dataset(
228228
assert output_path.exists()
229229
stats_cfg = OmegaConf.load(output_path)
230230
assert isinstance(stats_cfg, DictConfig)
231-
assert stats_cfg["normalization_stats"]["core_field_names"] == ["U", "V"]
232-
assert stats_cfg["normalization_stats"]["stats"]["mean"]["U"] == pytest.approx(2.5)
233-
assert stats_cfg["normalization_stats"]["stats"]["mean_delta"][
234-
"V"
235-
] == pytest.approx(2.0)
231+
assert stats_cfg["core_field_names"] == ["U", "V"]
232+
assert stats_cfg["stats"]["mean"]["U"] == pytest.approx(2.5)
233+
assert stats_cfg["stats"]["mean_delta"]["V"] == pytest.approx(2.0)
236234

237235

238236
def test_cli_stats_subcommand_writes_yaml(tmp_path: Path) -> None:
@@ -269,16 +267,12 @@ def test_cli_stats_subcommand_writes_yaml(tmp_path: Path) -> None:
269267
cwd=repo_root,
270268
)
271269

272-
stats_path = dataset_dir / f"{dataset_dir.name}.yaml"
270+
stats_path = dataset_dir / "stats.yml"
273271
assert stats_path.exists()
274272
stats_cfg = OmegaConf.load(stats_path)
275273
assert isinstance(stats_cfg, DictConfig)
276-
assert stats_cfg["normalization_stats"]["stats"]["mean_delta"][
277-
"U"
278-
] == pytest.approx(1.0)
279-
assert stats_cfg["normalization_stats"]["stats"]["std_delta"]["V"] == pytest.approx(
280-
0.0
281-
)
274+
assert stats_cfg["stats"]["mean_delta"]["U"] == pytest.approx(1.0)
275+
assert stats_cfg["stats"]["std_delta"]["V"] == pytest.approx(0.0)
282276

283277

284278
def test_cli_help_outputs_usage() -> None:

0 commit comments

Comments
 (0)