Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ To submit a new model to this benchmark and add it to our leaderboard, please cr
pred_file: /models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set
pred_col: e_form_per_atom_<model_name>
geo_opt: # only applicable if the model performed structure relaxation
pred_file: /models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and forces/stresses at each relaxation step
pred_file: /models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step
pred_col: e_form_per_atom_<model_name>
```

Expand Down
10 changes: 7 additions & 3 deletions matbench_discovery/metrics-which-is-better.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ discovery:
lower_is_better: [MAE, RMSE, FPR, FNR, FP, FN]

geo_opt:
higher_is_better: [Match]
# unclear if lower symmetry increase is really better, could be model actually found a higher symmetry lower-energy structure than DFT optimizer
lower_is_better: [RMSD, Decrease, Increase]
higher_is_better:
- σ<sub>match</sub>
lower_is_better:
- RMSD
- σ<sub>dec</sub>
# - σ<sub>inc</sub> # unclear if lower symmetry increase is really better, could be model actually found a higher symmetry lower-energy structure than DFT optimizer
- N<sub>ops,MAE</sub>

phonons:
lower_is_better: [κ_SRME, κ<sub>SRME</sub>]
Expand Down
131 changes: 81 additions & 50 deletions matbench_discovery/metrics/geo_opt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Functions to calculate and save geometry optimization metrics."""

import numpy as np
import pandas as pd
from pymatviz.enums import Key, Task
from ruamel.yaml.comments import CommentedMap
Expand All @@ -9,85 +8,115 @@
from matbench_discovery.enums import MbdKey


def write_geo_opt_metrics_to_yaml(
df_sym: pd.DataFrame, df_sym_changes: pd.DataFrame
) -> None:
"""Write geometry optimization metrics to model YAML metadata files."""
for model_name in df_sym.columns.levels[0]:
def write_geo_opt_metrics_to_yaml(df_metrics: pd.DataFrame, symprec: float) -> None:
"""Write geometry optimization metrics to model YAML metadata files.

Args:
df_metrics (pd.DataFrame): DataFrame with all geometry optimization metrics.
Index = model names, columns = metric names including:
- structure_rmsd_vs_dft: RMSD between predicted and DFT structures
- n_sym_ops_mae: Mean absolute error in number of symmetry operations
- symmetry_decrease: Fraction of structures with decreased symmetry
- symmetry_match: Fraction of structures with matching symmetry
- symmetry_increase: Fraction of structures with increased symmetry
- n_structs: Number of structures evaluated
symprec (float): spglib symmetry precision for comparing ML and DFT relaxed
structures.
"""
for model_name in df_metrics.index:
try:
model = Model.from_label(model_name)
except StopIteration:
print(f"Skipping {model_name}")
print(f"Skipping unknown {model_name=}")
continue

df_rmsd = df_sym.xs(
MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns"
).round(4)
if model.label not in df_rmsd:
print(f"No RMSD column for {model.label}")
return

# Calculate RMSD
rmsd = df_rmsd[model.label].mean(axis=0)

# Add symmetric mean error calculation for number of symmetry operations
sym_ops_diff = df_sym.drop(Key.dft.label, level=Key.model, axis="columns")[
model.label
][MbdKey.n_sym_ops_diff]

# Calculate symmetric mean error for each model
sym_ops_mae = np.mean(np.abs(sym_ops_diff))

# Calculate symmetry change statistics
if model.label not in df_sym_changes.index:
print(f"No symmetry data for {model.label}")
return
sym_changes = df_sym_changes.round(4).loc[model.label].to_dict()
# Combine metrics
with open(model.yaml_path) as file: # Load existing metadata
# Load existing metadata
with open(model.yaml_path) as file:
model_metadata = round_trip_yaml.load(file)

all_metrics = model_metadata.setdefault("metrics", {})

# Get metrics for this model
model_metrics = df_metrics.loc[model_name]
new_metrics = {
str(Key.rmsd): float(round(rmsd, 4)),
str(Key.n_sym_ops_mae): float(round(sym_ops_mae, 4)),
**sym_changes,
str(Key.rmsd): float(round(model_metrics[MbdKey.structure_rmsd_vs_dft], 4)),
str(Key.n_sym_ops_mae): float(round(model_metrics[Key.n_sym_ops_mae], 4)),
str(Key.symmetry_decrease): float(
round(model_metrics[Key.symmetry_decrease], 4)
),
str(Key.symmetry_match): float(round(model_metrics[Key.symmetry_match], 4)),
str(Key.symmetry_increase): float(
round(model_metrics[Key.symmetry_increase], 4)
),
str(Key.n_structs): int(model_metrics[Key.n_structs]),
}
geo_opt_metrics = CommentedMap(
all_metrics.setdefault(Task.geo_opt, {}) | new_metrics

geo_opt_metrics = CommentedMap(all_metrics.setdefault(Task.geo_opt, {}))
metrics_for_symprec = CommentedMap(
geo_opt_metrics.setdefault(f"{symprec=}", {})
)
metric_units = dict.fromkeys(sym_changes, "fraction") | {
metrics_for_symprec.update(new_metrics)

# Define units for metrics
metric_units = {
Key.rmsd: "Å",
Key.n_sym_ops_mae: "unitless",
Key.symmetry_decrease: "fraction",
Key.symmetry_match: "fraction",
Key.symmetry_increase: "fraction",
Key.n_structs: "count",
}

# Add units as YAML end-of-line comments
for key in new_metrics:
if unit := metric_units.get(key):
geo_opt_metrics.yaml_add_eol_comment(unit, key, column=1)
metrics_for_symprec.yaml_add_eol_comment(unit, key, column=1)

geo_opt_metrics[f"{symprec=}"] = metrics_for_symprec
all_metrics[Task.geo_opt] = geo_opt_metrics

with open(model.yaml_path, mode="w") as file: # Write back to file
# Write back to file
with open(model.yaml_path, mode="w") as file:
round_trip_yaml.dump(model_metadata, file)


def analyze_symmetry_changes(df_sym: pd.DataFrame) -> pd.DataFrame:
"""Analyze how often each model's predicted structure has different symmetry vs DFT.
def calc_geo_opt_metrics(df_geo_opt: pd.DataFrame) -> pd.DataFrame:
"""Calculate geometry optimization metrics for each model.

Args:
df_geo_opt (pd.DataFrame): DataFrame with geometry optimization metrics for all
models and DFT reference. Must have a 2-level column MultiIndex with levels
[model_name, property]. Required properties are:
- structure_rmsd_vs_dft: RMSD between predicted and DFT structures
- n_sym_ops_diff: Difference in number of symmetry operations vs DFT
- spg_num_diff: Difference in space group number vs DFT

Returns:
pd.DataFrame: DataFrame with columns for fraction of structures where symmetry
decreased, matched, or increased vs DFT.
pd.DataFrame: DataFrame with geometry optimization metrics. Shape = (n_models,
n_metrics). Columns include:
- structure_rmsd_vs_dft: Mean RMSD between predicted and DFT structures
- n_sym_ops_mae: Mean absolute error in number of symmetry operations
- symmetry_decrease: Fraction of structures with decreased symmetry
- symmetry_match: Fraction of structures with matching symmetry
- symmetry_increase: Fraction of structures with increased symmetry
- n_structs: Number of structures evaluated
"""
results: dict[str, dict[str, float]] = {}

for model in df_sym.columns.levels[0]:
if model == Key.dft.label: # don't compare DFT to itself
continue
for model in set(df_geo_opt.columns.levels[0]) - {Key.dft.label}:
try:
spg_diff = df_sym[model][MbdKey.spg_num_diff]
n_sym_ops_diff = df_sym[model][MbdKey.n_sym_ops_diff]
# Get relevant columns for this model
spg_diff = df_geo_opt[model][MbdKey.spg_num_diff]
n_sym_ops_diff = df_geo_opt[model][MbdKey.n_sym_ops_diff]
rmsd = df_geo_opt[model][MbdKey.structure_rmsd_vs_dft]

# Count total number of structures (excluding NaN values)
total = len(spg_diff.dropna())

# Calculate RMSD and MAE metrics
mean_rmsd = rmsd.mean()
sym_ops_mae = n_sym_ops_diff.abs().mean()

# Count cases where spacegroup changed
changed_mask = spg_diff != 0
# Among changed cases, count whether symmetry increased or decreased
Expand All @@ -96,14 +125,16 @@ def analyze_symmetry_changes(df_sym: pd.DataFrame) -> pd.DataFrame:
sym_matched = ~changed_mask

results[model] = {
str(MbdKey.structure_rmsd_vs_dft): float(mean_rmsd),
str(Key.n_sym_ops_mae): float(sym_ops_mae),
str(Key.symmetry_decrease): float(sym_decreased.sum() / total),
str(Key.symmetry_match): float(sym_matched.sum() / total),
str(Key.symmetry_increase): float(sym_increased.sum() / total),
str(Key.n_structs): total,
}
except KeyError as exc:
exc.add_note(
f"Missing data for {model}, available columns={list(df_sym[model])}"
f"Missing data for {model}, available columns={list(df_geo_opt[model])}"
)
raise

Expand Down
59 changes: 26 additions & 33 deletions matbench_discovery/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,22 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:


def analyze_symmetry(
structures: dict[str, Structure], *, pbar: bool | dict[str, str] = True
structures: dict[str, Structure],
*,
pbar: bool | dict[str, str] = True,
symprec: float = 1e-2,
angle_tolerance: float = -1,
) -> pd.DataFrame:
"""Analyze symmetry of a dictionary of structures using spglib.

Args:
structures (dict[str, Structure]): Map material IDs to pymatgen Structures
pbar (bool | dict[str, str], optional): Whether to show progress bar.
Defaults to True.
symprec (float, optional): Symmetry precision of spglib.get_symmetry_dataset.
Defaults to 1e-2.
angle_tolerance (float, optional): Angle tol. of spglib.get_symmetry_dataset.
Defaults to -1.

Returns:
pd.DataFrame: DataFrame containing symmetry information for each structure
Expand All @@ -78,21 +86,29 @@ def analyze_symmetry(

for struct_key, struct in iterator:
cell = (struct.lattice.matrix, struct.frac_coords, struct.atomic_numbers)
# spglib 2.5.0 issues lots of warnings:
# - get_bravais_exact_positions_and_lattice failed
# - ssm_get_exact_positions failed
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=spglib.spglib.SpglibError)
sym_data: spglib.SpglibDataset = spglib.get_symmetry_dataset(cell)
warnings.filterwarnings(action="ignore", module="spglib")
sym_data = spglib.get_symmetry_dataset(
cell, symprec=symprec, angle_tolerance=angle_tolerance
)

if sym_data is None:
raise ValueError(f"spglib returned None for {struct_key}\n{struct}")

sym_info = {
new_key: getattr(sym_data, old_key)
for old_key, new_key in sym_key_map.items()
} | {
Key.n_sym_ops: len(sym_data.rotations),
Key.n_rot_syms: len(sym_data.rotations),
Key.n_trans_syms: len(sym_data.translations),
}
sym_info[Key.n_sym_ops] = len(
sym_data.rotations
) # Each rotation has an associated translation
sym_info[Key.n_rot_syms] = len(sym_data.rotations)
sym_info[Key.n_trans_syms] = len(sym_data.translations)

results[struct_key] = sym_info
results[struct_key] = sym_info | dict(
symprec=symprec, angle_tolerance=angle_tolerance
)

df_sym = pd.DataFrame(results).T
df_sym.index.name = Key.mat_id
Expand Down Expand Up @@ -154,26 +170,3 @@ def pred_vs_ref_struct_symmetry(
df_result.loc[mat_id, Key.max_pair_dist] = max_dist

return df_result


if __name__ == "__main__":
import matplotlib.pyplot as plt

gamma = 1.5
samples = np.array([rng.weibull(gamma) for _ in range(10_000)])
mean = samples.mean()

# reproduces the dist in https://www.nature.com/articles/s41524-022-00891-8#Fig5
ax = plt.hist(samples, bins=100)
# add vertical line at the mean
plt.axvline(mean, color="gray", linestyle="dashed", linewidth=1)
# annotate the mean line
plt.annotate(
f"{mean=:.2f}",
xy=(mean, 1),
# use ax coords for y
xycoords=("data", "axes fraction"),
# add text offset
xytext=(10, -20),
textcoords="offset points",
)
20 changes: 14 additions & 6 deletions models/bowsr/bowsr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,20 @@ metrics:
geo_opt:
pred_file: models/bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz
pred_col: structure_bowsr_megnet
rmsd: 0.043 # Å
symmetry_decrease: 0.0037 # fraction
symmetry_match: 0.5271 # fraction
symmetry_increase: 0.4671 # fraction
n_sym_ops_mae: 29.4778
n_structures: 250779.0 # count
symprec=1e-5:
rmsd: 0.043 # Å
symmetry_decrease: 0.0037 # fraction
symmetry_match: 0.5271 # fraction
symmetry_increase: 0.4671 # fraction
n_sym_ops_mae: 29.4778 # unitless
n_structures: 250779 # count
symprec=1e-2:
rmsd: 0.043 # Å
n_sym_ops_mae: 25.4771 # unitless
symmetry_decrease: 0.0618 # fraction
symmetry_match: 0.795 # fraction
symmetry_increase: 0.1307 # fraction
n_structures: 250430 # count
discovery:
pred_file: models/bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv.gz
pred_col: e_form_per_atom_bowsr_megnet
Expand Down
20 changes: 14 additions & 6 deletions models/chgnet/chgnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,20 @@ metrics:
geo_opt:
pred_file: models/chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.json.gz
pred_col: chgnet_structure
rmsd: 0.0216 # Å
symmetry_decrease: 0.2526 # fraction
symmetry_match: 0.5833 # fraction
symmetry_increase: 0.1525 # fraction
n_sym_ops_mae: 3.2731
n_structures: 250779.0 # count
symprec=1e-5:
rmsd: 0.0216 # Å
symmetry_decrease: 0.2526 # fraction
symmetry_match: 0.5833 # fraction
symmetry_increase: 0.1525 # fraction
n_sym_ops_mae: 3.2731 # unitless
n_structures: 250779 # count
symprec=1e-2:
rmsd: 0.0218 # Å
n_sym_ops_mae: 2.0622 # unitless
symmetry_decrease: 0.0937 # fraction
symmetry_match: 0.7799 # fraction
symmetry_increase: 0.118 # fraction
n_structures: 256614 # count
discovery:
pred_file: models/chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.csv.gz
pred_col: e_form_per_atom_chgnet
Expand Down
20 changes: 14 additions & 6 deletions models/eqV2/eqV2-m-omat-mp-salex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,20 @@ metrics:
geo_opt:
pred_file: models/eqV2/eqV2-86M-omat-salex-mp.json.gz
pred_col: eqV2-86M-omat-mp-salex_structure
rmsd: 0.0138 # Å
n_sym_ops_mae: 10.0558
symmetry_decrease: 0.8611 # fraction
symmetry_match: 0.1382 # fraction
symmetry_increase: 0.0005 # fraction
n_structures: 256614.0 # count
symprec=1e-5:
rmsd: 0.0138 # Å
n_sym_ops_mae: 10.0558 # unitless
symmetry_decrease: 0.8611 # fraction
symmetry_match: 0.1382 # fraction
symmetry_increase: 0.0005 # fraction
n_structures: 256614 # count
symprec=1e-2:
rmsd: 0.0112 # Å
n_sym_ops_mae: 2.077 # unitless
symmetry_decrease: 0.1321 # fraction
symmetry_match: 0.7474 # fraction
symmetry_increase: 0.1077 # fraction
n_structures: 256614 # count
discovery:
pred_file: models/eqV2/eqV2-m-omat-mp-salex.csv.gz
pred_col: e_form_per_atom_eqV2-86M-omat-mp-salex
Expand Down
Loading