Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
4157a20
fix:orb squeeze incorrect energy shape
thomasloux Sep 18, 2025
646ddf5
Merge branch 'TorchSim:main' into main
thomasloux Oct 8, 2025
69ee796
Merge branch 'TorchSim:main' into main
thomasloux Oct 10, 2025
38c6138
First draft constraints
thomasloux Oct 17, 2025
6eb3d78
change base class name for constraint
thomasloux Oct 21, 2025
c630f39
remove useless methods
thomasloux Oct 21, 2025
bfdf6de
Merge branch 'main' into features/constraints
thomasloux Oct 21, 2025
f5459b9
change redundant definition
thomasloux Oct 21, 2025
6b2710e
constraint to optimizer, compatibility with state manipulation
thomasloux Oct 23, 2025
c955273
Merge branch 'features/constraints' of https://github.com/thomasloux/…
thomasloux Oct 23, 2025
7d63069
test temperature, adapt calc_kt for reduced degrees of freedom
thomasloux Oct 23, 2025
e1388fd
Merge branch 'main' into pr/thomasloux/294
janosh Nov 10, 2025
ad4fa0a
fix typo + unreleased changelog entry
janosh Nov 10, 2025
8beb9d9
renamed validate_constraints now called in SimState.add_constraints a…
janosh Nov 10, 2025
c577e1d
tests for constraint validation warnings and errors
janosh Nov 10, 2025
9cfe52b
refactor to use getter setter and _constraints
thomasloux Nov 10, 2025
be30d45
remove edge case slice(None)
thomasloux Nov 10, 2025
33d6025
new API (remove slice(None) and _constraint as private var
thomasloux Nov 10, 2025
399fbfd
correct get_centers_of_mass
thomasloux Nov 10, 2025
b31ba80
add warnings for npt dynamics
thomasloux Nov 10, 2025
1483977
simplify state updating in _filter_attrs_by_mask
orionarcher Nov 18, 2025
3c267eb
simplify _split_state with select_sub_constraint function
orionarcher Nov 19, 2025
06400e1
make constraint handling more modular with methods, merge states curr…
orionarcher Nov 19, 2025
4081973
No longer allow initializing FixCom() or FixAtoms() with empty arguments
orionarcher Nov 21, 2025
35749c3
vibe code and verify some tests
orionarcher Nov 21, 2025
8c067fb
Merge pull request #1 from TorchSim/contraints
thomasloux Nov 24, 2025
0688bfe
rename update_constraint to select_constraint, remove None Constraint…
thomasloux Nov 24, 2025
be161e3
change to _constraint name
thomasloux Nov 24, 2025
6afab52
revert to previous return as it actually also change the device/dtype…
thomasloux Nov 24, 2025
4aa1447
use post_init to enforce constraint on forces
thomasloux Nov 24, 2025
e61e452
constraint is not a global_attrs anymore
thomasloux Nov 24, 2025
8144ed6
increase slightly steps to test FixCom
thomasloux Nov 24, 2025
940827b
add _constraint to attributes so that it's kept when cloning simstate
thomasloux Nov 24, 2025
1cbd0b0
compute com for all and only subselect depending on system_idx, remov…
thomasloux Nov 24, 2025
6e09895
remove comments
thomasloux Nov 24, 2025
7d8890f
remove comment and raise if dof is negative
thomasloux Nov 24, 2025
be55c9b
remove unwrap_pos and add dummy state to test for validate_constraints
thomasloux Nov 24, 2025
33c6e92
ruff happy, simplify function
thomasloux Nov 24, 2025
d99a1a7
test for unwrap_positions
thomasloux Nov 24, 2025
50d566f
Merge branch 'main' into features/constraints
thomasloux Nov 24, 2025
eb26975
silence ruff
thomasloux Nov 24, 2025
c15a012
modify args names
thomasloux Nov 24, 2025
87644fa
reduce precision for test_unwrap
thomasloux Nov 24, 2025
95857b0
updates names
thomasloux Nov 24, 2025
7022df2
remove einsteinModel (not for this PR)
thomasloux Nov 24, 2025
07624f0
rename var and add mask
thomasloux Nov 26, 2025
0940919
remove comment now that a warning is set up for NPT MD with constraints
thomasloux Nov 26, 2025
b49e309
Add duplicate error in FixAtoms (subclass of AtomConstraint will hand…
thomasloux Nov 26, 2025
65fd0cf
rename args FixAtoms tests
thomasloux Nov 26, 2025
fee207f
system_idx for constraint must be dim 1
thomasloux Nov 26, 2025
99e8ad3
Merge branch 'main' into features/constraints
thomasloux Dec 29, 2025
3b8af40
remove duplicate code
thomasloux Dec 29, 2025
a226b2d
rename to constrain_positions
thomasloux Dec 29, 2025
d42fcf7
rename to constrain_momenta
thomasloux Dec 29, 2025
153f957
rename to constrain_forces
thomasloux Dec 29, 2025
a954832
update names to set_constrained_PROP
thomasloux Dec 29, 2025
f07c930
Merge branch 'features/constraints' of https://github.com/thomasloux/…
thomasloux Dec 29, 2025
9b1a88a
Merge branch 'main' into features/constraints
thomasloux Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions torch_sim/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,21 +400,55 @@ def count_degrees_of_freedom(
return max(0, total_dof) # Ensure non-negative


def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None:
"""Issue warnings if constraints might overlap in problematic ways.
def validate_constraints( # noqa: C901
constraints: list[Constraint], state: SimState | None = None
) -> None:
"""Validate constraints for potential issues and incompatibilities.

This function checks for potential issues like multiple constraints
acting on the same atoms, which could lead to unexpected behavior.
This function checks for:
1. Overlapping atom indices across multiple constraints
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 2. AtomIndexedConstraints spanning multiple systems:
So actually the current FixAtoms is implemented so that it is supposed to be only one FixAtoms contraints for a batch system. So it's expected to act on multiple systems.

for 1. it's not so clear that it's a problem to have the same atoms affected by multiple constraints. Another precise example of that, take a water molecule, you often want to constraints H-bonds. Then the oxygen atom will be concerned by 2 constraints.

2. AtomIndexedConstraints spanning multiple systems (requires state)
3. Mixing FixCom with other constraints (warning only)

Args:
constraints: List of constraints to check
constraints: List of constraints to validate
state: Optional SimState for validating atom indices belong to same system

Raises:
ValueError: If constraints are invalid or span multiple systems

Warns:
UserWarning: If constraints may lead to unexpected behavior
"""
if not constraints:
return

indexed_constraints = []
has_com_constraint = False

for constraint in constraints:
if isinstance(constraint, AtomIndexedConstraint):
indexed_constraints.append(constraint)

# Validate that atom indices exist in state if provided
if state is not None and len(constraint.indices) > 0:
if constraint.indices.max() >= state.n_atoms:
raise ValueError(
f"Constraint {type(constraint).__name__} has indices up to "
f"{constraint.indices.max()}, but state only has {state.n_atoms} "
"atoms"
)

# Check that all constrained atoms belong to same system
constrained_system_indices = state.system_idx[constraint.indices]
unique_systems = torch.unique(constrained_system_indices)
if len(unique_systems) > 1:
raise ValueError(
f"Constraint {type(constraint).__name__} acts on atoms from "
f"multiple systems {unique_systems.tolist()}. Each constraint "
f"must operate within a single system."
)

elif isinstance(constraint, FixCom):
has_com_constraint = True

Expand All @@ -427,7 +461,7 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None:
"Multiple constraints are acting on the same atoms. "
"This may lead to unexpected behavior.",
UserWarning,
stacklevel=2,
stacklevel=3,
)

# Warn about COM constraint with fixed atoms
Expand All @@ -437,5 +471,5 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None:
"unexpected behavior. The center of mass constraint is applied "
"to all atoms, including those that may be constrained by other means.",
UserWarning,
stacklevel=2,
stacklevel=3,
)
17 changes: 16 additions & 1 deletion torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

from torch_sim.constraints import AtomIndexedConstraint, Constraint, SystemConstraint
from torch_sim.constraints import (
AtomIndexedConstraint,
Constraint,
SystemConstraint,
validate_constraints,
)


@dataclass
Expand Down Expand Up @@ -140,6 +145,9 @@ def __post_init__(self) -> None:
if not torch.all(counts == torch.bincount(initial_system_idx)):
raise ValueError("System indices must be unique consecutive integers")

if self.constraints:
validate_constraints(self.constraints, state=self)

if self.cell.ndim != 3 and initial_system_idx is None:
self.cell = self.cell.unsqueeze(0)

Expand Down Expand Up @@ -249,6 +257,9 @@ def add_constraints(self, constraints: list[Constraint] | Constraint) -> None:
Args:
constraints (list["Constraint"] | None): List of constraints to apply.
If None, no constraints are applied.

Raises:
ValueError: If constraints are invalid or span multiple systems
"""
# check it is a list
if isinstance(constraints, Constraint):
Expand All @@ -258,6 +269,10 @@ def add_constraints(self, constraints: list[Constraint] | Constraint) -> None:
if hasattr(constraint, "system_idx") and constraint.system_idx == slice(None):
constraint.system_idx = torch.arange(self.n_systems, device=self.device)

# Validate new constraints before adding
all_constraints = self.constraints + constraints
validate_constraints(all_constraints, state=self)

self.constraints += constraints

def get_number_of_degrees_of_freedom(self) -> torch.Tensor:
Expand Down