Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b336015
Test implementation
ElliottKasoar Apr 21, 2026
2aa6229
Add mock models
ElliottKasoar Apr 22, 2026
53200b5
Add new CLI option to run mock model
ElliottKasoar Apr 22, 2026
c5e7290
Add option to only run mock
ElliottKasoar Apr 22, 2026
a93b85d
Refactor getting info during analysis
ElliottKasoar Apr 22, 2026
856000a
Allow GMTKN55 to fail during calc
ElliottKasoar Apr 22, 2026
84563f5
Load info and add element dropdown
ElliottKasoar Apr 22, 2026
b9873b9
Refactor analysis for reuse in app
ElliottKasoar Apr 24, 2026
9f39664
Temp app updates
ElliottKasoar Apr 24, 2026
8983086
Write mock structs during analysis
ElliottKasoar Apr 28, 2026
35b4ae5
Make element filter deselective
ElliottKasoar Apr 28, 2026
7db865f
Update apps for filter
ElliottKasoar Apr 28, 2026
ca2c4fd
Update analysis to save structures
ElliottKasoar Apr 28, 2026
3f1fa52
Fix mock model for precision as kwarg
ElliottKasoar Apr 28, 2026
57736ad
Add filter callback
ElliottKasoar Apr 28, 2026
43e8b9b
Fix filter list from apps
ElliottKasoar Apr 29, 2026
3922ea2
Simply inputs
ElliottKasoar Apr 29, 2026
e7abcf0
Reorder parameters
ElliottKasoar Apr 29, 2026
bc67708
Warn if no mock directory
ElliottKasoar Apr 29, 2026
54466d3
Print missing mock dir in warning
ElliottKasoar Apr 29, 2026
fe6b209
Temp app update
ElliottKasoar Apr 29, 2026
dae8793
Update Li diffusion for filtering
ElliottKasoar Apr 29, 2026
07aaf5c
Allow NaNs in metrics
ElliottKasoar Apr 29, 2026
3d32f79
Warn for missing info file
ElliottKasoar Apr 29, 2026
7220492
Allow null filter
ElliottKasoar Apr 29, 2026
91be595
Fix missing models during analysis
ElliottKasoar Apr 29, 2026
84c9113
Refactor filter and update GMTKN55 for filter
ElliottKasoar Apr 30, 2026
8a1a5da
Update GMTKN55 analysis
ElliottKasoar Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 117 additions & 62 deletions ml_peg/analysis/molecular/GMTKN55/analyse_GMTKN55.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
from pathlib import Path

from ase import units
Expand Down Expand Up @@ -48,37 +49,54 @@ def structure_info() -> dict[str, dict[str, float] | list | NDArray]:
"categories": [],
"subsets": [],
"systems": [],
"elements": [],
"excluded": [],
"weights": {},
"counts": {},
}
for model_name in MODELS:
for subset in [dir.name for dir in sorted((CALC_PATH / model_name).glob("*"))]:
count = 0
for system_path in sorted((CALC_PATH / model_name / subset).glob("*.xyz")):
count += 1
structs = read(system_path, index=":")
info["subsets"].append(subset)

info["categories"].append(structs[0].info["category"])
info["systems"].append(structs[0].info["system_name"])
info["excluded"].append(
any(
struct.info["excluded"]
or struct.info["charge"] not in ALLOWED_CHARGES
or struct.info["spin"] not in ALLOWED_MULTIPLICITY
for struct in structs
)
data_dir = CALC_PATH / "mock"
structs_dir = OUT_PATH / "mock"
if not data_dir.exists():
raise ValueError(f"{data_dir} does not exist. Please run mock calculation.")
structs_dir.mkdir(parents=True, exist_ok=True)

for subset in [dir.name for dir in sorted((data_dir).glob("*"))]:
count = 0
for system_path in sorted((data_dir / subset).glob("*.xyz")):
count += 1
structs = read(system_path, index=":")
info["subsets"].append(subset)

info["categories"].append(structs[0].info["category"])
info["systems"].append(structs[0].info["system_name"])
info["excluded"].append(
any(
struct.info["excluded"]
or struct.info["charge"] not in ALLOWED_CHARGES
or struct.info["spin"] not in ALLOWED_MULTIPLICITY
for struct in structs
)
info["weights"][subset] = structs[0].info["weight"]
info["counts"][subset] = count

# Convert to numpy arrays for filtering
info["categories"] = np.array(info["categories"])
info["subsets"] = np.array(info["subsets"])
info["excluded"] = np.array(info["excluded"])
# Only need to access info from one model
return info
)

info["elements"].append(
list(
set().union(*(struct.get_chemical_symbols() for struct in structs))
)
)
write(structs_dir / f"{count}.xyz", structs)

info["weights"][subset] = structs[0].info["weight"]
info["counts"][subset] = count

# Convert to numpy arrays for filtering
info["categories"] = np.array(info["categories"])
info["subsets"] = np.array(info["subsets"])
info["excluded"] = np.array(info["excluded"])

out_file = OUT_PATH / "info.json"
with out_file.open("w", encoding="utf8") as f:
json.dump({"elements": info["elements"]}, f, indent=1)

return info


Expand Down Expand Up @@ -139,8 +157,7 @@ def rel_energies() -> dict[str, list[float]]:
return results


@pytest.fixture
def all_errors(rel_energies: dict[str, list[float]]) -> dict[str, list[float]]:
def get_all_errors(rel_energies: dict[str, list[float]]) -> dict[str, list[float]]:
"""
Calculate MAD for all models for all systems with respect to reference.

Expand All @@ -162,39 +179,49 @@ def all_errors(rel_energies: dict[str, list[float]]) -> dict[str, list[float]]:
return errors


@pytest.fixture
def subset_errors(all_errors: dict[str, list[float]]) -> dict[str, dict[str, float]]:
def get_subset_errors(
rel_energies: dict[str, list[float]],
mask: list[bool],
) -> dict[str, dict[str, float]]:
"""
Calculate mean error for each subset for all models.

Parameters
----------
all_errors
Dictionary of relative MADs, grouped by model.
rel_energies
All reference and predicted relative energies, grouped by model.
mask
Additional boolean mask to apply to info.

Returns
-------
dict[str, dict[str, float]]
Mean error for all models, grouped by subset.
"""
all_errors = get_all_errors(rel_energies)
results = {}

excluded = INFO["excluded"][mask]
subsets_info = INFO["subsets"][mask]

valid = ~excluded
subsets = subsets_info[valid]

for model_name in MODELS:
results[model_name] = {}

# Filter excluded systems from subsets
errors = all_errors[model_name][np.logical_not(INFO["excluded"])]
subsets = INFO["subsets"][np.logical_not(INFO["excluded"])]
errors = all_errors[model_name][valid]

for subset in set(subsets):
results[model_name][subset] = np.mean(errors[subsets == subset])

return results


@pytest.fixture
def category_errors(
def get_category_errors(
subset_errors: dict[str, dict[str, float]],
mask: list[bool],
) -> dict[str, dict[str, float]]:
"""
Calculate MAD for all models, grouped by category.
Expand All @@ -203,6 +230,8 @@ def category_errors(
----------
subset_errors
Nested dictionary of mean errors, grouped by model and subset.
mask
Additional boolean mask to apply to info.

Returns
-------
Expand All @@ -214,20 +243,20 @@ def category_errors(
for model_name in MODELS:
results[model_name] = {}

all_categories = INFO["categories"]
all_subsets = INFO["subsets"]
all_categories = INFO["categories"][mask]
all_subsets = INFO["subsets"][mask]
excluded = INFO["excluded"][mask]
all_weights = INFO["weights"]
all_counts = INFO["counts"]
excluded = INFO["excluded"]

# Filter excluded systems
categories = all_categories[np.logical_not(excluded)]
valid = ~excluded
categories = all_categories[valid]
subsets = all_subsets[valid]

for category in set(categories):
# Filter non-excluded subsets in current category
filtered_subsets = np.unique(
all_subsets[np.logical_not(excluded)][categories == category]
)
filtered_subsets = np.unique(subsets[categories == category])

# Get number of systems in each subset
counts = np.array([all_counts[subset] for subset in filtered_subsets])
Expand All @@ -245,15 +274,18 @@ def category_errors(
return results


@pytest.fixture
def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, float]:
def get_weighted_error(
subset_errors: dict[str, dict[str, float]], mask: list[bool]
) -> dict[str, float]:
"""
Calculate weighted mean absolute deviation for all models.

Parameters
----------
subset_errors
Nested dictionary of mean errors, grouped by model and subset.
mask
Additional boolean mask to apply to info.

Returns
-------
Expand All @@ -265,10 +297,10 @@ def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, floa
for model_name in MODELS:
results[model_name] = {}

all_subsets = INFO["subsets"]
all_subsets = INFO["subsets"][mask]
excluded = INFO["excluded"][mask]
all_weights = INFO["weights"]
all_counts = INFO["counts"]
excluded = INFO["excluded"]

# Filter all non-excluded subsets
filtered_subsets = np.unique(all_subsets[np.logical_not(excluded)])
Expand All @@ -285,32 +317,30 @@ def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, floa
return results


@pytest.fixture
@build_table(
filename=OUT_PATH / "gmtkn55_metrics_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
weights=DEFAULT_WEIGHTS,
mlip_name_map=DISPERSION_NAME_MAP,
)
def metrics(
category_errors: dict[str, dict[str, float]], weighted_error: dict[str, float]
def get_metrics(
rel_energies: dict[str, list[float]], mask: list[bool] | None = None
) -> dict[str, dict]:
"""
Get all GMTKN55 metrics.

Parameters
----------
category_errors
Relative errors for each models, grouped by categories.
weighted_error
Weighted relative error for each model.
rel_energies
All reference and predicted relative energies, grouped by model.
mask
Additional boolean mask to apply to info. Default is `True` for all systems.

Returns
-------
dict[str, dict]
Metric names and values for all models.
"""
if mask is None:
mask = [True] * len(INFO["subsets"])
subset_errors = get_subset_errors(rel_energies, mask=mask)
category_errors = get_category_errors(subset_errors, mask=mask)
weighted_error = get_weighted_error(subset_errors, mask=mask)

category_abbrevs = {
"Basic properties and reaction energies for small systems": "Small systems",
"Reaction energies for large systems and isomerisation reactions": "Large "
Expand All @@ -323,12 +353,37 @@ def metrics(
metrics = {}
for full_category, short_category in category_abbrevs.items():
metrics[short_category] = {
model: category_errors[model][full_category] for model in MODELS
model: category_errors[model].get(full_category, None) for model in MODELS
}

return metrics | {"WTMAD": weighted_error}


@pytest.fixture
@build_table(
filename=OUT_PATH / "gmtkn55_metrics_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
weights=DEFAULT_WEIGHTS,
mlip_name_map=DISPERSION_NAME_MAP,
)
def metrics(rel_energies: dict[str, list[float]]) -> dict[str, dict]:
"""
Get all GMTKN55 metrics.

Parameters
----------
rel_energies
All reference and predicted relative energies, grouped by model.

Returns
-------
dict[str, dict]
Metric names and values for all models.
"""
return get_metrics(rel_energies)


def test_gmtkn55(metrics):
"""
Run GMTKN55 test.
Expand Down
Loading
Loading