diff --git a/tests/test_state.py b/tests/test_state.py index 1b293668..ab8e42be 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -23,14 +23,49 @@ from pymatgen.core import Structure -def test_get_attrs_for_scope(si_sim_state: SimState) -> None: +@pytest.fixture +def si_sim_state_extra(si_sim_state: SimState) -> SimState: + """Create a basic state from si_structure with extra attributes.""" + si_sim_state.set( + "charge", + torch.tensor([3.0] * si_sim_state.n_atoms, device=si_sim_state.device), + "atom", + ) + si_sim_state.set( + "energy", + torch.tensor([3.0], device=si_sim_state.device), + "system", + ) + si_sim_state.set( + "max_steps", + 100, + "global", + ) + return si_sim_state + + +@pytest.fixture +def si_double_sim_state_extra(si_sim_state_extra: SimState) -> SimState: + return ts.concatenate_states( + [si_sim_state_extra, si_sim_state_extra], + device=si_sim_state_extra.device, + ) + + +def test_get_attrs_for_scope(si_sim_state_extra: SimState) -> None: """Test getting attributes for a scope.""" - per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) - assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} - per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell"} - global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs) == {"pbc"} + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-atom")) + assert set(per_atom_attrs) == { + "positions", + "masses", + "atomic_numbers", + "system_idx", + "charge", + } + per_system_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-system")) + assert set(per_system_attrs) == {"cell", "energy"} + global_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "global")) + assert set(global_attrs) == {"pbc", "max_steps"} def test_all_attributes_must_be_specified_in_scopes() -> None: @@ -66,19 +101,32 @@ class ChildState(SimState): assert "duplicated_attribute" in str(exc_info.value) -def test_slice_substate(si_double_sim_state: SimState, si_sim_state: SimState) -> None: +def test_slice_substate( + si_double_sim_state_extra: SimState, si_sim_state_extra: SimState +) -> None: """Test slicing a substate from the SimState.""" for system_index in range(2): - substate = _slice_state(si_double_sim_state, [system_index]) + substate = _slice_state(si_double_sim_state_extra, [system_index]) assert isinstance(substate, SimState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) assert substate.cell.shape == (1, 3, 3) - assert torch.allclose(substate.positions, si_sim_state.positions) - assert torch.allclose(substate.masses, si_sim_state.masses) - assert torch.allclose(substate.cell, si_sim_state.cell) - assert torch.allclose(substate.atomic_numbers, si_sim_state.atomic_numbers) + assert substate.get_strict("charge").shape == (8,) + assert substate.get_strict("energy").shape == (1,) + assert torch.allclose(substate.positions, si_sim_state_extra.positions) + assert torch.allclose(substate.masses, si_sim_state_extra.masses) + assert torch.allclose(substate.cell, si_sim_state_extra.cell) + assert torch.allclose(substate.atomic_numbers, si_sim_state_extra.atomic_numbers) assert torch.allclose(substate.system_idx, torch.zeros_like(substate.system_idx)) + assert torch.allclose( + substate.get_strict("charge"), si_sim_state_extra.get_strict("charge") + ) + assert torch.allclose( + substate.get_strict("energy"), si_sim_state_extra.get_strict("energy") + ) + assert substate.get_strict("max_steps") == si_sim_state_extra.get_strict( + "max_steps" + ) def test_slice_md_substate(si_double_sim_state: SimState) -> None: @@ -100,19 +148,30 @@ def test_slice_md_substate(si_double_sim_state: SimState) -> None: def test_concatenate_two_si_states( - si_sim_state: SimState, si_double_sim_state: SimState + si_sim_state_extra: SimState, si_double_sim_state_extra: SimState ) -> None: """Test concatenating two identical silicon states.""" # Concatenate two copies of the sim state - concatenated = ts.concatenate_states([si_sim_state, si_sim_state]) + concatenated = ts.concatenate_states([si_sim_state_extra, si_sim_state_extra]) # Check that the result is the same as the double state assert isinstance(concatenated, SimState) - assert concatenated.positions.shape == si_double_sim_state.positions.shape - assert concatenated.masses.shape == si_double_sim_state.masses.shape - assert concatenated.cell.shape == si_double_sim_state.cell.shape - assert concatenated.atomic_numbers.shape == si_double_sim_state.atomic_numbers.shape - assert concatenated.system_idx.shape == si_double_sim_state.system_idx.shape + assert concatenated.positions.shape == si_double_sim_state_extra.positions.shape + assert concatenated.masses.shape == si_double_sim_state_extra.masses.shape + assert concatenated.cell.shape == si_double_sim_state_extra.cell.shape + assert ( + concatenated.atomic_numbers.shape + == si_double_sim_state_extra.atomic_numbers.shape + ) + assert concatenated.system_idx.shape == si_double_sim_state_extra.system_idx.shape + assert ( + concatenated.get_strict("charge").shape + == si_double_sim_state_extra.get_strict("charge").shape + ) + assert ( + concatenated.get_strict("energy").shape + == si_double_sim_state_extra.get_strict("energy").shape + ) # Check system indices tensor_args = dict(dtype=torch.int64, device=si_sim_state.device) @@ -133,6 +192,17 @@ def test_concatenate_two_si_states( si_double_sim_state.positions[mask_double], ) + # check that the extra attributes are concatenated correctly + assert torch.allclose( + concatenated.get_strict("charge"), si_double_sim_state_extra.get_strict("charge") + ) + assert torch.allclose( + concatenated.get_strict("energy"), si_double_sim_state_extra.get_strict("energy") + ) + assert concatenated.get_strict("max_steps") == si_double_sim_state_extra.get_strict( + "max_steps" + ) + def test_concatenate_si_and_fe_states( si_sim_state: SimState, fe_supercell_sim_state: SimState @@ -228,16 +298,26 @@ def test_concatenate_double_si_and_fe_states( assert torch.allclose(fe_slice.positions, fe_supercell_sim_state.positions) -def test_split_state(si_double_sim_state: SimState) -> None: +def test_concatenate_states_with_inconsistent_extra_attributes( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + """We should only be able to concat states with the same extra attributes.""" + with pytest.raises(ValueError): + ts.concatenate_states([si_sim_state, si_sim_state_extra]) + + +def test_split_state(si_double_sim_state_extra: SimState) -> None: """Test splitting a state into a list of states.""" - states = si_double_sim_state.split() - assert len(states) == si_double_sim_state.n_systems + states = si_double_sim_state_extra.split() + assert len(states) == si_double_sim_state_extra.n_systems for state in states: assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) assert state.atomic_numbers.shape == (8,) + assert state.get_strict("charge").shape == (8,) + assert state.get_strict("energy").shape == (1,) assert torch.allclose(state.system_idx, torch.zeros_like(state.system_idx)) @@ -287,6 +367,31 @@ def test_pop_states( assert kept_state.system_idx.shape == (len_kept,) +def test_pop_states_with_extra_attributes(si_double_sim_state_extra: SimState) -> None: + """Test popping states with extra attributes.""" + kept_state, popped_states = _pop_states( + si_double_sim_state_extra, + torch.tensor([0], device=si_double_sim_state_extra.device), + ) + assert isinstance(kept_state, SimState) + assert isinstance(popped_states, list) + assert len(popped_states) == 1 + assert isinstance(popped_states[0], SimState) + assert popped_states[0].positions.shape == si_double_sim_state_extra.positions.shape + assert popped_states[0].get_strict("charge").shape == (8,) + assert popped_states[0].get_strict("energy").shape == (1,) + assert popped_states[0].get_strict( + "max_steps" + ) == si_double_sim_state_extra.get_strict("max_steps") + + assert kept_state.positions.shape == (8, 3) + assert kept_state.get_strict("charge").shape == (8,) + assert kept_state.get_strict("energy").shape == (1,) + assert kept_state.get_strict("max_steps") == si_double_sim_state_extra.get_strict( + "max_steps" + ) + + def test_initialize_state_from_structure(si_structure: "Structure") -> None: """Test conversion from pymatgen Structure to state tensors.""" state = ts.initialize_state([si_structure], DEVICE, torch.float64) @@ -650,3 +755,72 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: "New state doesn't have correct device!" ) assert si_sim_state is not new_state_gpu, "New state is not a different object!" + + +def test_state_clone(si_sim_state_extra: SimState) -> None: + """Test the clone method of SimState.""" + cloned_state = si_sim_state_extra.clone() + assert isinstance(cloned_state, SimState) + assert cloned_state is not si_sim_state_extra + + for attr in si_sim_state_extra.attributes: + attr_cloned = getattr(cloned_state, attr) + attr_original = getattr(si_sim_state_extra, attr) + if isinstance(attr_cloned, torch.Tensor): + assert torch.allclose(attr_cloned, attr_original) + else: + assert attr_cloned == attr_original + assert attr_cloned is not attr_original + + +def test_state_get_attribute_with_extra_attributes(si_sim_state_extra: SimState) -> None: + """Test the __getitem__ method of SimState with extra attributes.""" + state = si_sim_state_extra + assert torch.equal(state.get_strict("charge"), state._extra_atom_attributes["charge"]) + assert torch.equal( + state.get_strict("energy"), state._extra_system_attributes["energy"] + ) + assert state.get_strict("max_steps") == state._extra_global_attributes["max_steps"] + + +def test_state_set_attribute_with_extra_attributes( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + """Test the __setitem__ method of SimState with extra attributes.""" + target_state = si_sim_state_extra + si_sim_state.set("charge", si_sim_state_extra.get_strict("charge"), "atom") + si_sim_state.set("energy", si_sim_state_extra.get_strict("energy"), "system") + si_sim_state.set("max_steps", si_sim_state_extra.get_strict("max_steps"), "global") + + assert torch.equal( + si_sim_state.get_strict("charge"), target_state.get_strict("charge") + ) + assert torch.equal( + si_sim_state.get_strict("energy"), target_state.get_strict("energy") + ) + assert si_sim_state.get_strict("max_steps") == target_state.get_strict("max_steps") + + +def test_state_set_attribute_must_specify_kind( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + with pytest.raises(ValueError) as exc_info: + si_sim_state.set("charge", torch.randn(si_sim_state.n_atoms, 1)) + assert "Kind must be specified for extra attributes" in str(exc_info.value) + + +def test_state_set_attribute_with_default_attributes(si_sim_state: SimState) -> None: + """Test the __setitem__ method of SimState with default attributes.""" + new_positions = torch.randn(si_sim_state.n_atoms, 3) + new_cell = torch.randn(si_sim_state.n_systems, 3, 3) + new_pbc = torch.tensor([True, False, True]) + + si_sim_state.set("positions", new_positions) + si_sim_state.set("cell", new_cell) + + # test that we can optionally specify a kind + si_sim_state.set("pbc", new_pbc, kind="global") + + assert torch.allclose(si_sim_state.positions, new_positions) + assert torch.allclose(si_sim_state.cell, new_cell) + assert torch.allclose(si_sim_state.pbc, new_pbc) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 58f233e8..e4e1edec 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -27,6 +27,8 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): """ from abc import ABC, abstractmethod +from typing import TypedDict +from typing_extensions import deprecated import torch @@ -35,13 +37,28 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): from torch_sim.typing import MemoryScaling, StateDict +class ModelInterfaceOutput(TypedDict): + """The expected output of a model forward pass implementation.""" + + atom_attributes: dict[str, torch.Tensor] + system_attributes: dict[str, torch.Tensor] + global_attributes: dict[str, torch.Tensor] + + # deprecated attributes. People who've written their own model interfaces should move + # away from this and write their results to atom_attributes, system_attributes, and + # global_attributes. + energy: torch.Tensor | None + forces: torch.Tensor | None + stress: torch.Tensor | None + + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. This interface provides a common structure for all energy and force models, ensuring they implement the required methods and properties. It defines how - models should process atomic positions and system information to compute energies, - forces, and stresses. + models should process atomic positions and system information to compute + system-wide attributes like energies/stresses, or atom-wise attributes like forces. Attributes: device (torch.device): Device where the model runs computations. @@ -133,7 +150,7 @@ def memory_scales_with(self) -> MemoryScaling: return getattr(self, "_memory_scales_with", "n_atoms_x_density") @abstractmethod - def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]: + def forward(self, state: SimState | StateDict, **kwargs) -> ModelInterfaceOutput: """Calculate energies, forces, and stresses for a atomistic system. This is the main computational method that all model implementations must provide. @@ -151,27 +168,32 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens **kwargs: Additional model-specific parameters. Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] - - "stress": Stress tensor with shape [n_systems, 3, 3] (if - compute_stress=True) - - May include additional model-specific outputs + ModelInterfaceOutput: Computed properties: + - "atom_attributes": Dictionary of atom-wise attributes + - "system_attributes": Dictionary of system-wide attributes + - "global_attributes": Dictionary of global attributes Examples: ```py # Compute energies and forces with a model output = model.forward(state) - energy = output["energy"] - forces = output["forces"] - stress = output.get("stress", None) + energy = output["system_attributes"]["energy"] + forces = output["atom_attributes"]["forces"] + stress = output["system_attributes"].get("stress") ``` """ +# TODO: we should put this logic inside __init_subclass__ of Modelinterface to +# automatically validate the model outputs when the model is subclassed. def validate_model_outputs( # noqa: C901, PLR0915 - model: ModelInterface, device: torch.device, dtype: torch.dtype + model: ModelInterface, + device: torch.device, + dtype: torch.dtype, + expected_output_atom_attributes: set[str], + expected_output_system_attributes: set[str], + expected_output_global_attributes: set[str], ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -183,7 +205,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 model (ModelInterface): Model implementation to validate. device (torch.device): Device to run the validation tests on. dtype (torch.dtype): Data type to use for validation tensors. - + expected_output_attributes (set[str]): The attributes that the model is expected + to return. Raises: AssertionError: If the model doesn't conform to the required interface, including issues with output shapes, types, or behavior consistency. @@ -243,46 +266,54 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") # assert model output has the correct keys - if "energy" not in model_output: - raise ValueError("energy not in model output") - if force_computed and "forces" not in model_output: - raise ValueError("forces not in model output") - if stress_computed and "stress" not in model_output: - raise ValueError("stress not in model output") - - # assert model output shapes are correct - if model_output["energy"].shape != (2,): - raise ValueError(f"{model_output['energy'].shape=} != (2,)") - if force_computed and model_output["forces"].shape != (20, 3): - raise ValueError(f"{model_output['forces'].shape=} != (20, 3)") - if stress_computed and model_output["stress"].shape != (2, 3, 3): - raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") + for attr in expected_output_atom_attributes: + if attr not in model_output["atom_attributes"]: + raise ValueError(f"{attr} not in model output") + for attr in expected_output_system_attributes: + if attr not in model_output["system_attributes"]: + raise ValueError(f"{attr} not in model output") + for attr in expected_output_global_attributes: + if attr not in model_output["global_attributes"]: + raise ValueError(f"{attr} not in model output") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) si_model_output = model.forward(si_state) - if not torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 - ): - raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") - if not torch.allclose( - forces := si_model_output["forces"], - expected_forces := model_output["forces"][: si_state.n_atoms], - atol=10e-3, - ): - raise ValueError(f"{forces=} != {expected_forces=}") - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) - if not torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 - ): - raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") - if not torch.allclose( - forces := fe_model_output["forces"], - expected_forces := model_output["forces"][si_state.n_atoms :], - atol=10e-2, - ): - raise ValueError(f"{forces=} != {expected_forces=}") + for attr in expected_output_atom_attributes: + if attr in model_output["atom_attributes"]: + si_attr = si_model_output["atom_attributes"][attr] + batched_attr = model_output["atom_attributes"][attr] + expected_attr = batched_attr[: si_state.n_atoms] + if not torch.allclose(si_attr, expected_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {expected_attr=}") + + fe_attr = fe_model_output["atom_attributes"][attr] + expected_fe_attr = batched_attr[si_state.n_atoms :] + if not torch.allclose(fe_attr, expected_fe_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {expected_fe_attr=}") + + for attr in expected_output_system_attributes: + if attr in model_output["system_attributes"]: + si_attr = si_model_output["system_attributes"][attr] + batched_attr = model_output["system_attributes"][attr] + expected_attr = batched_attr[0] + if not torch.allclose(si_attr, expected_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {expected_attr=}") + + fe_attr = fe_model_output["system_attributes"][attr] + expected_fe_attr = batched_attr[1] + if not torch.allclose(fe_attr, expected_fe_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {expected_fe_attr=}") + + for attr in expected_output_global_attributes: + if attr in model_output["global_attributes"]: + si_attr = si_model_output["global_attributes"][attr] + fe_attr = fe_model_output["global_attributes"][attr] + batched_attr = model_output["global_attributes"][attr] + if not torch.allclose(si_attr, batched_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {batched_attr=}") + if not torch.allclose(fe_attr, batched_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {batched_attr=}") diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b7fe73cb..9601e3ce 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -577,18 +577,6 @@ def static( properties=properties, ) - @dataclass(kw_only=True) - class StaticState(SimState): - energy: torch.Tensor - forces: torch.Tensor - stress: torch.Tensor - - _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 - _system_attributes = SimState._system_attributes | { # noqa: SLF001 - "energy", - "stress", - } - all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames @@ -609,25 +597,21 @@ class StaticState(SimState): ) model_outputs = model(sub_state) - static_state = StaticState( - positions=sub_state.positions, - masses=sub_state.masses, - cell=sub_state.cell, - pbc=sub_state.pbc, - atomic_numbers=sub_state.atomic_numbers, - system_idx=sub_state.system_idx, - energy=model_outputs["energy"], - forces=( - model_outputs["forces"] - if model.compute_forces - else torch.full_like(sub_state.positions, fill_value=float("nan")) - ), - stress=( - model_outputs["stress"] - if model.compute_stress - else torch.full_like(sub_state.cell, fill_value=float("nan")) - ), - ) + static_state = sub_state.clone() + for attribute_name, value in model_outputs["atom_attributes"].items(): + static_state.set(attribute_name, value, "atom") + for attribute_name, value in model_outputs["system_attributes"].items(): + static_state.set(attribute_name, value, "system") + for attribute_name, value in model_outputs["global_attributes"].items(): + static_state.set(attribute_name, value, "global") + + # Handle deprecated model outputs + if "energy" in model_outputs: + static_state.set("energy", model_outputs["energy"], "system") + if "forces" in model_outputs: + static_state.set("forces", model_outputs["forces"], "atom") + if "stress" in model_outputs: + static_state.set("stress", model_outputs["stress"], "system") props = trajectory_reporter.report(static_state, 0, model=model) all_props.extend(props) diff --git a/torch_sim/state.py b/torch_sim/state.py index 813354fe..533af247 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -10,7 +10,7 @@ from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, Set, TypeVar import torch @@ -98,6 +98,10 @@ def pbc(self) -> torch.Tensor: """A getter for pbc that tells type checkers it's always defined.""" return self.pbc + _extra_atom_attributes: dict[str, torch.Tensor] = field(default_factory=dict) + _extra_system_attributes: dict[str, torch.Tensor] = field(default_factory=dict) + _extra_global_attributes: dict[str, Any] = field(default_factory=dict) + _atom_attributes: ClassVar[set[str]] = { "positions", "masses", @@ -204,7 +208,7 @@ def volume(self) -> torch.Tensor: def attributes(self) -> dict[str, torch.Tensor]: """Get all public attributes of the state.""" return { - attr: getattr(self, attr) + attr: self.get(attr) for attr in self._atom_attributes | self._system_attributes | self._global_attributes @@ -402,6 +406,74 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def get(self, attribute_name: str) -> Any: + """Get the attribute of the state.""" + if ( + attribute_name + in self._atom_attributes | self._system_attributes | self._global_attributes + ): + return getattr(self, attribute_name) + + if attribute_name in self._extra_atom_attributes: + return self._extra_atom_attributes[attribute_name] + if attribute_name in self._extra_system_attributes: + return self._extra_system_attributes[attribute_name] + if attribute_name in self._extra_global_attributes: + return self._extra_global_attributes[attribute_name] + return None + + def get_strict(self, attribute_name: str) -> Any: + """Get the attribute of the state. + + Raises a ValueError if the attribute is not found. + """ + res = self.get(attribute_name) + if res is None: + raise ValueError(f"Attribute '{attribute_name}' not found in state") + return res + + def set( + self, + attribute_name: str, + value: Any, + kind: Literal["atom", "system", "global"] | None = None, + ) -> None: + """Set the attribute of the state.""" + # 1) Handle special cases for default attributes + all_default_attributes = ( + self._atom_attributes | self._system_attributes | self._global_attributes + ) + if attribute_name in all_default_attributes: + # no need to check kind since it's already a default attribute + setattr(self, attribute_name, value) + return + + # 2) validate the kind and value + if kind is None: + raise ValueError("Kind must be specified for extra attributes") + if kind in ("atom", "system") and not isinstance(value, torch.Tensor): + raise ValueError(f"Value for '{attribute_name}' must be a torch.Tensor") + + # 3) Write the value to the appropriate extra attribute + if kind == "atom": + if value.shape[0] != self.n_atoms: + raise ValueError( + f"Value for '{attribute_name}' must have shape (n_atoms, ...)" + ) + self._extra_atom_attributes[attribute_name] = value + elif kind == "system": + if value.shape[0] != self.n_systems: + raise ValueError( + f"Value for '{attribute_name}' must have shape (n_systems, ...)" + ) + self._extra_system_attributes[attribute_name] = value + elif kind == "global": + self._extra_global_attributes[attribute_name] = value + else: + raise ValueError( + f"Invalid kind: {kind}. Must be 'atom', 'system', or 'global'." + ) + def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. @@ -617,7 +689,9 @@ def _state_to_device[T: SimState]( def get_attrs_for_scope( - state: SimState, scope: Literal["per-atom", "per-system", "global"] + state: SimState, + scope: Literal["per-atom", "per-system", "global"], + attribute_kind: Literal["only_default", "only_extra", "all"] = "all", ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. @@ -629,17 +703,27 @@ def get_attrs_for_scope( Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ + attr_names = set[str]() match scope: case "per-atom": - attr_names = state._atom_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._atom_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_atom_attributes.keys() # noqa: SLF001 case "per-system": - attr_names = state._system_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._system_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_system_attributes.keys() # noqa: SLF001 case "global": - attr_names = state._global_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._global_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_global_attributes.keys() # noqa: SLF001 case _: raise ValueError(f"Unknown scope: {scope!r}") for attr_name in attr_names: - yield attr_name, getattr(state, attr_name) + yield attr_name, state.get(attr_name) def _filter_attrs_by_mask( @@ -870,15 +954,31 @@ def concatenate_states[T: SimState]( # noqa: C901 if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") + # ensure all states have the same extra attributes + first_state_attribute_names = extra_attribute_names(first_state) + if not all( + extra_attribute_names(state) == first_state_attribute_names for state in states + ): + raise ValueError( + "All states must have the same extra attributes. Currently, the first state " + "has these extra attributes: {first_state_attribute_names}" + ) + # Use the target device or default to the first state's device target_device = device or first_state.device # Initialize result with global properties from first state - concatenated = dict(get_attrs_for_scope(first_state, "global")) + concatenated = dict(get_attrs_for_scope(first_state, "global", "only_default")) + concatenated_global_extra = dict( + get_attrs_for_scope(first_state, "global", "only_extra") + ) # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + per_atom_tensors_extra = defaultdict[str, list[torch.Tensor]](list) + per_system_tensors_extra = defaultdict[str, list[torch.Tensor]](list) + new_system_indices = [] system_offset = 0 @@ -889,16 +989,30 @@ def concatenate_states[T: SimState]( # noqa: C901 state = state.to(target_device) # Collect per-atom properties - for prop, val in get_attrs_for_scope(state, "per-atom"): + for prop, val in get_attrs_for_scope( + state, "per-atom", attribute_kind="only_default" + ): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) + for prop, val in get_attrs_for_scope( + state, "per-atom", attribute_kind="only_extra" + ): + per_atom_tensors_extra[prop].append(val) + # Collect per-system properties - for prop, val in get_attrs_for_scope(state, "per-system"): + for prop, val in get_attrs_for_scope( + state, "per-system", attribute_kind="only_default" + ): per_system_tensors[prop].append(val) + for prop, val in get_attrs_for_scope( + state, "per-system", attribute_kind="only_extra" + ): + per_system_tensors_extra[prop].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -907,21 +1021,53 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): - # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) for prop, tensors in per_system_tensors.items(): - # if tensors: if isinstance(tensors[0], torch.Tensor): concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] + # concatenate the extra attributes + concatenated_per_atom_extra = dict( + get_attrs_for_scope(first_state, "per-atom", "only_extra") + ) + concatenated_per_system_extra = dict( + get_attrs_for_scope(first_state, "per-system", "only_extra") + ) + for prop, tensors in per_atom_tensors_extra.items(): + concatenated_per_atom_extra[prop] = torch.cat(tensors, dim=0) + for prop, tensors in per_system_tensors_extra.items(): + if isinstance(tensors[0], torch.Tensor): + concatenated_per_system_extra[prop] = torch.cat(tensors, dim=0) + else: # Non-tensor attributes, take first one (they should all be identical) + concatenated_per_system_extra[prop] = tensors[0] + # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + new_state = state_class(**concatenated) + + # Add the extra attributes (since these attributes are not in the class' constructor) + for prop, value in concatenated_per_atom_extra.items(): + new_state.set(prop, value, kind="atom") + for prop, value in concatenated_per_system_extra.items(): + new_state.set(prop, value, kind="system") + for prop, value in concatenated_global_extra.items(): + new_state.set(prop, value, kind="global") + + return new_state + + +def extra_attribute_names(state: SimState) -> set[str]: + """Get the names of the extra attributes of the state.""" + return ( + state._extra_atom_attributes.keys() # noqa: SLF001 + | state._extra_system_attributes.keys() # noqa: SLF001 + | state._extra_global_attributes.keys() # noqa: SLF001 + ) def initialize_state(