diff --git a/src/rxn_network/entries/gibbs.py b/src/rxn_network/entries/gibbs.py index ea51c71d..aaecfa8d 100644 --- a/src/rxn_network/entries/gibbs.py +++ b/src/rxn_network/entries/gibbs.py @@ -12,6 +12,7 @@ import numpy as np from monty.json import MontyDecoder from pymatgen.analysis.phase_diagram import GrandPotPDEntry +from pymatgen.core.periodic_table import Element from pymatgen.entries.computed_entries import ComputedEntry, ConstantEnergyAdjustment from scipy.interpolate import interp1d @@ -19,7 +20,6 @@ from rxn_network.data import G_ELEMS if TYPE_CHECKING: - from pymatgen.core.periodic_table import Element from pymatgen.core.structure import Structure from pymatgen.entries.computed_entries import EnergyAdjustment @@ -42,6 +42,8 @@ class GibbsComputedEntry(ComputedEntry): Nature Communications, 9(1), 4168. https://doi.org/10.1038/s41467-018-06682-4. """ + data: dict # Type annotation for parent class attribute + def __init__( self, composition: Composition, @@ -289,11 +291,29 @@ def unique_id(self) -> str: def as_dict(self) -> dict: """Returns an MSONable dict.""" - data = super().as_dict() - data["volume_per_atom"] = self.volume_per_atom - data["formation_energy_per_atom"] = self.formation_energy_per_atom - data["temperature"] = self.temperature - return data + # Temporarily swap self.data to sanitize Element keys for JSON serialization. + # MP API returns oxidation_states with Element keys (e.g., {Element('Li'): 1.0}) + # which aren't valid JSON keys. We must swap the attribute because pymatgen's + # ComputedEntry.as_dict() directly accesses self.data - there's no way to pass + # sanitized data as an argument. + original_data = self.data + self.data = self._sanitize_data(original_data) + result = super().as_dict() + self.data = original_data + result["volume_per_atom"] = self.volume_per_atom + result["formation_energy_per_atom"] = self.formation_energy_per_atom + result["temperature"] = self.temperature + return result + + @staticmethod + def _sanitize_data(data: dict) -> dict: + """Convert Element keys in nested dicts to strings for JSON serialization.""" + sanitized = {} + for k, v in data.items(): + if isinstance(v, dict): + v = {(str(dk) if isinstance(dk, Element) else dk): dv for dk, dv in v.items()} + sanitized[k] = v + return sanitized @classmethod def from_dict(cls, d: dict) -> GibbsComputedEntry: diff --git a/tests/entries/test_gibbs.py b/tests/entries/test_gibbs.py index 3cb70fcf..13b46e21 100644 --- a/tests/entries/test_gibbs.py +++ b/tests/entries/test_gibbs.py @@ -70,6 +70,22 @@ def test_to_from_dict(entry): assert e.energy == pytest.approx(entry.energy) +def test_as_dict_with_element_keys_in_data(entry): + """Test that as_dict works when entry.data contains Element keys. + + The Materials Project API returns oxidation_states with Element keys + (e.g., {Element('Li'): 1.0}), which are not JSON-serializable. + This test ensures as_dict handles this case. + """ + # Simulate data from MP API with Element keys + entry.data["oxidation_states"] = {Element("Li"): 1.0, Element("Fe"): 2.0} + + # This should not raise TypeError: keys must be str, int, float, bool or None, not Element + d = entry.as_dict() + e = GibbsComputedEntry.from_dict(d) + assert e.energy == pytest.approx(entry.energy) + + def test_get_new_temperature(entry): new_temp = 600 # != original temp of 450 new_entry = entry.get_new_temperature(new_temp)