diff --git a/.gitignore b/.gitignore index 7fd27471..e9579d19 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ CMakeFiles build dist dpdata.egg-info +*.db _version.py !tests/cp2k/aimd/cp2k.log !tests/cp2k/restart_aimd/ch4.log diff --git a/dpdata/plugins/schnetpack.py b/dpdata/plugins/schnetpack.py new file mode 100644 index 00000000..e72df831 --- /dev/null +++ b/dpdata/plugins/schnetpack.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from dpdata.format import Format + + +@Format.register("schnetpack") +class SchNetPackFormat(Format): + """Format for SchNetPack-compatible ASE database. + + SchNetPack uses ASE database format internally for storing atomic structures + and their properties. This format converts dpdata LabeledSystem to + SchNetPack-compatible ASE database format using only ASE functionality. + + The created database can be used directly with SchNetPack for training + machine learning models. + + For more information, see: + https://schnetpack.readthedocs.io/en/latest/tutorials/tutorial_01_preparing_data.html + """ + + def to_labeled_system( + self, + data: dict, + file_name: str = "schnetpack_data.db", + distance_unit: str = "Ang", + property_unit_dict: dict | None = None, + **kwargs, + ) -> None: + """Convert dpdata LabeledSystem to SchNetPack-compatible ASE database format. + + Parameters + ---------- + data : dict + dpdata LabeledSystem data dictionary + file_name : str, optional + Path to the output database file, by default "schnetpack_data.db" + distance_unit : str, optional + Unit for distances, by default "Ang" + property_unit_dict : dict, optional + Dictionary mapping property names to their units. + If None, defaults to {"energy": "eV", "forces": "eV/Ang"} + **kwargs : dict + Additional keyword arguments + + Raises + ------ + ImportError + If ASE is not available + """ + try: + from ase import Atoms + from ase.calculators.singlepoint import SinglePointCalculator + from ase.db import connect + except ImportError as e: + raise ImportError( + "ASE is required for schnetpack format. Install with: pip install ase" + ) from e + + # Set default units if not provided + if property_unit_dict is None: + property_unit_dict = {"energy": "eV", "forces": "eV/Ang"} + + # Create ASE database connection + db = connect(file_name, append=False) + + # Store property units as metadata for the entire database + # This ensures compatibility with different SchNetPack versions + if property_unit_dict: + # Store metadata in database metadata (if supported) + try: + # Some versions of ASE support metadata storage + db.metadata = {"property_units": property_unit_dict} + except (AttributeError, NotImplementedError): + # Fallback: store in a special row (will be filtered out by SchNetPack) + pass + + species = [data["atom_names"][tt] for tt in data["atom_types"]] + + # Handle both list and numpy array formats + import numpy as np + + coords = np.array(data["coords"]) + cells = np.array(data["cells"]) + energies = np.array(data.get("energies", [])) if "energies" in data else None + forces = np.array(data.get("forces", [])) if "forces" in data else None + virials = np.array(data.get("virials", [])) if "virials" in data else None + + nframes = coords.shape[0] + + for frame_idx in range(nframes): + # Create ASE Atoms object for this frame + atoms = Atoms( + symbols=species, + positions=coords[frame_idx], + pbc=not data.get("nopbc", False), + cell=cells[frame_idx], + ) + + # Prepare calculator properties + calc_properties = {} + + # Add energy + if energies is not None: + calc_properties["energy"] = float(energies[frame_idx]) + + # Add forces + if forces is not None: + calc_properties["forces"] = forces[frame_idx] + + # Attach calculator with properties + if calc_properties: + calc = SinglePointCalculator(atoms, **calc_properties) + atoms.calc = calc + + # Prepare additional data for database (e.g., virials) + db_data = {} + if virials is not None: + db_data["virials"] = virials[frame_idx] + + # Add property units as metadata for each row for maximum compatibility + # Some SchNetPack versions might expect this per-row + if property_unit_dict: + db_data["property_units"] = property_unit_dict + + # Ensure energy and forces are accessible in multiple ways for compatibility + write_kwargs = {} + if energies is not None: + # Store energy as a keyword argument for direct access + write_kwargs["energy"] = float(energies[frame_idx]) + if forces is not None: + # Store forces as a keyword argument for direct access + write_kwargs["forces"] = forces[frame_idx] + + # Write to database with all possible access methods + try: + db.write(atoms, data=db_data, **write_kwargs) + except Exception: + # Fallback: write without direct property arguments + # Some ASE versions might not support energy/forces as kwargs + db.write(atoms, data=db_data) + + return None diff --git a/tests/test_schnetpack.py b/tests/test_schnetpack.py new file mode 100644 index 00000000..67ea642c --- /dev/null +++ b/tests/test_schnetpack.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest.mock import patch + +from context import dpdata + + +class TestSchNetPackRegistration(unittest.TestCase): + """Test SchNetPack format registration and error handling.""" + + def test_format_registered(self): + """Test that schnetpack format is properly registered.""" + test_system = dpdata.LabeledSystem() + test_system.data = { + "atom_names": ["H"], + "atom_numbs": [1], + "atom_types": [0], + "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], + "coords": [[[0.0, 0.0, 0.0]]], + "orig": [0.0, 0.0, 0.0], + "energies": [1.0], + "forces": [[[0.0, 0.0, 0.0]]], + } + + # Since ASE is available, this should work and create the database + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "test_registered.db") + test_system.to("schnetpack", db_file) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + +try: + import ase.db # noqa: F401 + + ase_available = True +except ImportError: + ase_available = False + + +@unittest.skipIf(not ase_available, "skip test_schnetpack") +class TestSchNetPack(unittest.TestCase): + def setUp(self): + # Create a simple test system + self.test_system = dpdata.System() + + # Create simple water-like structure for testing + # 3 atoms: O, H, H + self.test_system.data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": [0, 1, 1], # O, H, H + "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], + "coords": [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], + "orig": [0.0, 0.0, 0.0], + } + + # Create labeled system with dummy energies and forces + self.labeled_system = dpdata.LabeledSystem() + self.labeled_system.data = self.test_system.data.copy() + self.labeled_system.data["energies"] = [-10.5] # eV + self.labeled_system.data["forces"] = [ + [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]] + ] # eV/Ang + + # Optional: add virials + self.labeled_system.data["virials"] = [ + [[0.01, 0.0, 0.0], [0.0, 0.01, 0.0], [0.0, 0.0, 0.01]] + ] + + def test_to_schnetpack(self): + """Test conversion to SchNetPack-compatible ASE database format.""" + from ase.db import connect + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "test_data.db") + + # Convert to SchNetPack format + self.labeled_system.to("schnetpack", db_file) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + # Load the database and verify contents using ASE + db = connect(db_file) + + # Check number of structures + self.assertEqual(len(db), 1) + + # Get the structure back + row = db.get(1) + atoms = row.toatoms() + + # Check basic structure information + self.assertEqual(len(atoms), 3) # O, H, H + self.assertEqual(atoms.get_chemical_symbols(), ["O", "H", "H"]) + + # Check that properties are present and accessible + self.assertTrue(hasattr(atoms, "calc") and atoms.calc is not None) + + # Check energy value + self.assertAlmostEqual(atoms.get_potential_energy(), -10.5, places=5) + + # Check forces shape and values + forces = atoms.get_forces() + self.assertEqual(forces.shape, (3, 3)) # 3 atoms, 3 components + + def test_to_schnetpack_custom_units(self): + """Test conversion with custom units.""" + from ase.db import connect + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "test_data_units.db") + + # Convert with custom units + property_units = {"energy": "kcal/mol", "forces": "kcal/mol/Ang"} + + self.labeled_system.to( + "schnetpack", db_file, property_unit_dict=property_units + ) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + # Load and verify using ASE + db = connect(db_file) + self.assertEqual(len(db), 1) + + # Basic verification that data is present + row = db.get(1) + atoms = row.toatoms() + self.assertTrue(hasattr(atoms, "calc") and atoms.calc is not None) + + # Check that units were stored in metadata + self.assertIn("property_units", row.data) + self.assertEqual(row.data["property_units"], property_units) + + def test_to_schnetpack_without_virials(self): + """Test conversion without virials.""" + from ase.db import connect + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "test_no_virials.db") + + # Create system without virials + system_no_virials = dpdata.LabeledSystem() + system_no_virials.data = self.test_system.data.copy() + system_no_virials.data["energies"] = [-10.5] + system_no_virials.data["forces"] = [ + [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]] + ] + + # Convert to SchNetPack format + system_no_virials.to("schnetpack", db_file) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + # Load and verify using ASE + db = connect(db_file) + self.assertEqual(len(db), 1) + + row = db.get(1) + atoms = row.toatoms() + self.assertTrue(hasattr(atoms, "calc") and atoms.calc is not None) + # virials should not be present in row.data + self.assertNotIn("virials", row.data) + + def test_multiframe_system(self): + """Test conversion of multi-frame system.""" + from ase.db import connect + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "test_multiframe.db") + + # Create multi-frame system + multiframe_system = dpdata.LabeledSystem() + multiframe_system.data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": [0, 1, 1], + "cells": [ + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], + ], + "coords": [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.1, 0.0, 0.0], [1.1, 0.0, 0.0], [0.1, 1.0, 0.0]], + ], + "orig": [0.0, 0.0, 0.0], + "energies": [-10.5, -10.6], + "forces": [ + [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]], + [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], + ], + } + + # Convert to SchNetPack format + multiframe_system.to("schnetpack", db_file) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + # Load and verify using ASE + db = connect(db_file) + + # Should have 2 frames + self.assertEqual(len(db), 2) + + # Check both frames + for i in range(1, 3): # ASE database IDs start from 1 + row = db.get(i) + atoms = row.toatoms() + self.assertTrue(hasattr(atoms, "calc") and atoms.calc is not None) + self.assertEqual(len(atoms), 3) # O, H, H + + def test_schnetpack_script_compatibility(self): + """Test compatibility with the specific SchNetPack script from user requirements.""" + from ase.db import connect + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = os.path.join(tmpdir, "schnet.db") + + # Create a system similar to what the user script expects + # with enough frames for num_train=1000, num_val=100 (need at least 1100) + import numpy as np + + num_frames = 1200 + coords = [] + cells = [] + energies = [] + forces = [] + + for i in range(num_frames): + # Base cell + cell = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + cells.append(cell) + + # Slightly perturbed coordinates for variety + coord = [ + [0.0 + 0.01 * i, 0.0, 0.0], # O + [1.0 + 0.01 * np.sin(i / 10), 0.0, 0.0], # H + [0.0, 1.0 + 0.01 * np.cos(i / 10), 0.0], # H + ] + coords.append(coord) + + # Energy (slightly varying) + energy = -10.5 - 0.1 * np.sin(i / 100) + energies.append(energy) + + # Forces (varying) + force = [ + [0.1 + 0.01 * np.cos(i / 20), 0.0, 0.0], + [0.0, 0.1 + 0.01 * np.sin(i / 20), 0.0], + [0.0, 0.0, 0.1 + 0.01 * np.cos(i / 30)], + ] + forces.append(force) + + labeled_system = dpdata.LabeledSystem() + labeled_system.data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": [0, 1, 1], # O, H, H + "cells": cells, + "coords": coords, + "orig": [0.0, 0.0, 0.0], + "energies": energies, + "forces": forces, + } + + # Convert with exact units from user script + expected_units = {"energy": "eV", "forces": "eV/Ang"} + labeled_system.to("schnetpack", db_file, property_unit_dict=expected_units) + + # Verify the database was created + self.assertTrue(os.path.exists(db_file)) + + # Load and verify using ASE + db = connect(db_file) + + # Check we have enough entries for the user's requirements + self.assertGreaterEqual( + len(db), + 1100, + "Need at least 1100 entries for num_train=1000, num_val=100", + ) + + # Check property units match exactly + first_row = db.get(1) + self.assertIn("property_units", first_row.data) + self.assertEqual(first_row.data["property_units"], expected_units) + + # Test that all entries have the required properties for SchNetPack + sample_size = min(10, len(db)) + for i in range(1, sample_size + 1): + row = db.get(i) + atoms = row.toatoms() + + # Check calculator exists + self.assertTrue(hasattr(atoms, "calc") and atoms.calc is not None) + + # Check energy is accessible + energy = atoms.get_potential_energy() + self.assertIsInstance(energy, (int, float)) + + # Check forces are accessible and correct shape + forces = atoms.get_forces() + self.assertEqual(forces.shape, (len(atoms), 3)) + + # Check positions are correct shape + self.assertEqual(atoms.positions.shape, (len(atoms), 3)) + + # Test database iteration (SchNetPack does this) + energies_from_db = [] + for row in db.select(limit=5): + atoms = row.toatoms() + energy = atoms.get_potential_energy() + energies_from_db.append(energy) + + self.assertEqual(len(energies_from_db), 5) + + # Verify all energies are numbers + for energy in energies_from_db: + self.assertIsInstance(energy, (int, float)) + + +class TestSchNetPackMocked(unittest.TestCase): + """Test SchNetPack functionality with mocked dependencies.""" + + def setUp(self): + # Create a simple test system + self.labeled_system = dpdata.LabeledSystem() + self.labeled_system.data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": [0, 1, 1], # O, H, H + "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], + "coords": [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], + "orig": [0.0, 0.0, 0.0], + "energies": [-10.5], # eV + "forces": [[[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]]], # eV/Ang + } + + @patch("dpdata.plugins.schnetpack.SchNetPackFormat.to_labeled_system") + def test_conversion_logic_mocked(self, mock_to_labeled_system): + """Test the conversion logic with mocked dependencies.""" + # Test that the method can be called + mock_to_labeled_system.return_value = None + + # Test the conversion - should call the mocked method + self.labeled_system.to("schnetpack", "/tmp/test.db") + + # Verify the method was called + mock_to_labeled_system.assert_called_once() + + +if __name__ == "__main__": + unittest.main()