diff --git a/CHANGELOG.md b/CHANGELOG.md index 647ee854..05abd71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,16 @@ # Changelog +## Unreleased + +### 🎉 New Features +* Constraints support for molecular dynamics and optimization by @thomasloux in [#294](https://github.com/TorchSim/torch-sim/pull/294) + - Added `FixAtoms` constraint to fix specific atoms in place + - Added `FixCom` constraint to prevent center of mass drift + - Constraints automatically adjust degrees of freedom for accurate temperature calculations + - Full support across all integrators (NVE, NVT, NPT) and optimizers (FIRE, Gradient Descent) + - Constraints preserved during state manipulation (slicing, splitting, concatenation) + ## v0.5.0 This release focuses on improving batch processing capabilities across TorchSim. The neighbor list module has been completely refactored to support batched calculations with multiple backend implementations, elastic tensor calculations now leverage batched operations for improved performance, and a bug fix ensures Monte Carlo swaps work correctly with ragged (different-sized) systems. diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 8a25f0c9..ee638199 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -67,9 +67,8 @@ class HybridSwapMCState(ts.SwapMCState, MDState): last_swap: Last swap attempted """ - last_permutation: torch.Tensor _atom_attributes = ( - ts.SwapMCState._atom_attributes | MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 ) _system_attributes = ( ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 4ceca9f9..2f959fb6 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -100,9 +100,11 @@ class HybridSwapMCState(SwapMCState, MDState): from MDState. """ - last_permutation: torch.Tensor _atom_attributes = ( - MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) diff --git a/tests/conftest.py b/tests/conftest.py index de97b021..d9ce55ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -339,7 +339,7 @@ def distorted_fcc_al_conventional_sim_state() -> ts.SimState: positions = atoms_fcc.get_positions() np_rng = np.random.default_rng(seed=42) positions += np_rng.normal(scale=0.01, size=positions.shape) - atoms_fcc.set_positions(positions) + atoms_fcc.positions = positions # Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms) return ts.io.atoms_to_state(atoms_fcc, device=DEVICE, dtype=DTYPE) diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 00000000..c6e72ac3 --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,839 @@ +from typing import get_args + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DTYPE +from torch_sim.constraints import ( + Constraint, + FixAtoms, + FixCom, + merge_constraints, + validate_constraints, +) +from torch_sim.models.interface import ModelInterface +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.optimizers import FireFlavor +from torch_sim.transforms import get_centers_of_mass +from torch_sim.units import MetalUnits + + +def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixCom constraint.""" + ar_supercell_sim_state.constraints = [FixCom([0])] + initial_positions = ar_supercell_sim_state.positions.clone() + ar_supercell_sim_state.set_constrained_positions(initial_positions + 0.5) + assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) + + ar_supercell_md_state = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_md_state.set_constrained_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) + assert torch.allclose( + ar_supercell_md_state.momenta.mean(dim=0), + torch.zeros(3, dtype=DTYPE), + atol=1e-8, + ) + + +def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixAtoms constraint.""" + indices_to_fix = torch.tensor([0, 5, 10], dtype=torch.long) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices_to_fix)] + initial_positions = ar_supercell_sim_state.positions.clone() + # displacement = torch.randn_like(ar_supercell_sim_state.positions) * 0.5 + displacement = 0.5 + ar_supercell_sim_state.set_constrained_positions(initial_positions + displacement) + assert torch.allclose( + ar_supercell_sim_state.positions[indices_to_fix], + initial_positions[indices_to_fix], + atol=1e-8, + ) + # Check that other positions have changed + unfixed_indices = torch.tensor( + [i for i in range(ar_supercell_sim_state.n_atoms) if i not in indices_to_fix], + dtype=torch.long, + ) + assert not torch.allclose( + ar_supercell_sim_state.positions[unfixed_indices], + initial_positions[unfixed_indices], + atol=1e-8, + ) + + ar_supercell_md_state = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_md_state.set_constrained_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) + assert torch.allclose( + ar_supercell_md_state.momenta[indices_to_fix], + torch.zeros_like(ar_supercell_md_state.momenta[indices_to_fix]), + atol=1e-8, + ) + + +def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test FixCom constraint in NVT Langevin dynamics.""" + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() + cu_sim_state.constraints = [FixCom([0])] + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - 3 + ) + + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + positions = [] + system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 1), + state.masses.unsqueeze(-1), + ) + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.get_number_of_degrees_of_freedom(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) + + traj_positions = torch.stack(positions) + + coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_( + 1, + state.system_idx[None, :, None].expand(n_steps, -1, 3), + state.masses.unsqueeze(-1) * traj_positions, + ) + coms /= system_masses + coms_drift = coms - coms[0] + assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-6) + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 + + +def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test FixAtoms constraint in NVT Langevin dynamics.""" + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() + cu_sim_state.constraints = [FixAtoms(atom_idx=torch.tensor([0, 1], dtype=torch.long))] + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - torch.tensor([6]) + ) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + positions = [] + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.get_number_of_degrees_of_freedom(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) + traj_positions = torch.stack(positions) + + diff_positions = traj_positions - traj_positions[0] + assert torch.max(diff_positions[:, :2]) < 1e-8 + assert torch.max(diff_positions[:, 2:]) > 1e-3 + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 + + +def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): + """Test that constraints are properly propagated during state manipulation.""" + # Set up constraints on the original state + ar_double_sim_state.constraints = [ + FixAtoms(atom_idx=torch.tensor([0, 1])), # Only applied to first system + FixCom([0, 1]), + ] + + # Extract individual systems from the double system state + first_system = ar_double_sim_state[0] # FixAtoms + FixCom + second_system = ar_double_sim_state[1] # FixCom only + concatenated_state = ts.concatenate_states( + [first_system, first_system, second_system] + ) + + # Verify constraint propagation to subsystems + assert len(first_system.constraints) == 2 + assert len(second_system.constraints) == 1 + assert len(concatenated_state.constraints) == 2 + + # Verify FixAtoms constraint indices are correctly mapped + assert torch.all(first_system.constraints[0].atom_idx == torch.tensor([0, 1])) + assert torch.all( + concatenated_state.constraints[0].atom_idx == torch.tensor([0, 1, 32, 33]) + ) + + # Verify FixCom constraint system masks + assert torch.all( + concatenated_state.constraints[1].system_idx == torch.tensor([0, 1, 2]) + ) + + # Test constraint propagation after splitting concatenated state + split_systems = concatenated_state.split() + assert len(split_systems[0].constraints) == 2 + assert torch.all(split_systems[0].constraints[0].atom_idx == torch.tensor([0, 1])) + assert torch.all(split_systems[1].constraints[0].atom_idx == torch.tensor([0, 1])) + assert len(split_systems[2].constraints) == 1 + + # Test constraint manipulation with different configurations + ar_double_sim_state.constraints = [] + ar_double_sim_state.constraints = [FixCom([0, 1])] + isolated_system = ar_double_sim_state[0] + assert torch.all( + isolated_system.constraints[0].system_idx == torch.tensor([0], dtype=torch.long) + ) + + # Test concatenation with mixed constraint states + isolated_system.constraints = [] + mixed_concatenated_state = ts.concatenate_states( + [isolated_system, ar_double_sim_state, isolated_system] + ) + assert torch.all( + mixed_concatenated_state.constraints[0].system_idx == torch.tensor([1, 2]) + ) + + +def test_fix_com_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixCom constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + ar_supercell_sim_state.constraints = [FixCom([0])] + + initial_coms = get_centers_of_mass( + positions=initial_state.positions, + masses=initial_state.masses, + system_idx=initial_state.system_idx, + n_systems=initial_state.n_systems, + ) + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_coms = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=initial_state.n_systems, + ) + + assert torch.allclose(final_coms, initial_coms, atol=1e-4) + assert not torch.allclose(state.positions, initial_state.positions) + + +def test_fix_atoms_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixAtoms constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + initial_state.constraints = [FixAtoms(atom_idx=[0])] + initial_position = initial_state.positions[0].clone() + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_position = state.positions[0] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + assert not torch.allclose(state.positions, initial_state.positions) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_test_atoms_fire_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, +) -> None: + """Test FixAtoms constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + indices = torch.tensor([0, 2], dtype=torch.long) + current_sim_state.constraints = [FixAtoms(atom_idx=indices)] + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_position = state.positions[indices].clone() + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_position = state.positions[indices] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_fix_com_fire_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, +) -> None: + """Test FixCom constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + current_sim_state.constraints = [FixCom([0])] + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + assert torch.allclose(final_com, initial_com, atol=1e-4) + + +def test_fix_atoms_validation() -> None: + """Test FixAtoms construction and validation.""" + # Boolean mask conversion + mask = torch.zeros(10, dtype=torch.bool) + mask[:3] = True + assert torch.all(FixAtoms(atom_mask=mask).atom_idx == torch.tensor([0, 1, 2])) + + # Invalid indices + with pytest.raises(ValueError, match="Indices must be integers"): + FixAtoms(atom_idx=torch.tensor([0.5, 1.5])) + with pytest.raises(ValueError, match="Duplicate"): + FixAtoms(atom_idx=torch.tensor([0, 1, 1])) + with pytest.raises(ValueError, match="wrong number of dimensions"): + FixAtoms(atom_idx=torch.tensor([[0, 1]])) + + +def test_constraint_validation_warnings(ar_double_sim_state: ts.SimState) -> None: + """Test validation warnings for constraint conflicts.""" + with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): + validate_constraints( + [FixAtoms(atom_idx=[0, 1, 2]), FixAtoms(atom_idx=[2, 3, 4])], + ar_double_sim_state, + ) + with pytest.warns(UserWarning, match="FixCom together with other constraints"): + validate_constraints( + [FixCom([0]), FixAtoms(atom_idx=[0, 1])], ar_double_sim_state + ) + + +def test_constraint_validation_errors( + cu_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, +) -> None: + """Test validation errors for invalid constraints.""" + # Out of bounds + with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 100])] + + # Validation in __post_init__ + with pytest.raises(ValueError, match="Duplicate"): + ts.SimState( + positions=ar_supercell_sim_state.positions.clone(), + masses=ar_supercell_sim_state.masses, + cell=ar_supercell_sim_state.cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers, + system_idx=ar_supercell_sim_state.system_idx, + _constraints=[FixAtoms(atom_idx=[0, 0, 1])], + ) + + +@pytest.mark.parametrize( + ("integrator", "constraint", "n_steps"), + [ + ("nve", FixAtoms(atom_idx=[0, 1]), 100), + ("nvt_nose_hoover", FixCom([0]), 200), + ("npt_langevin", FixAtoms(atom_idx=[0, 3]), 200), + ("npt_nose_hoover", FixCom([0]), 200), + ], +) +def test_integrators_with_constraints( + cu_sim_state: ts.SimState, + lj_model: LennardJonesModel, + integrator: str, + constraint: Constraint, + n_steps: int, +) -> None: + """Test all integrators respect constraints.""" + cu_sim_state.constraints = [constraint] + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) + + # Store initial state + if isinstance(constraint, FixAtoms): + initial = cu_sim_state.positions[constraint.atom_idx].clone() + else: + initial = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + + # Run integration + if integrator == "nve": + state = ts.nve_init(cu_sim_state, lj_model, kT=kT, seed=42) + for _ in range(n_steps): + state = ts.nve_step(state, lj_model, dt=dt) + elif integrator == "nvt_nose_hoover": + state = ts.nvt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) + for _ in range(n_steps): + state = ts.nvt_nose_hoover_step(state, lj_model, dt=dt, kT=kT) + elif integrator == "npt_langevin": + state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42, dt=dt) + for _ in range(n_steps): + state = ts.npt_langevin_step( + state, + lj_model, + dt=dt, + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + else: # npt_nose_hoover + state = ts.npt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) + for _ in range(n_steps): + state = ts.npt_nose_hoover_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + + # Verify constraint held + if isinstance(constraint, FixAtoms): + assert torch.allclose(state.positions[constraint.atom_idx], initial, atol=1e-6) + else: + final = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_multiple_constraints_and_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test multiple constraints together with correct DOF calculation.""" + # Test DOF calculation + n = cu_sim_state.n_atoms + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n) + cu_sim_state.constraints = [FixAtoms(atom_idx=[0])] + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 3) + cu_sim_state.constraints = [FixCom([0]), FixAtoms(atom_idx=[0])] + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 6) + + # Verify both constraints hold during dynamics + initial_pos = cu_sim_state.positions[0].clone() + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + for _ in range(200): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + ) + assert torch.allclose(state.positions[0], initial_pos, atol=1e-6) + final_com = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + +@pytest.mark.parametrize( + ("cell_filter", "fire_flavor"), + [ + (ts.CellFilter.unit, "ase_fire"), + (ts.CellFilter.frechet, "ase_fire"), + (ts.CellFilter.frechet, "vv_fire"), + ], +) +def test_cell_optimization_with_constraints( + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + cell_filter: str, + fire_flavor: FireFlavor, +) -> None: + """Test cell filters work with constraints.""" + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.05 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1])] + state = ts.fire_init( + ar_supercell_sim_state, + lj_model, + cell_filter=cell_filter, + fire_flavor=fire_flavor, + ) + for _ in range(50): + state = ts.fire_step(state, lj_model, dt_max=0.1) + if state.forces.abs().max() < 0.05: + break + assert len(state.constraints) > 0 + + +def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test system-specific constraints in batched states.""" + s1, s2 = ar_double_sim_state.split() + s1.constraints = [FixAtoms(atom_idx=[0, 1])] + s2.constraints = [FixCom([0])] + combined = ts.concatenate_states([s1, s2]) + assert len(combined.constraints) == 2 + assert isinstance(combined.constraints[0], FixAtoms) + assert torch.all(combined.constraints[0].atom_idx == torch.tensor([0, 1])) + assert isinstance(combined.constraints[1], FixCom) + assert torch.all(combined.constraints[1].system_idx == torch.tensor([1])) + + +def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: + """Test constraints work with non-periodic boundaries.""" + state = ts.SimState( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]], + dtype=DTYPE, + ), + masses=torch.ones(4, dtype=DTYPE) * 39.948, + cell=torch.eye(3, dtype=DTYPE).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.full((4,), 18, dtype=torch.long), + system_idx=torch.zeros(4, dtype=torch.long), + ) + state.constraints = [FixCom([0])] + initial = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + md_state = ts.nve_init(state, lj_model, kT=torch.tensor(100.0, dtype=DTYPE), seed=42) + for _ in range(100): + md_state = ts.nve_step(md_state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) + final = get_centers_of_mass( + md_state.positions, md_state.masses, md_state.system_idx, md_state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_high_level_api_with_constraints( + cu_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test high-level integrate() and optimize() APIs with constraints.""" + # Test integrate() + cu_sim_state.constraints = [FixCom([0])] + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + final = ts.integrate( + cu_sim_state, + lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300.0, + timestep=0.001, + ) + final_com = get_centers_of_mass( + final.positions, final.masses, final.system_idx, final.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + # Test optimize() + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] + initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() + final = ts.optimize( + ar_supercell_sim_state, lj_model, optimizer=ts.Optimizer.fire, max_steps=500 + ) + assert torch.allclose(final.positions[[0, 1, 2]], initial_pos, atol=1e-5) + + +def test_temperature_with_constrained_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test temperature calculation uses constrained DOF.""" + target = 300.0 + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + temps = [] + for _ in range(4000): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + ) + temp = state.calc_kT() + temps.append(temp / MetalUnits.temperature) + avg = torch.mean(torch.stack(temps)[500:]) + assert abs(avg - target) / target < 0.30 + + +def test_system_constraint_update_and_select() -> None: + """Test select_constraint and select_sub_constraint for SystemConstraint.""" + # Create a FixCom constraint for systems 0, 1, 2 + constraint = FixCom([0, 1, 2]) + + # Test select_constraint with system_mask + # Keep systems 0 and 2 (drop system 1) + atom_mask = torch.ones(10, dtype=torch.bool) + system_mask = torch.tensor([True, False, True], dtype=torch.bool) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) + + # System indices should be renumbered: [0, 2] -> [0, 1] + assert torch.all(updated_constraint.system_idx == torch.tensor([0, 1])) + + # Test select_sub_constraint + # Select system 1 from the original constraint + constraint = FixCom([0, 1, 2]) + atom_idx = torch.arange(5, 10) # Atoms for a specific system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with system_idx = [0] (renumbered from 1) + assert sub_constraint is not None + assert torch.all(sub_constraint.system_idx == torch.tensor([0])) + + # Test when system is not in constraint + constraint = FixCom([0, 2]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_atom_indexed_constraint_update_and_select() -> None: + """Test select_constraint and select_sub_constraint for AtomConstraint.""" + # Create a FixAtoms constraint for atoms 0, 1, 5, 8 + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + + # Test select_constraint with atom_mask + # Keep atoms 0, 1, 2, 3, 5, 6, 7, 8 (drop atoms 4) + atom_mask = torch.tensor( + [True, True, True, True, False, True, True, True, True], dtype=torch.bool + ) + system_mask = torch.ones(2, dtype=torch.bool) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) + + # Atom indices should be renumbered: + # Original: [0, 1, 5, 8] + # After dropping atom 4: [0, 1, 4, 7] (indices shift down by 1 after index 4) + assert torch.all(updated_constraint.atom_idx == torch.tensor([0, 1, 4, 7])) + + # Test select_sub_constraint + # Select atoms that belong to a specific system + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + atom_idx = torch.tensor([0, 1, 2, 3, 4]) # Atoms for first system + sys_idx = 0 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with only atoms 0, 1 (within atom_idx range) + # Renumbered to start from 0 + assert sub_constraint is not None + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 1])) + + # Test with different atom range + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + atom_idx = torch.tensor([5, 6, 7, 8, 9]) # Atoms for second system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with atoms 5, 8 renumbered to [0, 3] + assert sub_constraint is not None + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 3])) + + # Test when no atoms in range + constraint = FixAtoms(atom_idx=[0, 1]) + atom_idx = torch.tensor([5, 6, 7, 8]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test merge_constraints combines constraints from multiple systems.""" + # Split the double system state + s1, s2 = ar_double_sim_state.split() + n_atoms_s1 = s1.n_atoms + n_atoms_s2 = s2.n_atoms + + # Create constraints for each system + # System 1: Fix atoms 0, 1 and fix COM for system 0 + s1_constraints = [ + FixAtoms(atom_idx=[0, 1]), + FixCom([0]), + ] + + # System 2: Fix atoms 2, 3 and fix COM for system 0 + s2_constraints = [ + FixAtoms(atom_idx=[2, 3]), + FixCom([0]), + ] + + # Merge constraints + constraint_lists = [s1_constraints, s2_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Should have 2 constraints: one FixAtoms and one FixCom + assert len(merged_constraints) == 2 + + # Find FixAtoms and FixCom in merged list + fix_atoms = None + fix_com = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + elif isinstance(constraint, FixCom): + fix_com = constraint + + assert fix_atoms is not None + assert fix_com is not None + + # FixAtoms should have indices [0, 1] from s1 and [2+n_atoms_s1, 3+n_atoms_s1] from s2 + expected_atom_indices = torch.tensor([0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1]) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) + + # FixCom should have system_idx [0, 1] (one for each original system) + expected_system_indices = torch.tensor([0, 1]) + assert torch.all(fix_com.system_idx == expected_system_indices) + + # Test with three systems + s3 = s1.clone() + s3_constraints = [FixAtoms(atom_idx=[0])] + constraint_lists = [s1_constraints, s2_constraints, s3_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2, s3.n_atoms]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Find FixAtoms + fix_atoms = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + break + + assert fix_atoms is not None + # Should include atoms from all three systems with proper offsets + expected_atom_indices = torch.tensor( + [0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1, 0 + n_atoms_s1 + n_atoms_s2] + ) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 16565b73..fe9e8c3f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -9,6 +9,8 @@ import torch_sim as ts import torch_sim.transforms as ft from tests.conftest import DEVICE, DTYPE +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.units import MetalUnits def test_inverse_box_scalar() -> None: @@ -1301,3 +1303,65 @@ def test_build_linked_cell_neighborhood_basic() -> None: # Verify that there are neighbors from both batches assert torch.any(system_mapping == 0) assert torch.any(system_mapping == 1) + + +def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + n_steps = 50 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + # Same cell + state = ts.nvt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, state.cell, state.system_idx) + for positions in positions + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) + + # Different cell + state = ts.npt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42, dt=dt + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + cells = [state.cell.detach().clone()] + for _step in range(n_steps): + state = ts.npt_langevin_step( + model=lj_model, + state=state, + dt=dt, + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE, device=DEVICE), + ) + positions.append(state.positions.detach().clone()) + cells.append(state.cell.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, cell, state.system_idx) + for positions, cell in zip(positions, cells, strict=True) + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index f632cbfa..8af19906 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -8,6 +8,7 @@ import torch_sim as ts from torch_sim import ( autobatching, + constraints, elastic, io, math, diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py new file mode 100644 index 00000000..d352d539 --- /dev/null +++ b/torch_sim/constraints.py @@ -0,0 +1,610 @@ +"""Constraints for molecular dynamics simulations. + +This module implements constraints inspired by ASE's constraint system, +adapted for the torch-sim framework with support for batched operations +and PyTorch tensors. + +The constraints affect degrees of freedom counting and modify forces, momenta, +and positions during MD simulations. +""" + +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Self + +import torch + + +if TYPE_CHECKING: + from torch_sim.state import SimState + + +class Constraint(ABC): + """Base class for all constraints in torch-sim. + + This is the abstract base class that all constraints must inherit from. + It defines the interface that constraints must implement to work with + the torch-sim MD system. + """ + + @abstractmethod + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get the number of degrees of freedom removed by this constraint. + + Args: + state: The simulation state + + Returns: + Number of degrees of freedom removed by this constraint + """ + + @abstractmethod + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to satisfy the constraint. + + This method should modify new_positions in-place to ensure the + constraint is satisfied. + + Args: + state: Current simulation state + new_positions: Proposed new positions to be adjusted + """ + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Adjust momenta to satisfy the constraint. + + This method should modify momenta in-place to ensure the constraint + is satisfied. By default, it calls adjust_forces with the momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted + """ + # Default implementation: treat momenta like forces + self.adjust_forces(state, momenta) + + @abstractmethod + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Adjust forces to satisfy the constraint. + + This method should modify forces in-place to ensure the constraint + is satisfied. + + Args: + state: Current simulation state + forces: Forces to be adjusted + """ + + @abstractmethod + def select_constraint( + self, atom_mask: torch.Tensor, system_mask: torch.Tensor + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + + @abstractmethod + def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + + Returns: + Constraint for the given atom and system index + """ + + +def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + cumsum_atom_mask = torch.cumsum(~mask, dim=0) + new_indices = idx - cumsum_atom_mask[idx] + mask_indices = torch.where(mask)[0] + drop_indices = ~torch.isin(idx, mask_indices) + return new_indices[~drop_indices] + + +class AtomConstraint(Constraint): + """Base class for constraints that act on specific atom indices. + + This class provides common functionality for constraints that operate + on a subset of atoms, identified by their indices. + """ + + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize indexed constraint. + + Args: + atom_idx: Indices of atoms to constrain. Can be a tensor or list of integers. + atom_mask: Boolean mask for atoms to constrain. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if atom_idx is not None and atom_mask is not None: + raise ValueError("Provide either atom_idx or atom_mask, not both.") + if atom_mask is not None: + atom_mask = torch.as_tensor(atom_mask) + atom_idx = torch.where(atom_mask)[0] + + # Convert to tensor if needed + atom_idx = torch.as_tensor(atom_idx) + + # Ensure we have the right shape and type + atom_idx = torch.atleast_1d(atom_idx) + if atom_idx.ndim != 1: + raise ValueError( + "atom_idx has wrong number of dimensions. " + f"Got {atom_idx.ndim}, expected ndim <= 1" + ) + + if torch.is_floating_point(atom_idx): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={atom_idx.dtype}" + ) + + self.atom_idx = atom_idx.long() + + def get_indices(self) -> torch.Tensor: + """Get the constrained atom indices. + + Returns: + Tensor of atom indices affected by this constraint + """ + return self.atom_idx.clone() + + def select_constraint( + self, + atom_mask: torch.Tensor, + system_mask: torch.Tensor, # noqa: ARG002 + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + indices = self.atom_idx.clone() + indices = _mask_constraint_indices(indices, atom_mask) + if len(indices) == 0: + return None + return type(self)(indices) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, + sys_idx: int, # noqa: ARG002 + ) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + mask = torch.isin(self.atom_idx, atom_idx) + masked_indices = self.atom_idx[mask] + new_atom_idx = masked_indices - atom_idx.min() + if len(new_atom_idx) == 0: + return None + return type(self)(new_atom_idx) + + +class SystemConstraint(Constraint): + """Base class for constraints that act on specific system indices. + + This class provides common functionality for constraints that operate + on a subset of systems, identified by their indices. + """ + + def __init__( + self, + system_idx: torch.Tensor | list[int] | None = None, + system_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize indexed constraint. + + Args: + system_idx: Indices of systems to constrain. + Can be a tensor or list of integers. + system_mask: Boolean mask for systems to constrain. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if system_idx is not None and system_mask is not None: + raise ValueError("Provide either system_idx or system_mask, not both.") + if system_mask is not None: + system_idx = torch.as_tensor(system_idx) + system_idx = torch.where(system_mask)[0] + + # Convert to tensor if needed + system_idx = torch.as_tensor(system_idx) + + # Ensure we have the right shape and type + system_idx = torch.atleast_1d(system_idx) + if system_idx.ndim != 1: + raise ValueError( + "system_idx has wrong number of dimensions. " + f"Got {system_idx.ndim}, expected ndim <= 1" + ) + + # Check for duplicates + if len(system_idx) != len(torch.unique(system_idx)): + raise ValueError("Duplicate system indices found in SystemConstraint.") + + if torch.is_floating_point(system_idx): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={system_idx.dtype}" + ) + + self.system_idx = system_idx.long() + + def select_constraint( + self, + atom_mask: torch.Tensor, # noqa: ARG002 + system_mask: torch.Tensor, + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + system_idx = self.system_idx.clone() + system_idx = _mask_constraint_indices(system_idx, system_mask) + if len(system_idx) == 0: + return None + return type(self)(system_idx) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, # noqa: ARG002 + sys_idx: int, + ) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None + + +def merge_constraints( + constraint_lists: list[list[AtomConstraint | SystemConstraint]], + num_atoms_per_state: torch.Tensor, +) -> list[Constraint]: + """Merge constraints from multiple systems into a single list of constraints. + + Args: + constraint_lists: List of lists of constraints + num_atoms_per_state: Number of atoms per system + + Returns: + List of merged constraints + """ + from collections import defaultdict + + cumsum_atoms = torch.cumsum(num_atoms_per_state, dim=0) - num_atoms_per_state[0] + + # aggregate updated constraint indices by constraint type + constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) + for i, constraint_list in enumerate(constraint_lists): + for constraint in constraint_list: + if isinstance(constraint, AtomConstraint): + idxs = constraint.atom_idx + offset = cumsum_atoms[i] + elif isinstance(constraint, SystemConstraint): + idxs = constraint.system_idx + offset = i + else: + raise NotImplementedError( + f"Constraint type {type(constraint)} is not implemented" + ) + constraint_indices[type(constraint)].append(idxs + offset) + + return [ + constraint_type(torch.cat(idxs)) + for constraint_type, idxs in constraint_indices.items() + ] + + +class FixAtoms(AtomConstraint): + """Constraint that fixes specified atoms in place. + + This constraint prevents the specified atoms from moving by: + - Resetting their positions to original values + - Setting their forces to zero + - Removing 3 degrees of freedom per fixed atom + + Examples: + Fix atoms with indices [0, 1, 2]: + >>> constraint = FixAtoms(atom_idx=[0, 1, 2]) + + Fix atoms using a boolean mask: + >>> mask = torch.tensor([True, True, True, False, False]) + >>> constraint = FixAtoms(mask=mask) + """ + + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize FixAtoms constraint and check for duplicate indices.""" + super().__init__(atom_idx=atom_idx, atom_mask=atom_mask) + # Check duplicates + if len(self.atom_idx) != len(torch.unique(self.atom_idx)): + raise ValueError("Duplicate atom indices found in FixAtoms constraint.") + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Each fixed atom removes 3 degrees of freedom (x, y, z motion). + + Args: + state: Simulation state + + Returns: + Number of degrees of freedom removed (3 * number of fixed atoms) + """ + fixed_atoms_system_idx = torch.bincount( + state.system_idx[self.atom_idx], minlength=state.n_systems + ) + return 3 * fixed_atoms_system_idx + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Reset positions of fixed atoms to their current values. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + new_positions[self.atom_idx] = state.positions[self.atom_idx] + + def adjust_forces( + self, + state: SimState, # noqa: ARG002 + forces: torch.Tensor, + ) -> None: + """Set forces on fixed atoms to zero. + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + forces[self.atom_idx] = 0.0 + + def __repr__(self) -> str: + """String representation of the constraint.""" + if len(self.atom_idx) <= 10: + indices_str = self.atom_idx.tolist() + else: + indices_str = f"{self.atom_idx[:5].tolist()}...{self.atom_idx[-5:].tolist()}" + return f"FixAtoms(indices={indices_str})" + + +class FixCom(SystemConstraint): + """Constraint that fixes the center of mass of all atoms per system. + + This constraint prevents the center of mass from moving by: + - Adjusting positions to maintain center of mass position + - Removing center of mass velocity from momenta + - Adjusting forces to remove net force + - Removing 3 degrees of freedom (center of mass translation) + + The constraint is applied to all atoms in the system. + """ + + coms: torch.Tensor | None = None + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Fixing center of mass removes 3 degrees of freedom (x, y, z translation). + + Args: + state: Simulation state + + Returns: + Always returns 3 (center of mass translation degrees of freedom) + """ + affected_systems = torch.zeros(state.n_systems, dtype=torch.long) + affected_systems[self.system_idx] = 1 + return 3 * affected_systems + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to maintain center of mass position. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + dtype = state.positions.dtype + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + if self.coms is None: + self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * state.positions, + ) + self.coms /= system_mass.unsqueeze(-1) + + new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * new_positions, + ) + new_com /= system_mass.unsqueeze(-1) + displacement = torch.zeros(state.n_systems, 3, dtype=dtype) + displacement[self.system_idx] = ( + -new_com[self.system_idx] + self.coms[self.system_idx] + ) + new_positions += displacement[state.system_idx] + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Remove center of mass velocity from momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted in-place + """ + # Compute center of mass momenta + dtype = momenta.dtype + com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + momenta, + ) + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + velocity_com = com_momenta / system_mass.unsqueeze(-1) + velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype) + velocity_change[self.system_idx] = velocity_com[self.system_idx] + momenta -= velocity_change[state.system_idx] * state.masses.unsqueeze(-1) + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Remove net force to prevent center of mass acceleration. + + This implements the constraint from Eq. (3) and (7) in + https://doi.org/10.1021/jp9722824 + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + dtype = state.positions.dtype + system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, + state.system_idx, + torch.square(state.masses), + ) + lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + forces * state.masses.unsqueeze(-1), + ) + lmd /= system_square_mass.unsqueeze(-1) + forces_change = torch.zeros(state.n_systems, 3, dtype=dtype) + forces_change[self.system_idx] = lmd[self.system_idx] + forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) + + def __repr__(self) -> str: + """String representation of the constraint.""" + return f"FixCom(system_idx={self.system_idx})" + + +def count_degrees_of_freedom( + state: SimState, constraints: list[Constraint] | None = None +) -> int: + """Count the total degrees of freedom in a system with constraints. + + This function calculates the total number of degrees of freedom by starting + with the unconstrained count (n_atoms * 3) and subtracting the degrees of + freedom removed by each constraint. + + Args: + state: Simulation state + constraints: List of active constraints (optional) + + Returns: + Total number of degrees of freedom + """ + # Start with unconstrained DOF + total_dof = state.n_atoms * 3 + + # Subtract DOF removed by constraints + if constraints is not None: + for constraint in constraints: + total_dof -= constraint.get_removed_dof(state) + + return max(0, total_dof) # Ensure non-negative + + +def check_no_index_out_of_bounds( + indices: torch.Tensor, max_state_indices: int, constraint_name: str +) -> None: + """Check that constraint indices are within bounds of the state.""" + if (len(indices) > 0) and (indices.max() >= max_state_indices): + raise ValueError( + f"Constraint {constraint_name} has indices up to " + f"{indices.max()}, but state only has {max_state_indices} " + "atoms" + ) + + +def validate_constraints(constraints: list[Constraint], state: SimState) -> None: + """Validate constraints for potential issues and incompatibilities. + + This function checks for: + 1. Overlapping atom indices across multiple constraints + 2. AtomConstraints spanning multiple systems (requires state) + 3. Mixing FixCom with other constraints (warning only) + + Args: + constraints: List of constraints to validate + state: SimState to check against + + Raises: + ValueError: If constraints are invalid or span multiple systems + + Warns: + UserWarning: If constraints may lead to unexpected behavior + """ + if not constraints: + return + + indexed_constraints = [] + has_com_constraint = False + + for constraint in constraints: + if isinstance(constraint, AtomConstraint): + indexed_constraints.append(constraint) + + # Validate that atom indices exist in state if provided + check_no_index_out_of_bounds( + constraint.atom_idx, state.n_atoms, type(constraint).__name__ + ) + elif isinstance(constraint, SystemConstraint): + check_no_index_out_of_bounds( + constraint.system_idx, state.n_systems, type(constraint).__name__ + ) + + if isinstance(constraint, FixCom): + has_com_constraint = True + + # Check for overlapping atom indices + if len(indexed_constraints) > 1: + all_indices = torch.cat([c.atom_idx for c in indexed_constraints]) + unique_indices = torch.unique(all_indices) + if len(unique_indices) < len(all_indices): + warnings.warn( + "Multiple constraints are acting on the same atoms. " + "This may lead to unexpected behavior.", + UserWarning, + stacklevel=3, + ) + + # Warn about COM constraint with fixed atoms + if has_com_constraint and indexed_constraints: + warnings.warn( + "Using FixCom together with other constraints may lead to " + "unexpected behavior. The center of mass constraint is applied " + "to all atoms, including those that may be constrained by other means.", + UserWarning, + stacklevel=3, + ) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 1196468d..a604bc94 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -5,9 +5,8 @@ import torch -from torch_sim import transforms from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_kT, calc_temperature +from torch_sim.quantities import calc_kT from torch_sim.state import SimState from torch_sim.units import MetalUnits @@ -57,6 +56,12 @@ def velocities(self) -> torch.Tensor: """ return self.momenta / self.masses.unsqueeze(-1) + def set_constrained_momenta(self, new_momenta: torch.Tensor) -> None: + """Set new momenta, applying any constraints as needed.""" + for constraint in self.constraints: + constraint.adjust_momenta(self, new_momenta) + self.momenta = new_momenta + def calc_temperature( self, units: MetalUnits = MetalUnits.temperature ) -> torch.Tensor: @@ -68,19 +73,13 @@ def calc_temperature( Returns: torch.Tensor: Calculated temperature """ - return calc_temperature( - masses=self.masses, - momenta=self.momenta, - system_idx=self.system_idx, - dof_per_system=self.get_number_of_degrees_of_freedom(), - units=units, - ) + return self.calc_kT() / units.temperature def calc_kT(self) -> torch.Tensor: # noqa: N802 """Calculate kT from momenta, masses, and system indices. Returns: - torch.Tensor: Calculated kT + torch.Tensor: Calculated kT in energy units """ return calc_kT( masses=self.masses, @@ -167,7 +166,7 @@ def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_momenta = state.momenta + state.forces * dt - state.momenta = new_momenta + state.set_constrained_momenta(new_momenta) return state @@ -187,17 +186,7 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt - - if state.pbc.any(): - # Split positions and cells by system - new_positions = transforms.pbc_wrap_batched( - new_positions, - state.cell, - state.system_idx, - state.pbc, - ) - - state.positions = new_positions + state.set_constrained_positions(new_positions) return state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 30003524..f9110ccf 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1,5 +1,6 @@ """Implementations of NPT integrators.""" +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -364,14 +365,7 @@ def _npt_langevin_position_step( ) # Update positions with all contributions - state.positions = c_1 + c_2.unsqueeze(-1) * c_3 - - # Apply periodic boundary conditions if needed - if state.pbc.any(): - state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx, state.pbc - ) - + state.set_constrained_positions(c_1 + c_2.unsqueeze(-1) * c_3) return state @@ -435,7 +429,8 @@ def _npt_langevin_velocity_step( # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 - state.momenta = new_velocities * state.masses.unsqueeze(-1) + # Apply constraints. + state.set_constrained_momenta(new_velocities * state.masses.unsqueeze(-1)) return state @@ -565,6 +560,9 @@ def npt_langevin_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) + if not isinstance(state, SimState): + state = SimState(**state) + if alpha.ndim == 0: alpha = alpha.expand(state.n_systems) if cell_alpha.ndim == 0: @@ -572,9 +570,6 @@ def npt_langevin_init( if b_tau.ndim == 0: b_tau = b_tau.expand(state.n_systems) - if not isinstance(state, SimState): - state = SimState(**state) - # Get model output to initialize forces and stress model_output = model(state) @@ -606,6 +601,16 @@ def npt_langevin_init( ) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Langevin dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create the initial state return NPTLangevinState( positions=state.positions, @@ -625,6 +630,7 @@ def npt_langevin_init( cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, + _constraints=state.constraints, ) @@ -1027,14 +1033,7 @@ def _npt_nose_hoover_exp_iL1( # noqa: N802 state.positions * (torch.exp(x_expanded) - 1) + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded ) - new_positions = state.positions + new_positions - - # Apply periodic boundary conditions if needed - if state.pbc.any(): - return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx, pbc=state.pbc - ) - return new_positions + return state.positions + new_positions def _npt_nose_hoover_exp_iL2( # noqa: N802 @@ -1244,7 +1243,7 @@ def _npt_nose_hoover_inner_step( # Update particle positions and forces positions = _npt_nose_hoover_exp_iL1(state, state.velocities, cell_velocities, dt) - state.positions = positions + state.set_constrained_positions(positions) state.cell = cell model_output = model(state) @@ -1265,8 +1264,8 @@ def _npt_nose_hoover_inner_step( cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state - state.positions = positions - state.momenta = momenta + state.set_constrained_positions(positions) + state.set_constrained_momenta(momenta) state.forces = model_output["forces"] state.energy = model_output["energy"] state.cell_position = cell_position @@ -1411,6 +1410,16 @@ def npt_nose_hoover_init( forces = model_output["forces"] energy = model_output["energy"] + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Nosé Hoover dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create initial state return NPTNoseHooverState( positions=state.positions, @@ -1430,6 +1439,7 @@ def npt_nose_hoover_init( thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index d3773b3c..b4db4e6c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -67,6 +67,7 @@ def nve_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 4bbcdb63..d773a922 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -72,7 +72,7 @@ def _ou_step( c1.unsqueeze(-1) * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise ) - state.momenta = new_momenta + state.set_constrained_momenta(new_momenta) return state @@ -118,7 +118,6 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( positions=state.positions, momenta=momenta, @@ -129,6 +128,7 @@ def nvt_langevin_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + _constraints=state.constraints, ) @@ -328,6 +328,7 @@ def nvt_nose_hoover_init( system_idx=state.system_idx, chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, # Store the chain functions + _constraints=state.constraints, ) @@ -372,7 +373,7 @@ def nvt_nose_hoover_step( # First half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_constrained_momenta(momenta) # Full velocity Verlet step state = velocity_verlet(state=state, dt=dt, model=model) @@ -385,7 +386,7 @@ def nvt_nose_hoover_step( # Second half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_constrained_momenta(momenta) state.chain = chain return state diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 6f4dd5bf..fc88d7f5 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -214,6 +214,7 @@ def swap_mc_init( system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), + _constraints=state.constraints, ) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 0a689432..5583cc0c 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -80,6 +80,7 @@ def fire_init( "cell": state.cell.clone(), "atomic_numbers": state.atomic_numbers.clone(), "system_idx": state.system_idx.clone(), + "_constraints": state.constraints, "pbc": state.pbc, # Optimization state "forces": forces, @@ -211,13 +212,13 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update - state.positions = state.positions + atom_wise_dt * state.velocities + state.set_constrained_positions(state.positions + atom_wise_dt * state.velocities) # Cell position updates are handled in the velocity update step above # Get new forces and energy model_output = model(state) - state.forces = model_output["forces"] + state.set_constrained_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] @@ -419,7 +420,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = ( + state.set_constrained_positions( torch.linalg.solve( cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) @@ -454,16 +455,18 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = torch.bmm( - state.positions.unsqueeze(1), - new_deform_grad[state.system_idx].transpose(-2, -1), - ).squeeze(1) + state.set_constrained_positions( + torch.bmm( + state.positions.unsqueeze(1), + new_deform_grad[state.system_idx].transpose(-2, -1), + ).squeeze(1) + ) else: - state.positions = state.positions + dr_atom + state.set_constrained_positions(state.positions + dr_atom) # Get new forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + state.set_constrained_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index bfdfcf3f..d6bf52f5 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -61,6 +61,7 @@ def gradient_descent_init( "pbc": state.pbc, "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, + "_constraints": state.constraints, } if cell_filter is not None: # Create cell optimization state @@ -107,7 +108,7 @@ def gradient_descent_step( atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions - state.positions = state.positions + atom_lr * state.forces + state.set_constrained_positions(state.positions + atom_lr * state.forces) # Update cell if using cell optimization if isinstance(state, CellOptimState): @@ -117,7 +118,7 @@ def gradient_descent_step( # Get updated forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + state.set_constrained_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 2ab530db..358ea6b2 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -23,6 +23,16 @@ class OptimState(SimState): _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy", "stress"} # noqa: SLF001 + def set_constrained_forces(self, new_forces: torch.Tensor) -> None: + """Set new forces in the optimization state.""" + for constraint in self._constraints: + constraint.adjust_forces(self, new_forces) + self.forces = new_forces + + def __post_init__(self) -> None: + """Post-initialization to ensure SimState setup.""" + self.set_constrained_forces(self.forces) + @dataclass(kw_only=True) class FireState(OptimState): diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 30c8e690..aa2b4a2a 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -57,7 +57,7 @@ def calc_kT( # noqa: N802 # Count degrees of freedom per system system_sizes = torch.bincount(system_idx) if dof_per_system is None: - dof_per_system = system_sizes * squared_term.shape[-1] + dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions # Calculate temperature per system system_sums = torch.segment_reduce( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b7fe73cb..5f4e09ee 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -168,7 +168,6 @@ def integrate[T: SimState]( # noqa: C901 f"integrator must be key from Integrator or a tuple of " f"(init_func, step_func), got {type(integrator)}" ) - # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( initial_state, model, autobatcher=autobatcher diff --git a/torch_sim/state.py b/torch_sim/state.py index bb0bf1e7..e97101d7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -23,6 +23,8 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure +from torch_sim.constraints import Constraint, merge_constraints, validate_constraints + @dataclass class SimState: @@ -53,6 +55,8 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. + constraints (list["Constraint"] | None): List of constraints applied to the + system. Constraints affect degrees of freedom and modify positions. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -87,6 +91,7 @@ class SimState: charge: torch.Tensor | None = field(default=None) spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor | None = field(default=None) + _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 if TYPE_CHECKING: @@ -138,6 +143,9 @@ def __post_init__(self) -> None: # noqa: C901 if not torch.all(counts == torch.bincount(initial_system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.constraints: + validate_constraints(self.constraints, state=self) + if self.charge is None: self.charge = torch.zeros( self.n_systems, device=self.device, dtype=self.dtype @@ -221,6 +229,7 @@ def attributes(self) -> dict[str, torch.Tensor]: for attr in self._atom_attributes | self._system_attributes | self._global_attributes + | {"_constraints"} } @property @@ -251,6 +260,46 @@ def row_vector_cell(self, value: torch.Tensor) -> None: """ self.cell = value.mT + def set_constrained_positions(self, new_positions: torch.Tensor) -> None: + """Set the positions and apply constraints if they exist. + + Args: + new_positions: New positions tensor with shape (n_atoms, 3) + """ + # Apply constraints if they exist + for constraint in self.constraints: + constraint.adjust_positions(self, new_positions) + self.positions = new_positions + + @property + def constraints(self) -> list[Constraint]: + """Get the constraints for the SimState. + + Returns: + list["Constraint"]: List of constraints applied to the system. + """ + return self._constraints + + @constraints.setter + def constraints(self, constraints: list[Constraint] | Constraint) -> None: + """Set the constraints for the SimState. + + Args: + constraints (list["Constraint"] | None): List of constraints to apply. + If None, no constraints are applied. + + Raises: + ValueError: If constraints are invalid or span multiple systems + """ + # check it is a list + if isinstance(constraints, Constraint): + constraints = [constraints] + + # Validate new constraints before adding + validate_constraints(constraints, state=self) + + self._constraints = constraints + def set_cell( self, cell: torch.Tensor, @@ -285,7 +334,18 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: of freedom, minus any degrees removed by constraints. """ # Start with unconstrained DOF: 3 degrees per atom - return 3 * self.n_atoms_per_system + dof_per_system = 3 * self.n_atoms_per_system + + # Subtract DOF removed by constraints + if self.constraints is not None: + for constraint in self.constraints: + removed_dof = constraint.get_removed_dof(self) + dof_per_system -= removed_dof + + # Ensure non-negative DOF + if (dof_per_system <= 0).any(): + raise ValueError("Degrees of freedom cannot be zero or negative") + return dof_per_system def clone(self) -> Self: """Create a deep copy of the SimState. @@ -651,7 +711,7 @@ def _state_to_device[T: SimState]( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) # type: ignore[invalid-return-type] + return type(state)(**attrs) def get_attrs_for_scope( @@ -702,6 +762,18 @@ def _filter_attrs_by_mask( # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) + # take into account constraints that are AtomConstraint + filtered_attrs["_constraints"] = [ + constraint.select_constraint(atom_mask, system_mask) + for constraint in copy.deepcopy(state.constraints) + ] + # Remove any None constraints resulting from selection + filtered_attrs["_constraints"] = [ + constraint + for constraint in filtered_attrs["_constraints"] + if constraint is not None + ] + # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": @@ -723,6 +795,7 @@ def _filter_attrs_by_mask( dtype=attr_value.dtype, ) filtered_attrs[attr_name] = new_system_idxs + else: filtered_attrs[attr_name] = attr_value[atom_mask] @@ -749,7 +822,7 @@ def _split_state[T: SimState](state: T) -> list[T]: list[SimState]: A list of SimState objects, each containing a single system """ - system_sizes = torch.bincount(state.system_idx).tolist() + system_sizes = state.n_atoms_per_system.tolist() split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): @@ -768,6 +841,8 @@ def _split_state[T: SimState](state: T) -> list[T]: # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) + zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) + cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system @@ -787,6 +862,15 @@ def _split_state[T: SimState](state: T) -> list[T]: # Add the global attributes **global_attrs, } + + atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) + new_constraints = [ + new_constraint + for constraint in state.constraints + if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) + ] + + system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] return states @@ -919,6 +1003,7 @@ def concatenate_states[T: SimState]( # noqa: C901 per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 + num_atoms_per_state = [] # Process all states in a single pass for state in states: @@ -941,6 +1026,8 @@ def concatenate_states[T: SimState]( # noqa: C901 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + num_atoms_per_state.append(state.n_atoms) + system_offset += num_systems # Concatenate collected tensors @@ -958,8 +1045,14 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Merge constraints + constraint_lists = [state.constraints for state in states] + constraints = merge_constraints( + constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) + ) + # Create a new instance of the same class - return state_class(**concatenated) + return state_class(**concatenated, _constraints=constraints) def initialize_state( diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 2ab4ab2e..28acb977 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1175,3 +1175,97 @@ def safe_mask( """ masked = torch.where(mask, operand, torch.zeros_like(operand)) return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + + +def unwrap_positions( + positions: torch.Tensor, cells: torch.Tensor, system_idx: torch.Tensor +) -> torch.Tensor: + """Vectorized unwrapping for multiple systems without explicit loops. + + Parameters + ---------- + positions : (T, N_tot, 3) + Wrapped cartesian positions for all systems concatenated. + cells : (n_systems, 3, 3) or (T, n_systems, 3, 3) + Box matrices, constant or time-dependent. + system_idx : (N_tot,) + For each atom, which system it belongs to (0..n_systems-1). + + Returns: + ------- + unwrapped_pos : (T, N_tot, 3) + Unwrapped cartesian positions. + """ + # -- Constant boxes per system + if cells.ndim == 3: + inv_box = torch.inverse(cells) # (n_systems, 3, 3) + + # Map each atom to its system's box + inv_box_atoms = inv_box[system_idx] # (N, 3, 3) + box_atoms = cells[system_idx] # (N, 3, 3) + + # Compute fractional coordinates + frac = torch.einsum("tni,nij->tnj", positions, inv_box_atoms) + + # Fractional displacements and unwrap + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + # Back to Cartesian + dcart = torch.einsum("tni,nij->tnj", dfrac, box_atoms) + + # -- Time-dependent boxes per system + elif cells.ndim == 4: + inv_box = torch.inverse(cells) # (T, n_systems, 3, 3) + + # Gather each atom's box per frame efficiently + inv_box_atoms = inv_box[:, system_idx] # (T, N, 3, 3) + box_atoms = cells[:, system_idx] # (T, N, 3, 3) + + # Compute fractional coordinates per frame + frac = torch.einsum("tni,tnij->tnj", positions, inv_box_atoms) + + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + dcart = torch.einsum("tni,tnij->tnj", dfrac, box_atoms[:-1]) + + else: + raise ValueError("box must have shape (n_systems,3,3) or (T,n_systems,3,3)") + + # Cumulative reconstruction + unwrapped = torch.empty_like(positions) + unwrapped[0] = positions[0] + unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] + + return unwrapped + + +def get_centers_of_mass( + positions: torch.Tensor, + masses: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, +) -> torch.Tensor: + """Compute the centers of mass for each structure in the simulation state.s. + + Args: + positions (torch.Tensor): Atomic positions of shape (N, 3). + masses (torch.Tensor): Atomic masses of shape (N,). + system_idx (torch.Tensor): System indices for each atom of shape (N,). + n_systems (int): Total number of systems. + + Returns: + torch.Tensor: A tensor of shape (n_structures, 3) containing + the center of mass coordinates for each structure. + """ + coms = torch.zeros((n_systems, 3), dtype=positions.dtype).scatter_add_( + 0, + system_idx.unsqueeze(-1).expand(-1, 3), + masses.unsqueeze(-1) * positions, + ) + system_masses = torch.zeros((n_systems,), dtype=positions.dtype).scatter_add_( + 0, system_idx, masses + ) + coms /= system_masses.unsqueeze(-1) + return coms