1+ from pathlib import Path
12from typing import Any
23
34import numpy as np
5+ import pytest
46import torch
57
68from torch_sim .autobatching import ChunkingAutoBatcher , HotSwappingAutoBatcher
1416from torch_sim .units import UnitSystem
1517
1618
17- def test_integrate_nve (ar_sim_state : SimState , lj_model : Any , tmp_path : Any ) -> None :
19+ def test_integrate_nve (ar_sim_state : SimState , lj_model : Any , tmp_path : Path ) -> None :
1820 """Test NVE integration with LJ potential."""
19- trajectory_file = tmp_path / "nve.h5md"
21+ traj_file = tmp_path / "nve.h5md"
2022 reporter = TrajectoryReporter (
21- filenames = trajectory_file ,
23+ filenames = traj_file ,
2224 state_frequency = 1 ,
2325 prop_calculators = {
2426 1 : {"ke" : lambda state : kinetic_energy (state .momenta , state .masses )}
@@ -37,22 +39,22 @@ def test_integrate_nve(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
3739 )
3840
3941 assert isinstance (final_state , SimState )
40- assert trajectory_file .exists ()
42+ assert traj_file .exists ()
4143
4244 # Check energy conservation
43- with TorchSimTrajectory (trajectory_file ) as traj :
45+ with TorchSimTrajectory (traj_file ) as traj :
4446 energies = traj .get_array ("ke" )
4547 std_energy = np .std (energies )
4648 assert std_energy / np .mean (energies ) < 0.1 # 10% tolerance
4749
4850
4951def test_integrate_single_nvt (
50- ar_sim_state : SimState , lj_model : Any , tmp_path : Any
52+ ar_sim_state : SimState , lj_model : Any , tmp_path : Path
5153) -> None :
5254 """Test NVT integration with LJ potential."""
53- trajectory_file = tmp_path / "nvt.h5md"
55+ traj_file = tmp_path / "nvt.h5md"
5456 reporter = TrajectoryReporter (
55- filenames = trajectory_file ,
57+ filenames = traj_file ,
5658 state_frequency = 1 ,
5759 prop_calculators = {
5860 1 : {"ke" : lambda state : kinetic_energy (state .momenta , state .masses )}
@@ -72,10 +74,10 @@ def test_integrate_single_nvt(
7274 )
7375
7476 assert isinstance (final_state , SimState )
75- assert trajectory_file .exists ()
77+ assert traj_file .exists ()
7678
7779 # Check energy fluctuations
78- with TorchSimTrajectory (trajectory_file ) as traj :
80+ with TorchSimTrajectory (traj_file ) as traj :
7981 energies = traj .get_array ("ke" )
8082 std_energy = np .std (energies )
8183 assert std_energy / np .mean (energies ) < 0.2 # 20% tolerance for NVT
@@ -98,7 +100,7 @@ def test_integrate_double_nvt(ar_double_sim_state: SimState, lj_model: Any) -> N
98100
99101
100102def test_integrate_double_nvt_with_reporter (
101- ar_double_sim_state : SimState , lj_model : Any , tmp_path : Any
103+ ar_double_sim_state : SimState , lj_model : Any , tmp_path : Path
102104) -> None :
103105 """Test NVT integration with LJ potential."""
104106 trajectory_files = [tmp_path / "nvt_0.h5md" , tmp_path / "nvt_1.h5md" ]
@@ -139,7 +141,7 @@ def test_integrate_many_nvt(
139141 ar_sim_state : SimState ,
140142 fe_fcc_sim_state : SimState ,
141143 lj_model : Any ,
142- tmp_path : Any ,
144+ tmp_path : Path ,
143145) -> None :
144146 """Test NVT integration with LJ potential."""
145147 triple_state = initialize_state (
@@ -216,7 +218,7 @@ def test_integrate_with_autobatcher_and_reporting(
216218 ar_sim_state : SimState ,
217219 fe_fcc_sim_state : SimState ,
218220 lj_model : Any ,
219- tmp_path : Any ,
221+ tmp_path : Path ,
220222) -> None :
221223 """Test integration with autobatcher."""
222224 states = [ar_sim_state , fe_fcc_sim_state , ar_sim_state ]
@@ -280,7 +282,7 @@ def test_integrate_with_autobatcher_and_reporting(
280282 assert torch .any (final_state .positions != init_state .positions )
281283
282284
283- def test_optimize_fire (ar_sim_state : SimState , lj_model : Any , tmp_path : Any ) -> None :
285+ def test_optimize_fire (ar_sim_state : SimState , lj_model : Any , tmp_path : Path ) -> None :
284286 """Test FIRE optimization with LJ potential."""
285287 trajectory_files = [tmp_path / "opt.h5md" ]
286288 reporter = TrajectoryReporter (
@@ -311,7 +313,7 @@ def test_optimize_fire(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
311313
312314
313315def test_default_converged_fn (
314- ar_sim_state : SimState , lj_model : Any , tmp_path : Any
316+ ar_sim_state : SimState , lj_model : Any , tmp_path : Path
315317) -> None :
316318 """Test default converged function."""
317319 ar_sim_state .positions += torch .randn_like (ar_sim_state .positions ) * 0.1
@@ -341,7 +343,7 @@ def test_default_converged_fn(
341343def test_batched_optimize_fire (
342344 ar_double_sim_state : SimState ,
343345 lj_model : Any ,
344- tmp_path : Any ,
346+ tmp_path : Path ,
345347) -> None :
346348 """Test batched FIRE optimization with LJ potential."""
347349 trajectory_files = [
@@ -403,7 +405,7 @@ def test_optimize_with_autobatcher_and_reporting(
403405 ar_sim_state : SimState ,
404406 fe_fcc_sim_state : SimState ,
405407 lj_model : Any ,
406- tmp_path : Any ,
408+ tmp_path : Path ,
407409) -> None :
408410 """Test optimize with autobatcher and reporting."""
409411 states = [ar_sim_state , fe_fcc_sim_state , ar_sim_state ]
@@ -540,13 +542,14 @@ def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001
540542 assert torch .any (final_state .positions != init_state .positions )
541543
542544
543- def test_static_single (ar_sim_state : SimState , lj_model : Any , tmp_path : Any ) -> None :
545+ def test_static_single (ar_sim_state : SimState , lj_model : Any , tmp_path : Path ) -> None :
544546 """Test static calculation with LJ potential."""
545- trajectory_file = tmp_path / "static.h5md"
547+ traj_file = tmp_path / "static.h5md"
546548 reporter = TrajectoryReporter (
547- filenames = trajectory_file ,
549+ filenames = traj_file ,
548550 state_frequency = 1 ,
549551 prop_calculators = {1 : {"potential_energy" : lambda state : state .energy }},
552+ state_kwargs = {"save_forces" : True }, # Enable force saving
550553 )
551554
552555 props = static (
@@ -558,17 +561,23 @@ def test_static_single(ar_sim_state: SimState, lj_model: Any, tmp_path: Any) ->
558561 assert isinstance (props , list )
559562 assert len (props ) == 1 # Single system = single props dict
560563 assert "potential_energy" in props [0 ]
561- assert trajectory_file .exists ()
564+ assert traj_file .exists ()
562565
563566 # Check that energy was computed and saved correctly
564- with TorchSimTrajectory (trajectory_file ) as traj :
567+ with TorchSimTrajectory (traj_file ) as traj :
565568 saved_energy = traj .get_array ("potential_energy" )
566569 assert len (saved_energy ) == 1 # Static calc = single frame
567570 np .testing .assert_allclose (saved_energy [0 ], props [0 ]["potential_energy" ].numpy ())
568571
572+ # Verify state_kwargs were applied correctly
573+ assert traj .get_array ("atomic_numbers" ).shape == (1 , ar_sim_state .n_atoms )
574+ assert traj .get_array ("masses" ).shape == (1 , ar_sim_state .n_atoms )
575+ if lj_model .compute_forces :
576+ assert "forces" in traj .array_registry
577+
569578
570579def test_static_double (
571- ar_double_sim_state : SimState , lj_model : Any , tmp_path : Any
580+ ar_double_sim_state : SimState , lj_model : Any , tmp_path : Path
572581) -> None :
573582 """Test static calculation with multiple systems."""
574583 trajectory_files = [tmp_path / "static_0.h5md" , tmp_path / "static_1.h5md" ]
@@ -636,7 +645,7 @@ def test_static_with_autobatcher_and_reporting(
636645 ar_sim_state : SimState ,
637646 fe_fcc_sim_state : SimState ,
638647 lj_model : Any ,
639- tmp_path : Any ,
648+ tmp_path : Path ,
640649) -> None :
641650 """Test static calculation with autobatcher and trajectory reporting."""
642651 states = [ar_sim_state , fe_fcc_sim_state , ar_sim_state ]
@@ -685,3 +694,35 @@ def test_static_with_autobatcher_and_reporting(
685694 assert not np .allclose (
686695 props [0 ]["potential_energy" ].numpy (), props [1 ]["potential_energy" ].numpy ()
687696 )
697+
698+
699+ def test_static_no_filenames (
700+ ar_sim_state : SimState , lj_model : Any , tmp_path : Path
701+ ) -> None :
702+ """Test static calculation with no trajectory filenames."""
703+ reporter = TrajectoryReporter (
704+ filenames = None ,
705+ state_frequency = 1 ,
706+ prop_calculators = {1 : {"potential_energy" : lambda state : state .energy }},
707+ )
708+
709+ props = static (system = ar_sim_state , model = lj_model , trajectory_reporter = reporter )
710+
711+ assert isinstance (props , list )
712+ assert len (props ) == 1
713+ assert "potential_energy" in props [0 ]
714+ assert isinstance (props [0 ]["potential_energy" ], torch .Tensor )
715+
716+ reporter = TrajectoryReporter (
717+ filenames = tmp_path / "static.h5md" ,
718+ state_frequency = 2 , # Invalid for static calculations
719+ prop_calculators = {1 : {"potential_energy" : lambda state : state .energy }},
720+ )
721+
722+ # should raise for invalid state frequency
723+ with pytest .raises (ValueError , match = "state_frequency=2 must be 1 for statics" ):
724+ static (
725+ system = ar_sim_state ,
726+ model = lj_model ,
727+ trajectory_reporter = reporter ,
728+ )
0 commit comments