Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
- { python: '3.12', resolution: highest }
- { python: '3.13', resolution: lowest-direct }
model:
- { name: chgnet, test_path: "tests/models/test_chgnet.py" }
- { name: fairchem, test_path: "tests/models/test_fairchem.py" }
- { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" }
- { name: graphpes, test_path: "tests/models/test_graphpes.py" }
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
]

autodoc_mock_imports = [
"chgnet",
"fairchem",
"mace",
"mattersim",
Expand Down
149 changes: 149 additions & 0 deletions examples/scripts/1_Introduction/1.4_CHGNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""CHGNet model example for TorchSim."""

# /// script
# dependencies = ["chgnet>=0.4.2", "mace-torch>=0.3.12"]
# ///

import os
import warnings

import torch
from ase import Atoms
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.chgnet import CHGNetModel
from torch_sim.models.mace import MaceModel, MaceUrls


# Silence warnings
warnings.filterwarnings("ignore")
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

print("CHGNet Example for TorchSim")
print("=" * 40)

# Create CHGNet model
model = CHGNetModel(
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)

# Create test systems
al_atoms = bulk("Al", "fcc", a=4.05, cubic=True)
c_atoms = bulk("C", "diamond", a=3.57, cubic=True)
mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21)
a_perovskite = 3.84
ca_tio3_atoms = Atoms(
["Ca", "Ti", "O", "O", "O"],
positions=[
[0, 0, 0],
[a_perovskite / 2, a_perovskite / 2, a_perovskite / 2],
[a_perovskite / 2, 0, 0],
[0, a_perovskite / 2, 0],
[0, 0, a_perovskite / 2],
],
cell=[a_perovskite, a_perovskite, a_perovskite],
pbc=True,
)

# Convert to TorchSim state
state = ts.io.atoms_to_state([al_atoms, c_atoms, mg_atoms], device, dtype)

# Load MACE model for comparison
raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True)
mace_model = MaceModel(
model=raw_mace_mp,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
mace_available = True

# Single comprehensive results table
print(
"\nCHGNet vs MACE Results "
"(E: Total Energy, F: Maximum Force, S: Maximum Stress, M: Maximum Magnetic Moment)"
)
print("=" * 87)
print(
f"{'System':<10} {'CHGNet E':<12} {'CHGNet F':<12} {'CHGNet S':<12} "
f"{'CHGNet M':<12} {'MACE E':<12} {'MACE F':<12}"
)
print("-" * 87)

# Test equilibrium structures
for i, system_name in enumerate(["Al", "C", "Mg"]):
single_state = ts.io.atoms_to_state([[al_atoms, c_atoms, mg_atoms][i]], device, dtype)

# CHGNet results
chgnet_result = model.forward(single_state)
chgnet_energy = chgnet_result["energy"].item()
chgnet_force = torch.norm(chgnet_result["forces"], dim=1).max().item()
chgnet_stress = torch.norm(chgnet_result["stress"], dim=(1, 2)).max().item()
chgnet_magmom = (
torch.norm(chgnet_result.get("magnetic_moments", torch.zeros(1, 3)), dim=-1)
.max()
.item()
)

# MACE results
mace_result = mace_model.forward(single_state)
mace_energy = mace_result["energy"].item()
mace_force = torch.norm(mace_result["forces"], dim=1).max().item()
print(
f"{system_name:<10} {chgnet_energy:<12.3f} {chgnet_force:<12.3f} "
f"{chgnet_stress:<12.3f} {chgnet_magmom:<12.3f} {mace_energy:<12.3f} "
f"{mace_force:<12.3f}"
)

# Test optimization on displaced structures
for atoms, system_name in zip(
[al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"], strict=False
):
single_state = ts.io.atoms_to_state([atoms], device, dtype)
displacement = torch.randn_like(single_state.positions) * 0.1
displaced_state = single_state.clone()
displaced_state.positions = single_state.positions + displacement

# CHGNet optimization
chgnet_optimized = ts.optimize(
displaced_state, model, optimizer=ts.optimizers.Optimizer.fire, max_steps=100
)
chgnet_final = model.forward(chgnet_optimized)
chgnet_final_energy = chgnet_final["energy"].item()
chgnet_final_force = torch.norm(chgnet_final["forces"], dim=1).max().item()
chgnet_final_stress = torch.norm(chgnet_final["stress"], dim=(1, 2)).max().item()
chgnet_final_magmom = (
torch.norm(chgnet_final.get("magnetic_moments", torch.zeros(1, 3)), dim=-1)
.max()
.item()
)

# MACE optimization
mace_optimized = ts.optimize(
displaced_state,
mace_model,
optimizer=ts.optimizers.Optimizer.fire,
max_steps=100,
)
mace_final = mace_model.forward(mace_optimized)
mace_final_energy = mace_final["energy"].item()
mace_final_force = torch.norm(mace_final["forces"], dim=1).max().item()
print(
f"{system_name + '_opt':<10} {chgnet_final_energy:<12.3f} "
f"{chgnet_final_force:<12.3f} {chgnet_final_stress:<12.3f} "
f"{chgnet_final_magmom:<12.3f} {mace_final_energy:<12.3f} "
f"{mace_final_force:<12.3f}"
)

print("=" * 87)
print("CHGNet example completed successfully!")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ sevenn = ["sevenn>=0.11.0"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
nequip = ["nequip>=0.12.0"]
fairchem = ["fairchem-core>=2.7"]
chgnet = ["chgnet>=0.4.2"]
docs = [
"autodoc_pydantic==2.2.0",
"furo==2024.8.6",
Expand Down
125 changes: 125 additions & 0 deletions tests/models/test_chgnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import traceback
from typing import Any, ClassVar

import pytest
import torch
from ase.atoms import Atoms
from ase.calculators.calculator import Calculator, all_changes

from tests.conftest import DEVICE
from tests.models.conftest import (
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)


try:
from chgnet.model.model import CHGNet

from torch_sim.models.chgnet import CHGNetModel
except (ImportError, ValueError):
pytest.skip(
f"CHGNet not installed: {traceback.format_exc()}", allow_module_level=True
)


class CHGNetCalculator(Calculator):
"""ASE Calculator wrapper for CHGNet."""

implemented_properties: ClassVar[list[str]] = ["energy", "forces", "stress"]

def __init__(self, model: CHGNet | None = None, **kwargs) -> None:
Calculator.__init__(self, **kwargs)
self.model = model or CHGNet.load()

def calculate(
self,
atoms: Atoms | None = None,
properties: list[str] | None = None,
system_changes: Any = all_changes,
):
if properties is None:
properties = ["energy"]
Calculator.calculate(self, atoms, properties, system_changes)

# Convert ASE atoms to pymatgen Structure
from pymatgen.io.ase import AseAtomsAdaptor

structure = AseAtomsAdaptor.get_structure(atoms)

# Get CHGNet predictions
result = self.model.predict_structure(structure)

# Convert to ASE format
self.results = {}
if "energy" in properties:
# CHGNet returns energy per atom, convert to total energy
self.results["energy"] = result["e"] * len(structure)

if "forces" in properties:
self.results["forces"] = result["f"]

if "stress" in properties:
self.results["stress"] = result["s"]


DTYPE = torch.float32


@pytest.fixture
def ts_chgnet_model() -> CHGNetModel:
"""Create a TorchSim CHGNet model for testing."""
return CHGNetModel(
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def ase_chgnet_calculator(ts_chgnet_model: CHGNetModel) -> CHGNetCalculator:
"""Create an ASE CHGNet calculator for testing."""
# Use the same model instance to ensure consistency
return CHGNetCalculator(model=ts_chgnet_model.model)


def test_chgnet_missing_atomic_numbers() -> None:
"""Test that CHGNet raises appropriate error when atomic numbers are missing."""
model = CHGNetModel(
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)

# Create state without atomic numbers by using a state dict
state_dict = {
"positions": torch.randn(8, 3, device=DEVICE, dtype=DTYPE),
"cell": torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0),
"pbc": True,
"atomic_numbers": None, # Missing atomic numbers
"system_idx": torch.zeros(8, dtype=torch.long, device=DEVICE),
}

with pytest.raises(ValueError, match="Atomic numbers must be provided"):
model.forward(state_dict)


test_chgnet_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ts_chgnet_model", dtype=DTYPE
)

test_chgnet_consistency = make_model_calculator_consistency_test(
test_name="chgnet",
model_fixture_name="ts_chgnet_model",
calculator_fixture_name="ase_chgnet_calculator",
sim_state_names=("si_sim_state", "cu_sim_state", "mg_sim_state", "ti_sim_state"),
dtype=DTYPE,
energy_rtol=1e-4,
energy_atol=1e-4,
force_rtol=1e-4,
force_atol=1e-4,
stress_rtol=1e-3,
stress_atol=1e-3,
)
Loading
Loading