Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 197 additions & 23 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,49 @@
from pymatgen.core import Structure


def test_get_attrs_for_scope(si_sim_state: SimState) -> None:
@pytest.fixture
def si_sim_state_extra(si_sim_state: SimState) -> SimState:
"""Create a basic state from si_structure with extra attributes."""
si_sim_state.set(
"charge",
torch.tensor([3.0] * si_sim_state.n_atoms, device=si_sim_state.device),
"atom",
)
si_sim_state.set(
"energy",
torch.tensor([3.0], device=si_sim_state.device),
"system",
)
si_sim_state.set(
"max_steps",
100,
"global",
)
return si_sim_state


@pytest.fixture
def si_double_sim_state_extra(si_sim_state_extra: SimState) -> SimState:
return ts.concatenate_states(
[si_sim_state_extra, si_sim_state_extra],
device=si_sim_state_extra.device,
)


def test_get_attrs_for_scope(si_sim_state_extra: SimState) -> None:
"""Test getting attributes for a scope."""
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom"))
assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"}
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
assert set(per_system_attrs) == {"cell"}
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
assert set(global_attrs) == {"pbc"}
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-atom"))
assert set(per_atom_attrs) == {
"positions",
"masses",
"atomic_numbers",
"system_idx",
"charge",
}
per_system_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-system"))
assert set(per_system_attrs) == {"cell", "energy"}
global_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "global"))
assert set(global_attrs) == {"pbc", "max_steps"}


def test_all_attributes_must_be_specified_in_scopes() -> None:
Expand Down Expand Up @@ -66,19 +101,32 @@ class ChildState(SimState):
assert "duplicated_attribute" in str(exc_info.value)


def test_slice_substate(si_double_sim_state: SimState, si_sim_state: SimState) -> None:
def test_slice_substate(
si_double_sim_state_extra: SimState, si_sim_state_extra: SimState
) -> None:
"""Test slicing a substate from the SimState."""
for system_index in range(2):
substate = _slice_state(si_double_sim_state, [system_index])
substate = _slice_state(si_double_sim_state_extra, [system_index])
assert isinstance(substate, SimState)
assert substate.positions.shape == (8, 3)
assert substate.masses.shape == (8,)
assert substate.cell.shape == (1, 3, 3)
assert torch.allclose(substate.positions, si_sim_state.positions)
assert torch.allclose(substate.masses, si_sim_state.masses)
assert torch.allclose(substate.cell, si_sim_state.cell)
assert torch.allclose(substate.atomic_numbers, si_sim_state.atomic_numbers)
assert substate.get_strict("charge").shape == (8,)
assert substate.get_strict("energy").shape == (1,)
assert torch.allclose(substate.positions, si_sim_state_extra.positions)
assert torch.allclose(substate.masses, si_sim_state_extra.masses)
assert torch.allclose(substate.cell, si_sim_state_extra.cell)
assert torch.allclose(substate.atomic_numbers, si_sim_state_extra.atomic_numbers)
assert torch.allclose(substate.system_idx, torch.zeros_like(substate.system_idx))
assert torch.allclose(
substate.get_strict("charge"), si_sim_state_extra.get_strict("charge")
)
assert torch.allclose(
substate.get_strict("energy"), si_sim_state_extra.get_strict("energy")
)
assert substate.get_strict("max_steps") == si_sim_state_extra.get_strict(
"max_steps"
)


def test_slice_md_substate(si_double_sim_state: SimState) -> None:
Expand All @@ -100,19 +148,30 @@ def test_slice_md_substate(si_double_sim_state: SimState) -> None:


def test_concatenate_two_si_states(
si_sim_state: SimState, si_double_sim_state: SimState
si_sim_state_extra: SimState, si_double_sim_state_extra: SimState
) -> None:
"""Test concatenating two identical silicon states."""
# Concatenate two copies of the sim state
concatenated = ts.concatenate_states([si_sim_state, si_sim_state])
concatenated = ts.concatenate_states([si_sim_state_extra, si_sim_state_extra])

# Check that the result is the same as the double state
assert isinstance(concatenated, SimState)
assert concatenated.positions.shape == si_double_sim_state.positions.shape
assert concatenated.masses.shape == si_double_sim_state.masses.shape
assert concatenated.cell.shape == si_double_sim_state.cell.shape
assert concatenated.atomic_numbers.shape == si_double_sim_state.atomic_numbers.shape
assert concatenated.system_idx.shape == si_double_sim_state.system_idx.shape
assert concatenated.positions.shape == si_double_sim_state_extra.positions.shape
assert concatenated.masses.shape == si_double_sim_state_extra.masses.shape
assert concatenated.cell.shape == si_double_sim_state_extra.cell.shape
assert (
concatenated.atomic_numbers.shape
== si_double_sim_state_extra.atomic_numbers.shape
)
assert concatenated.system_idx.shape == si_double_sim_state_extra.system_idx.shape
assert (
concatenated.get_strict("charge").shape
== si_double_sim_state_extra.get_strict("charge").shape
)
assert (
concatenated.get_strict("energy").shape
== si_double_sim_state_extra.get_strict("energy").shape
)

# Check system indices
tensor_args = dict(dtype=torch.int64, device=si_sim_state.device)
Expand All @@ -133,6 +192,17 @@ def test_concatenate_two_si_states(
si_double_sim_state.positions[mask_double],
)

# check that the extra attributes are concatenated correctly
assert torch.allclose(
concatenated.get_strict("charge"), si_double_sim_state_extra.get_strict("charge")
)
assert torch.allclose(
concatenated.get_strict("energy"), si_double_sim_state_extra.get_strict("energy")
)
assert concatenated.get_strict("max_steps") == si_double_sim_state_extra.get_strict(
"max_steps"
)


def test_concatenate_si_and_fe_states(
si_sim_state: SimState, fe_supercell_sim_state: SimState
Expand Down Expand Up @@ -228,16 +298,26 @@ def test_concatenate_double_si_and_fe_states(
assert torch.allclose(fe_slice.positions, fe_supercell_sim_state.positions)


def test_split_state(si_double_sim_state: SimState) -> None:
def test_concatenate_states_with_inconsistent_extra_attributes(
si_sim_state: SimState, si_sim_state_extra: SimState
) -> None:
"""We should only be able to concat states with the same extra attributes."""
with pytest.raises(ValueError):
ts.concatenate_states([si_sim_state, si_sim_state_extra])


def test_split_state(si_double_sim_state_extra: SimState) -> None:
"""Test splitting a state into a list of states."""
states = si_double_sim_state.split()
assert len(states) == si_double_sim_state.n_systems
states = si_double_sim_state_extra.split()
assert len(states) == si_double_sim_state_extra.n_systems
for state in states:
assert isinstance(state, SimState)
assert state.positions.shape == (8, 3)
assert state.masses.shape == (8,)
assert state.cell.shape == (1, 3, 3)
assert state.atomic_numbers.shape == (8,)
assert state.get_strict("charge").shape == (8,)
assert state.get_strict("energy").shape == (1,)
assert torch.allclose(state.system_idx, torch.zeros_like(state.system_idx))


Expand Down Expand Up @@ -287,6 +367,31 @@ def test_pop_states(
assert kept_state.system_idx.shape == (len_kept,)


def test_pop_states_with_extra_attributes(si_double_sim_state_extra: SimState) -> None:
"""Test popping states with extra attributes."""
kept_state, popped_states = _pop_states(
si_double_sim_state_extra,
torch.tensor([0], device=si_double_sim_state_extra.device),
)
assert isinstance(kept_state, SimState)
assert isinstance(popped_states, list)
assert len(popped_states) == 1
assert isinstance(popped_states[0], SimState)
assert popped_states[0].positions.shape == si_double_sim_state_extra.positions.shape
assert popped_states[0].get_strict("charge").shape == (8,)
assert popped_states[0].get_strict("energy").shape == (1,)
assert popped_states[0].get_strict(
"max_steps"
) == si_double_sim_state_extra.get_strict("max_steps")

assert kept_state.positions.shape == (8, 3)
assert kept_state.get_strict("charge").shape == (8,)
assert kept_state.get_strict("energy").shape == (1,)
assert kept_state.get_strict("max_steps") == si_double_sim_state_extra.get_strict(
"max_steps"
)


def test_initialize_state_from_structure(si_structure: "Structure") -> None:
"""Test conversion from pymatgen Structure to state tensors."""
state = ts.initialize_state([si_structure], DEVICE, torch.float64)
Expand Down Expand Up @@ -650,3 +755,72 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None:
"New state doesn't have correct device!"
)
assert si_sim_state is not new_state_gpu, "New state is not a different object!"


def test_state_clone(si_sim_state_extra: SimState) -> None:
"""Test the clone method of SimState."""
cloned_state = si_sim_state_extra.clone()
assert isinstance(cloned_state, SimState)
assert cloned_state is not si_sim_state_extra

for attr in si_sim_state_extra.attributes:
attr_cloned = getattr(cloned_state, attr)
attr_original = getattr(si_sim_state_extra, attr)
if isinstance(attr_cloned, torch.Tensor):
assert torch.allclose(attr_cloned, attr_original)
else:
assert attr_cloned == attr_original
assert attr_cloned is not attr_original


def test_state_get_attribute_with_extra_attributes(si_sim_state_extra: SimState) -> None:
"""Test the __getitem__ method of SimState with extra attributes."""
state = si_sim_state_extra
assert torch.equal(state.get_strict("charge"), state._extra_atom_attributes["charge"])
assert torch.equal(
state.get_strict("energy"), state._extra_system_attributes["energy"]
)
assert state.get_strict("max_steps") == state._extra_global_attributes["max_steps"]


def test_state_set_attribute_with_extra_attributes(
si_sim_state: SimState, si_sim_state_extra: SimState
) -> None:
"""Test the __setitem__ method of SimState with extra attributes."""
target_state = si_sim_state_extra
si_sim_state.set("charge", si_sim_state_extra.get_strict("charge"), "atom")
si_sim_state.set("energy", si_sim_state_extra.get_strict("energy"), "system")
si_sim_state.set("max_steps", si_sim_state_extra.get_strict("max_steps"), "global")

assert torch.equal(
si_sim_state.get_strict("charge"), target_state.get_strict("charge")
)
assert torch.equal(
si_sim_state.get_strict("energy"), target_state.get_strict("energy")
)
assert si_sim_state.get_strict("max_steps") == target_state.get_strict("max_steps")


def test_state_set_attribute_must_specify_kind(
si_sim_state: SimState, si_sim_state_extra: SimState
) -> None:
with pytest.raises(ValueError) as exc_info:
si_sim_state.set("charge", torch.randn(si_sim_state.n_atoms, 1))
assert "Kind must be specified for extra attributes" in str(exc_info.value)


def test_state_set_attribute_with_default_attributes(si_sim_state: SimState) -> None:
"""Test the __setitem__ method of SimState with default attributes."""
new_positions = torch.randn(si_sim_state.n_atoms, 3)
new_cell = torch.randn(si_sim_state.n_systems, 3, 3)
new_pbc = torch.tensor([True, False, True])

si_sim_state.set("positions", new_positions)
si_sim_state.set("cell", new_cell)

# test that we can optionally specify a kind
si_sim_state.set("pbc", new_pbc, kind="global")

assert torch.allclose(si_sim_state.positions, new_positions)
assert torch.allclose(si_sim_state.cell, new_cell)
assert torch.allclose(si_sim_state.pbc, new_pbc)
Loading
Loading