Skip to content

Commit c843e14

Browse files
committed
add test_static_no_filenames
1 parent 3fa3bfe commit c843e14

File tree

1 file changed

+65
-24
lines changed

1 file changed

+65
-24
lines changed

tests/test_runners.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from pathlib import Path
12
from typing import Any
23

34
import numpy as np
5+
import pytest
46
import torch
57

68
from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher
@@ -14,11 +16,11 @@
1416
from torch_sim.units import UnitSystem
1517

1618

17-
def test_integrate_nve(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) -> None:
19+
def test_integrate_nve(ar_sim_state: SimState, lj_model: Any, tmp_path: Path) -> None:
1820
"""Test NVE integration with LJ potential."""
19-
trajectory_file = tmp_path / "nve.h5md"
21+
traj_file = tmp_path / "nve.h5md"
2022
reporter = TrajectoryReporter(
21-
filenames=trajectory_file,
23+
filenames=traj_file,
2224
state_frequency=1,
2325
prop_calculators={
2426
1: {"ke": lambda state: kinetic_energy(state.momenta, state.masses)}
@@ -37,22 +39,22 @@ def test_integrate_nve(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
3739
)
3840

3941
assert isinstance(final_state, SimState)
40-
assert trajectory_file.exists()
42+
assert traj_file.exists()
4143

4244
# Check energy conservation
43-
with TorchSimTrajectory(trajectory_file) as traj:
45+
with TorchSimTrajectory(traj_file) as traj:
4446
energies = traj.get_array("ke")
4547
std_energy = np.std(energies)
4648
assert std_energy / np.mean(energies) < 0.1 # 10% tolerance
4749

4850

4951
def test_integrate_single_nvt(
50-
ar_sim_state: SimState, lj_model: Any, tmp_path: Any
52+
ar_sim_state: SimState, lj_model: Any, tmp_path: Path
5153
) -> None:
5254
"""Test NVT integration with LJ potential."""
53-
trajectory_file = tmp_path / "nvt.h5md"
55+
traj_file = tmp_path / "nvt.h5md"
5456
reporter = TrajectoryReporter(
55-
filenames=trajectory_file,
57+
filenames=traj_file,
5658
state_frequency=1,
5759
prop_calculators={
5860
1: {"ke": lambda state: kinetic_energy(state.momenta, state.masses)}
@@ -72,10 +74,10 @@ def test_integrate_single_nvt(
7274
)
7375

7476
assert isinstance(final_state, SimState)
75-
assert trajectory_file.exists()
77+
assert traj_file.exists()
7678

7779
# Check energy fluctuations
78-
with TorchSimTrajectory(trajectory_file) as traj:
80+
with TorchSimTrajectory(traj_file) as traj:
7981
energies = traj.get_array("ke")
8082
std_energy = np.std(energies)
8183
assert std_energy / np.mean(energies) < 0.2 # 20% tolerance for NVT
@@ -98,7 +100,7 @@ def test_integrate_double_nvt(ar_double_sim_state: SimState, lj_model: Any) -> N
98100

99101

100102
def test_integrate_double_nvt_with_reporter(
101-
ar_double_sim_state: SimState, lj_model: Any, tmp_path: Any
103+
ar_double_sim_state: SimState, lj_model: Any, tmp_path: Path
102104
) -> None:
103105
"""Test NVT integration with LJ potential."""
104106
trajectory_files = [tmp_path / "nvt_0.h5md", tmp_path / "nvt_1.h5md"]
@@ -139,7 +141,7 @@ def test_integrate_many_nvt(
139141
ar_sim_state: SimState,
140142
fe_fcc_sim_state: SimState,
141143
lj_model: Any,
142-
tmp_path: Any,
144+
tmp_path: Path,
143145
) -> None:
144146
"""Test NVT integration with LJ potential."""
145147
triple_state = initialize_state(
@@ -216,7 +218,7 @@ def test_integrate_with_autobatcher_and_reporting(
216218
ar_sim_state: SimState,
217219
fe_fcc_sim_state: SimState,
218220
lj_model: Any,
219-
tmp_path: Any,
221+
tmp_path: Path,
220222
) -> None:
221223
"""Test integration with autobatcher."""
222224
states = [ar_sim_state, fe_fcc_sim_state, ar_sim_state]
@@ -280,7 +282,7 @@ def test_integrate_with_autobatcher_and_reporting(
280282
assert torch.any(final_state.positions != init_state.positions)
281283

282284

283-
def test_optimize_fire(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) -> None:
285+
def test_optimize_fire(ar_sim_state: SimState, lj_model: Any, tmp_path: Path) -> None:
284286
"""Test FIRE optimization with LJ potential."""
285287
trajectory_files = [tmp_path / "opt.h5md"]
286288
reporter = TrajectoryReporter(
@@ -311,7 +313,7 @@ def test_optimize_fire(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
311313

312314

313315
def test_default_converged_fn(
314-
ar_sim_state: SimState, lj_model: Any, tmp_path: Any
316+
ar_sim_state: SimState, lj_model: Any, tmp_path: Path
315317
) -> None:
316318
"""Test default converged function."""
317319
ar_sim_state.positions += torch.randn_like(ar_sim_state.positions) * 0.1
@@ -341,7 +343,7 @@ def test_default_converged_fn(
341343
def test_batched_optimize_fire(
342344
ar_double_sim_state: SimState,
343345
lj_model: Any,
344-
tmp_path: Any,
346+
tmp_path: Path,
345347
) -> None:
346348
"""Test batched FIRE optimization with LJ potential."""
347349
trajectory_files = [
@@ -403,7 +405,7 @@ def test_optimize_with_autobatcher_and_reporting(
403405
ar_sim_state: SimState,
404406
fe_fcc_sim_state: SimState,
405407
lj_model: Any,
406-
tmp_path: Any,
408+
tmp_path: Path,
407409
) -> None:
408410
"""Test optimize with autobatcher and reporting."""
409411
states = [ar_sim_state, fe_fcc_sim_state, ar_sim_state]
@@ -540,13 +542,14 @@ def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001
540542
assert torch.any(final_state.positions != init_state.positions)
541543

542544

543-
def test_static_single(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) -> None:
545+
def test_static_single(ar_sim_state: SimState, lj_model: Any, tmp_path: Path) -> None:
544546
"""Test static calculation with LJ potential."""
545-
trajectory_file = tmp_path / "static.h5md"
547+
traj_file = tmp_path / "static.h5md"
546548
reporter = TrajectoryReporter(
547-
filenames=trajectory_file,
549+
filenames=traj_file,
548550
state_frequency=1,
549551
prop_calculators={1: {"potential_energy": lambda state: state.energy}},
552+
state_kwargs={"save_forces": True}, # Enable force saving
550553
)
551554

552555
props = static(
@@ -558,17 +561,23 @@ def test_static_single(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
558561
assert isinstance(props, list)
559562
assert len(props) == 1 # Single system = single props dict
560563
assert "potential_energy" in props[0]
561-
assert trajectory_file.exists()
564+
assert traj_file.exists()
562565

563566
# Check that energy was computed and saved correctly
564-
with TorchSimTrajectory(trajectory_file) as traj:
567+
with TorchSimTrajectory(traj_file) as traj:
565568
saved_energy = traj.get_array("potential_energy")
566569
assert len(saved_energy) == 1 # Static calc = single frame
567570
np.testing.assert_allclose(saved_energy[0], props[0]["potential_energy"].numpy())
568571

572+
# Verify state_kwargs were applied correctly
573+
assert traj.get_array("atomic_numbers").shape == (1, ar_sim_state.n_atoms)
574+
assert traj.get_array("masses").shape == (1, ar_sim_state.n_atoms)
575+
if lj_model.compute_forces:
576+
assert "forces" in traj.array_registry
577+
569578

570579
def test_static_double(
571-
ar_double_sim_state: SimState, lj_model: Any, tmp_path: Any
580+
ar_double_sim_state: SimState, lj_model: Any, tmp_path: Path
572581
) -> None:
573582
"""Test static calculation with multiple systems."""
574583
trajectory_files = [tmp_path / "static_0.h5md", tmp_path / "static_1.h5md"]
@@ -636,7 +645,7 @@ def test_static_with_autobatcher_and_reporting(
636645
ar_sim_state: SimState,
637646
fe_fcc_sim_state: SimState,
638647
lj_model: Any,
639-
tmp_path: Any,
648+
tmp_path: Path,
640649
) -> None:
641650
"""Test static calculation with autobatcher and trajectory reporting."""
642651
states = [ar_sim_state, fe_fcc_sim_state, ar_sim_state]
@@ -685,3 +694,35 @@ def test_static_with_autobatcher_and_reporting(
685694
assert not np.allclose(
686695
props[0]["potential_energy"].numpy(), props[1]["potential_energy"].numpy()
687696
)
697+
698+
699+
def test_static_no_filenames(
700+
ar_sim_state: SimState, lj_model: Any, tmp_path: Path
701+
) -> None:
702+
"""Test static calculation with no trajectory filenames."""
703+
reporter = TrajectoryReporter(
704+
filenames=None,
705+
state_frequency=1,
706+
prop_calculators={1: {"potential_energy": lambda state: state.energy}},
707+
)
708+
709+
props = static(system=ar_sim_state, model=lj_model, trajectory_reporter=reporter)
710+
711+
assert isinstance(props, list)
712+
assert len(props) == 1
713+
assert "potential_energy" in props[0]
714+
assert isinstance(props[0]["potential_energy"], torch.Tensor)
715+
716+
reporter = TrajectoryReporter(
717+
filenames=tmp_path / "static.h5md",
718+
state_frequency=2, # Invalid for static calculations
719+
prop_calculators={1: {"potential_energy": lambda state: state.energy}},
720+
)
721+
722+
# should raise for invalid state frequency
723+
with pytest.raises(ValueError, match="state_frequency=2 must be 1 for statics"):
724+
static(
725+
system=ar_sim_state,
726+
model=lj_model,
727+
trajectory_reporter=reporter,
728+
)

0 commit comments

Comments
 (0)