diff --git a/contributing.md b/contributing.md index 1480436c2..2503e8a00 100644 --- a/contributing.md +++ b/contributing.md @@ -184,7 +184,7 @@ To submit a new model to this benchmark and add it to our leaderboard, please cr pred_file: /models//--wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set pred_col: e_form_per_atom_ geo_opt: # only applicable if the model performed structure relaxation - pred_file: /models//--wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and forces/stresses at each relaxation step + pred_file: /models//--wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step pred_col: e_form_per_atom_ ``` diff --git a/matbench_discovery/metrics-which-is-better.yml b/matbench_discovery/metrics-which-is-better.yml index 15e990521..963d3bb38 100644 --- a/matbench_discovery/metrics-which-is-better.yml +++ b/matbench_discovery/metrics-which-is-better.yml @@ -17,9 +17,13 @@ discovery: lower_is_better: [MAE, RMSE, FPR, FNR, FP, FN] geo_opt: - higher_is_better: [Match] - # unclear if lower symmetry increase is really better, could be model actually found a higher symmetry lower-energy structure than DFT optimizer - lower_is_better: [RMSD, Decrease, Increase] + higher_is_better: + - σmatch + lower_is_better: + - RMSD + - σdec + # - σinc # unclear if lower symmetry increase is really better, could be model actually found a higher symmetry lower-energy structure than DFT optimizer + - Nops,MAE phonons: lower_is_better: [κ_SRME, κSRME] diff --git a/matbench_discovery/metrics/geo_opt.py b/matbench_discovery/metrics/geo_opt.py index add73dba6..973666312 100644 --- a/matbench_discovery/metrics/geo_opt.py +++ b/matbench_discovery/metrics/geo_opt.py @@ -1,6 +1,5 @@ """Functions to calculate and save geometry optimization metrics.""" -import numpy as np import pandas as pd from pymatviz.enums import Key, Task from ruamel.yaml.comments import CommentedMap @@ -9,85 +8,115 @@ from matbench_discovery.enums import MbdKey -def write_geo_opt_metrics_to_yaml( - df_sym: pd.DataFrame, df_sym_changes: pd.DataFrame -) -> None: - """Write geometry optimization metrics to model YAML metadata files.""" - for model_name in df_sym.columns.levels[0]: +def write_geo_opt_metrics_to_yaml(df_metrics: pd.DataFrame, symprec: float) -> None: + """Write geometry optimization metrics to model YAML metadata files. + + Args: + df_metrics (pd.DataFrame): DataFrame with all geometry optimization metrics. + Index = model names, columns = metric names including: + - structure_rmsd_vs_dft: RMSD between predicted and DFT structures + - n_sym_ops_mae: Mean absolute error in number of symmetry operations + - symmetry_decrease: Fraction of structures with decreased symmetry + - symmetry_match: Fraction of structures with matching symmetry + - symmetry_increase: Fraction of structures with increased symmetry + - n_structs: Number of structures evaluated + symprec (float): spglib symmetry precision for comparing ML and DFT relaxed + structures. + """ + for model_name in df_metrics.index: try: model = Model.from_label(model_name) except StopIteration: - print(f"Skipping {model_name}") + print(f"Skipping unknown {model_name=}") continue - df_rmsd = df_sym.xs( - MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns" - ).round(4) - if model.label not in df_rmsd: - print(f"No RMSD column for {model.label}") - return - - # Calculate RMSD - rmsd = df_rmsd[model.label].mean(axis=0) - - # Add symmetric mean error calculation for number of symmetry operations - sym_ops_diff = df_sym.drop(Key.dft.label, level=Key.model, axis="columns")[ - model.label - ][MbdKey.n_sym_ops_diff] - - # Calculate symmetric mean error for each model - sym_ops_mae = np.mean(np.abs(sym_ops_diff)) - - # Calculate symmetry change statistics - if model.label not in df_sym_changes.index: - print(f"No symmetry data for {model.label}") - return - sym_changes = df_sym_changes.round(4).loc[model.label].to_dict() - # Combine metrics - with open(model.yaml_path) as file: # Load existing metadata + # Load existing metadata + with open(model.yaml_path) as file: model_metadata = round_trip_yaml.load(file) all_metrics = model_metadata.setdefault("metrics", {}) + + # Get metrics for this model + model_metrics = df_metrics.loc[model_name] new_metrics = { - str(Key.rmsd): float(round(rmsd, 4)), - str(Key.n_sym_ops_mae): float(round(sym_ops_mae, 4)), - **sym_changes, + str(Key.rmsd): float(round(model_metrics[MbdKey.structure_rmsd_vs_dft], 4)), + str(Key.n_sym_ops_mae): float(round(model_metrics[Key.n_sym_ops_mae], 4)), + str(Key.symmetry_decrease): float( + round(model_metrics[Key.symmetry_decrease], 4) + ), + str(Key.symmetry_match): float(round(model_metrics[Key.symmetry_match], 4)), + str(Key.symmetry_increase): float( + round(model_metrics[Key.symmetry_increase], 4) + ), + str(Key.n_structs): int(model_metrics[Key.n_structs]), } - geo_opt_metrics = CommentedMap( - all_metrics.setdefault(Task.geo_opt, {}) | new_metrics + + geo_opt_metrics = CommentedMap(all_metrics.setdefault(Task.geo_opt, {})) + metrics_for_symprec = CommentedMap( + geo_opt_metrics.setdefault(f"{symprec=}", {}) ) - metric_units = dict.fromkeys(sym_changes, "fraction") | { + metrics_for_symprec.update(new_metrics) + + # Define units for metrics + metric_units = { Key.rmsd: "Å", + Key.n_sym_ops_mae: "unitless", + Key.symmetry_decrease: "fraction", + Key.symmetry_match: "fraction", + Key.symmetry_increase: "fraction", Key.n_structs: "count", } + # Add units as YAML end-of-line comments for key in new_metrics: if unit := metric_units.get(key): - geo_opt_metrics.yaml_add_eol_comment(unit, key, column=1) + metrics_for_symprec.yaml_add_eol_comment(unit, key, column=1) + geo_opt_metrics[f"{symprec=}"] = metrics_for_symprec all_metrics[Task.geo_opt] = geo_opt_metrics - with open(model.yaml_path, mode="w") as file: # Write back to file + # Write back to file + with open(model.yaml_path, mode="w") as file: round_trip_yaml.dump(model_metadata, file) -def analyze_symmetry_changes(df_sym: pd.DataFrame) -> pd.DataFrame: - """Analyze how often each model's predicted structure has different symmetry vs DFT. +def calc_geo_opt_metrics(df_geo_opt: pd.DataFrame) -> pd.DataFrame: + """Calculate geometry optimization metrics for each model. + + Args: + df_geo_opt (pd.DataFrame): DataFrame with geometry optimization metrics for all + models and DFT reference. Must have a 2-level column MultiIndex with levels + [model_name, property]. Required properties are: + - structure_rmsd_vs_dft: RMSD between predicted and DFT structures + - n_sym_ops_diff: Difference in number of symmetry operations vs DFT + - spg_num_diff: Difference in space group number vs DFT Returns: - pd.DataFrame: DataFrame with columns for fraction of structures where symmetry - decreased, matched, or increased vs DFT. + pd.DataFrame: DataFrame with geometry optimization metrics. Shape = (n_models, + n_metrics). Columns include: + - structure_rmsd_vs_dft: Mean RMSD between predicted and DFT structures + - n_sym_ops_mae: Mean absolute error in number of symmetry operations + - symmetry_decrease: Fraction of structures with decreased symmetry + - symmetry_match: Fraction of structures with matching symmetry + - symmetry_increase: Fraction of structures with increased symmetry + - n_structs: Number of structures evaluated """ results: dict[str, dict[str, float]] = {} - for model in df_sym.columns.levels[0]: - if model == Key.dft.label: # don't compare DFT to itself - continue + for model in set(df_geo_opt.columns.levels[0]) - {Key.dft.label}: try: - spg_diff = df_sym[model][MbdKey.spg_num_diff] - n_sym_ops_diff = df_sym[model][MbdKey.n_sym_ops_diff] + # Get relevant columns for this model + spg_diff = df_geo_opt[model][MbdKey.spg_num_diff] + n_sym_ops_diff = df_geo_opt[model][MbdKey.n_sym_ops_diff] + rmsd = df_geo_opt[model][MbdKey.structure_rmsd_vs_dft] + + # Count total number of structures (excluding NaN values) total = len(spg_diff.dropna()) + # Calculate RMSD and MAE metrics + mean_rmsd = rmsd.mean() + sym_ops_mae = n_sym_ops_diff.abs().mean() + # Count cases where spacegroup changed changed_mask = spg_diff != 0 # Among changed cases, count whether symmetry increased or decreased @@ -96,6 +125,8 @@ def analyze_symmetry_changes(df_sym: pd.DataFrame) -> pd.DataFrame: sym_matched = ~changed_mask results[model] = { + str(MbdKey.structure_rmsd_vs_dft): float(mean_rmsd), + str(Key.n_sym_ops_mae): float(sym_ops_mae), str(Key.symmetry_decrease): float(sym_decreased.sum() / total), str(Key.symmetry_match): float(sym_matched.sum() / total), str(Key.symmetry_increase): float(sym_increased.sum() / total), @@ -103,7 +134,7 @@ def analyze_symmetry_changes(df_sym: pd.DataFrame) -> pd.DataFrame: } except KeyError as exc: exc.add_note( - f"Missing data for {model}, available columns={list(df_sym[model])}" + f"Missing data for {model}, available columns={list(df_geo_opt[model])}" ) raise diff --git a/matbench_discovery/structure.py b/matbench_discovery/structure.py index aa932e0fd..4c5c02b4e 100644 --- a/matbench_discovery/structure.py +++ b/matbench_discovery/structure.py @@ -44,7 +44,11 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure: def analyze_symmetry( - structures: dict[str, Structure], *, pbar: bool | dict[str, str] = True + structures: dict[str, Structure], + *, + pbar: bool | dict[str, str] = True, + symprec: float = 1e-2, + angle_tolerance: float = -1, ) -> pd.DataFrame: """Analyze symmetry of a dictionary of structures using spglib. @@ -52,6 +56,10 @@ def analyze_symmetry( structures (dict[str, Structure]): Map material IDs to pymatgen Structures pbar (bool | dict[str, str], optional): Whether to show progress bar. Defaults to True. + symprec (float, optional): Symmetry precision of spglib.get_symmetry_dataset. + Defaults to 1e-2. + angle_tolerance (float, optional): Angle tol. of spglib.get_symmetry_dataset. + Defaults to -1. Returns: pd.DataFrame: DataFrame containing symmetry information for each structure @@ -78,21 +86,29 @@ def analyze_symmetry( for struct_key, struct in iterator: cell = (struct.lattice.matrix, struct.frac_coords, struct.atomic_numbers) + # spglib 2.5.0 issues lots of warnings: + # - get_bravais_exact_positions_and_lattice failed + # - ssm_get_exact_positions failed with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=spglib.spglib.SpglibError) - sym_data: spglib.SpglibDataset = spglib.get_symmetry_dataset(cell) + warnings.filterwarnings(action="ignore", module="spglib") + sym_data = spglib.get_symmetry_dataset( + cell, symprec=symprec, angle_tolerance=angle_tolerance + ) + + if sym_data is None: + raise ValueError(f"spglib returned None for {struct_key}\n{struct}") sym_info = { new_key: getattr(sym_data, old_key) for old_key, new_key in sym_key_map.items() + } | { + Key.n_sym_ops: len(sym_data.rotations), + Key.n_rot_syms: len(sym_data.rotations), + Key.n_trans_syms: len(sym_data.translations), } - sym_info[Key.n_sym_ops] = len( - sym_data.rotations - ) # Each rotation has an associated translation - sym_info[Key.n_rot_syms] = len(sym_data.rotations) - sym_info[Key.n_trans_syms] = len(sym_data.translations) - - results[struct_key] = sym_info + results[struct_key] = sym_info | dict( + symprec=symprec, angle_tolerance=angle_tolerance + ) df_sym = pd.DataFrame(results).T df_sym.index.name = Key.mat_id @@ -154,26 +170,3 @@ def pred_vs_ref_struct_symmetry( df_result.loc[mat_id, Key.max_pair_dist] = max_dist return df_result - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - - gamma = 1.5 - samples = np.array([rng.weibull(gamma) for _ in range(10_000)]) - mean = samples.mean() - - # reproduces the dist in https://www.nature.com/articles/s41524-022-00891-8#Fig5 - ax = plt.hist(samples, bins=100) - # add vertical line at the mean - plt.axvline(mean, color="gray", linestyle="dashed", linewidth=1) - # annotate the mean line - plt.annotate( - f"{mean=:.2f}", - xy=(mean, 1), - # use ax coords for y - xycoords=("data", "axes fraction"), - # add text offset - xytext=(10, -20), - textcoords="offset points", - ) diff --git a/models/bowsr/bowsr.yml b/models/bowsr/bowsr.yml index b6a18d9c9..922ee49ea 100644 --- a/models/bowsr/bowsr.yml +++ b/models/bowsr/bowsr.yml @@ -60,12 +60,20 @@ metrics: geo_opt: pred_file: models/bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz pred_col: structure_bowsr_megnet - rmsd: 0.043 # Å - symmetry_decrease: 0.0037 # fraction - symmetry_match: 0.5271 # fraction - symmetry_increase: 0.4671 # fraction - n_sym_ops_mae: 29.4778 - n_structures: 250779.0 # count + symprec=1e-5: + rmsd: 0.043 # Å + symmetry_decrease: 0.0037 # fraction + symmetry_match: 0.5271 # fraction + symmetry_increase: 0.4671 # fraction + n_sym_ops_mae: 29.4778 # unitless + n_structures: 250779 # count + symprec=1e-2: + rmsd: 0.043 # Å + n_sym_ops_mae: 25.4771 # unitless + symmetry_decrease: 0.0618 # fraction + symmetry_match: 0.795 # fraction + symmetry_increase: 0.1307 # fraction + n_structures: 250430 # count discovery: pred_file: models/bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv.gz pred_col: e_form_per_atom_bowsr_megnet diff --git a/models/chgnet/chgnet.yml b/models/chgnet/chgnet.yml index 42fa2b777..706076b8b 100644 --- a/models/chgnet/chgnet.yml +++ b/models/chgnet/chgnet.yml @@ -67,12 +67,20 @@ metrics: geo_opt: pred_file: models/chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.json.gz pred_col: chgnet_structure - rmsd: 0.0216 # Å - symmetry_decrease: 0.2526 # fraction - symmetry_match: 0.5833 # fraction - symmetry_increase: 0.1525 # fraction - n_sym_ops_mae: 3.2731 - n_structures: 250779.0 # count + symprec=1e-5: + rmsd: 0.0216 # Å + symmetry_decrease: 0.2526 # fraction + symmetry_match: 0.5833 # fraction + symmetry_increase: 0.1525 # fraction + n_sym_ops_mae: 3.2731 # unitless + n_structures: 250779 # count + symprec=1e-2: + rmsd: 0.0218 # Å + n_sym_ops_mae: 2.0622 # unitless + symmetry_decrease: 0.0937 # fraction + symmetry_match: 0.7799 # fraction + symmetry_increase: 0.118 # fraction + n_structures: 256614 # count discovery: pred_file: models/chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.csv.gz pred_col: e_form_per_atom_chgnet diff --git a/models/eqV2/eqV2-m-omat-mp-salex.yml b/models/eqV2/eqV2-m-omat-mp-salex.yml index 3ace340d5..da803c3d3 100644 --- a/models/eqV2/eqV2-m-omat-mp-salex.yml +++ b/models/eqV2/eqV2-m-omat-mp-salex.yml @@ -87,12 +87,20 @@ metrics: geo_opt: pred_file: models/eqV2/eqV2-86M-omat-salex-mp.json.gz pred_col: eqV2-86M-omat-mp-salex_structure - rmsd: 0.0138 # Å - n_sym_ops_mae: 10.0558 - symmetry_decrease: 0.8611 # fraction - symmetry_match: 0.1382 # fraction - symmetry_increase: 0.0005 # fraction - n_structures: 256614.0 # count + symprec=1e-5: + rmsd: 0.0138 # Å + n_sym_ops_mae: 10.0558 # unitless + symmetry_decrease: 0.8611 # fraction + symmetry_match: 0.1382 # fraction + symmetry_increase: 0.0005 # fraction + n_structures: 256614 # count + symprec=1e-2: + rmsd: 0.0112 # Å + n_sym_ops_mae: 2.077 # unitless + symmetry_decrease: 0.1321 # fraction + symmetry_match: 0.7474 # fraction + symmetry_increase: 0.1077 # fraction + n_structures: 256614 # count discovery: pred_file: models/eqV2/eqV2-m-omat-mp-salex.csv.gz pred_col: e_form_per_atom_eqV2-86M-omat-mp-salex diff --git a/models/eqV2/eqV2-s-dens-mp.yml b/models/eqV2/eqV2-s-dens-mp.yml index 45f82f4bb..0bb68cafe 100644 --- a/models/eqV2/eqV2-s-dens-mp.yml +++ b/models/eqV2/eqV2-s-dens-mp.yml @@ -91,12 +91,20 @@ metrics: geo_opt: pred_file: models/eqV2/eqV2-31M-dens-mp-p5.json.gz pred_col: eqV2-31M-dens-MP-p5_structure - rmsd: 0.0138 # Å - n_sym_ops_mae: 10.0558 - symmetry_decrease: 0.8611 # fraction - symmetry_match: 0.1382 # fraction - symmetry_increase: 0.0005 # fraction - n_structures: 256614.0 # count + symprec=1e-5: + rmsd: 0.0138 # Å + n_sym_ops_mae: 10.0558 # unitless + symmetry_decrease: 0.8611 # fraction + symmetry_match: 0.1382 # fraction + symmetry_increase: 0.0005 # fraction + n_structures: 256614 # count + symprec=1e-2: + rmsd: 0.0138 # Å + n_sym_ops_mae: 3.9922 # unitless + symmetry_decrease: 0.4074 # fraction + symmetry_match: 0.5027 # fraction + symmetry_increase: 0.0631 # fraction + n_structures: 256614 # count discovery: pred_file: models/eqV2/eqV2-s-dens-mp.csv.gz pred_col: e_form_per_atom_eqV2-31M-dens-MP-p5 diff --git a/models/grace2l_r6/grace2l-r6.yml b/models/grace2l_r6/grace2l-r6.yml index f266e31cd..8ba45f4fb 100644 --- a/models/grace2l_r6/grace2l-r6.yml +++ b/models/grace2l_r6/grace2l-r6.yml @@ -54,12 +54,20 @@ metrics: geo_opt: pred_file: models/grace2l_r6/GRACE_2L_r6_11Nov2024_relaxed_structures.json.gz pred_col: grace2l_r6_structure - rmsd: 0.0186 # Å - n_sym_ops_mae: 1.8703 - symmetry_decrease: 0.0355 # fraction - symmetry_match: 0.7315 # fraction - symmetry_increase: 0.2285 # fraction - n_structures: 256862.0 # count + symprec=1e-5: + rmsd: 0.0186 # Å + n_sym_ops_mae: 1.8703 # unitless + symmetry_decrease: 0.0355 # fraction + symmetry_match: 0.7315 # fraction + symmetry_increase: 0.2285 # fraction + n_structures: 256862 # count + symprec=1e-2: + rmsd: 0.0186 # Å + n_sym_ops_mae: 1.8982 # unitless + symmetry_decrease: 0.0592 # fraction + symmetry_match: 0.7976 # fraction + symmetry_increase: 0.1363 # fraction + n_structures: 256513 # count discovery: pred_file: models/grace2l_r6/2024-11-21-MP_GRACE_2L_r6_11Nov2024-wbm-IS2RE-FIRE.csv.gz pred_col: e_form_per_atom_grace diff --git a/models/m3gnet/m3gnet.yml b/models/m3gnet/m3gnet.yml index f6c1b97f7..c4d7a446c 100644 --- a/models/m3gnet/m3gnet.yml +++ b/models/m3gnet/m3gnet.yml @@ -57,12 +57,20 @@ metrics: geo_opt: pred_file: models/m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.json.gz pred_col: m3gnet_structure - rmsd: 0.0217 # Å - symmetry_decrease: 0.0652 # fraction - symmetry_match: 0.7488 # fraction - symmetry_increase: 0.1804 # fraction - n_sym_ops_mae: 1.7751 - n_structures: 256963.0 # count + symprec=1e-5: + rmsd: 0.0217 # Å + n_sym_ops_mae: 1.7751 # unitless + symmetry_decrease: 0.0652 # fraction + symmetry_match: 0.7488 # fraction + symmetry_increase: 0.1804 # fraction + n_structures: 256963 # count + symprec=1e-2: + rmsd: 0.0217 # Å + n_sym_ops_mae: 2.0111 # unitless + symmetry_decrease: 0.0719 # fraction + symmetry_match: 0.7933 # fraction + symmetry_increase: 0.1278 # fraction + n_structures: 256614 # count discovery: pred_file: models/m3gnet/2023-12-28-m3gnet-wbm-IS2RE.csv.gz pred_col: e_form_per_atom_m3gnet diff --git a/models/mace/mace.yml b/models/mace/mace.yml index 7829e65ad..a0727030e 100644 --- a/models/mace/mace.yml +++ b/models/mace/mace.yml @@ -72,12 +72,20 @@ metrics: geo_opt: pred_file: models/mace/2023-12-11-mace-wbm-IS2RE-FIRE.json.gz pred_col: mace_structure - rmsd: 0.0194 # Å - symmetry_decrease: 0.035 # fraction - symmetry_match: 0.7361 # fraction - symmetry_increase: 0.2243 # fraction - n_sym_ops_mae: 1.8584 - n_structures: 243070.0 # count + symprec=1e-5: + rmsd: 0.0194 # Å + n_sym_ops_mae: 1.8584 # unitless + symmetry_decrease: 0.035 # fraction + symmetry_match: 0.7361 # fraction + symmetry_increase: 0.2243 # fraction + n_structures: 243070 # count + symprec=1e-2: + rmsd: 0.0197 # Å + n_sym_ops_mae: 1.8961 # unitless + symmetry_decrease: 0.0602 # fraction + symmetry_match: 0.7977 # fraction + symmetry_increase: 0.1353 # fraction + n_structures: 249034 # count discovery: pred_file: models/mace/2023-12-11-mace-wbm-IS2RE-FIRE.csv.gz pred_col: e_form_per_atom_mace diff --git a/models/orb/orb-mptrj.yml b/models/orb/orb-mptrj.yml index 856fd0748..3cee0be37 100644 --- a/models/orb/orb-mptrj.yml +++ b/models/orb/orb-mptrj.yml @@ -77,12 +77,20 @@ metrics: geo_opt: pred_file: models/orb/orb-mptrj-only-v2-20241014.json.gz pred_col: orb_structure - rmsd: 0.0185 # Å - symmetry_decrease: 0.8594 # fraction - symmetry_match: 0.1397 # fraction - symmetry_increase: 0.0007 # fraction - n_sym_ops_mae: 10.036 - n_structures: 256963.0 # count + symprec=1e-5: + rmsd: 0.0185 # Å + n_sym_ops_mae: 10.036 # unitless + symmetry_decrease: 0.8594 # fraction + symmetry_match: 0.1397 # fraction + symmetry_increase: 0.0007 # fraction + n_structures: 256963 # count + symprec=1e-2: + rmsd: 0.0185 # Å + n_sym_ops_mae: 6.1283 # unitless + symmetry_decrease: 0.586 # fraction + symmetry_match: 0.3716 # fraction + symmetry_increase: 0.0332 # fraction + n_structures: 256614 # count discovery: pred_file: models/orb/orbff-mptrj-only-v2-20241014.csv.gz pred_col: e_form_per_atom_orb diff --git a/models/orb/orb.yml b/models/orb/orb.yml index 1a9124016..43300a521 100644 --- a/models/orb/orb.yml +++ b/models/orb/orb.yml @@ -77,12 +77,20 @@ metrics: geo_opt: pred_file: models/orb/orb-v2-20241011.json.gz pred_col: orb_structure - rmsd: 0.016 # Å - symmetry_decrease: 0.8473 # fraction - symmetry_match: 0.1494 # fraction - symmetry_increase: 0.0031 # fraction - n_sym_ops_mae: 9.8834 - n_structures: 256963.0 # count + symprec=1e-5: + rmsd: 0.016 # Å + n_sym_ops_mae: 9.8834 # unitless + symmetry_decrease: 0.8473 # fraction + symmetry_match: 0.1494 # fraction + symmetry_increase: 0.0031 # fraction + n_structures: 256963 # count + symprec=1e-2: + rmsd: 0.016 # Å + n_sym_ops_mae: 5.5556 # unitless + symmetry_decrease: 0.5211 # fraction + symmetry_match: 0.4286 # fraction + symmetry_increase: 0.0401 # fraction + n_structures: 256614 # count discovery: pred_file: models/orb/orbff-v2-20241011.csv.gz pred_col: e_form_per_atom_orb diff --git a/models/sevennet/sevennet.yml b/models/sevennet/sevennet.yml index 216a4fbdc..995465743 100644 --- a/models/sevennet/sevennet.yml +++ b/models/sevennet/sevennet.yml @@ -80,12 +80,20 @@ metrics: geo_opt: pred_file: models/sevennet/2024-07-11-sevennet-0-relaxed-structures.json.gz pred_col: sevennet_structure - rmsd: 0.0193 # Å - symmetry_decrease: 0.3557 # fraction - symmetry_match: 0.4535 # fraction - symmetry_increase: 0.1446 # fraction - n_sym_ops_mae: 2.5921 - n_structures: 256963.0 # count + symprec=1e-5: + rmsd: 0.0193 # Å + n_sym_ops_mae: 2.5921 # unitless + symmetry_decrease: 0.3557 # fraction + symmetry_match: 0.4535 # fraction + symmetry_increase: 0.1446 # fraction + n_structures: 256963 # count + symprec=1e-2: + rmsd: 0.0193 # Å + n_sym_ops_mae: 1.9558 # unitless + symmetry_decrease: 0.0831 # fraction + symmetry_match: 0.7823 # fraction + symmetry_increase: 0.1262 # fraction + n_structures: 256614 # count discovery: pred_file: models/sevennet/2024-07-11-sevennet-0-preds.csv.gz pred_col: e_form_per_atom_sevennet diff --git a/scripts/metrics/eval_geo_opt.py b/scripts/metrics/eval_geo_opt.py index 09f9b433b..1d3c27234 100644 --- a/scripts/metrics/eval_geo_opt.py +++ b/scripts/metrics/eval_geo_opt.py @@ -1,4 +1,5 @@ -"""Analyze symmetry retention across models.""" +"""Evaluate ML vs DFT-relaxed structure similarity and symmetry retention for different +MLFFs.""" # %% import os @@ -18,6 +19,7 @@ from matbench_discovery.data import Model, df_wbm from matbench_discovery.enums import MbdKey +symprec = 1e-2 init_spg_col = "init_spg_num" dft_spg_col = "dft_spg_num" df_wbm[init_spg_col] = df_wbm[MbdKey.init_wyckoff].str.split("_").str[2].astype(int) @@ -26,15 +28,12 @@ # %% -csv_path = f"{ROOT}/data/2024-11-26-all-models-geo-opt-analysis.csv.gz" -df_sym = pd.read_csv(csv_path, header=[0, 1], index_col=0) +csv_path = f"{ROOT}/data/2024-11-29-all-models-geo-opt-analysis-{symprec=}.csv.gz" +df_go = pd.read_csv(csv_path, header=[0, 1], index_col=0) -df_sym_changes = go_metrics.analyze_symmetry_changes(df_sym).convert_dtypes() +df_go_metrics = go_metrics.calc_geo_opt_metrics(df_go).convert_dtypes() -# limit the number of structures loaded per model to this number, 0 for no limit -debug_mode: int = 0 - retained = (df_wbm[init_spg_col] == df_wbm[dft_spg_col]).sum() print( @@ -42,12 +41,19 @@ f"{retained:,} / {len(df_wbm):,} ({retained/len(df_wbm):.2%})" ) -models = df_sym.columns.levels[0] -print(f"\n{len(models)=}: {', '.join(models)}") +models = df_go.columns.levels[0] + + +# %% +display( + df_go_metrics.style.set_caption(f"Symmetry changes vs DFT for {len(models)} models") + .background_gradient(cmap="RdBu") + .format(precision=3) +) # %% Plot violin plot of RMSD vs DFT -df_rmsd = df_sym.xs(MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns") +df_rmsd = df_go.xs(MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns") fig_rmsd = px.violin( df_rmsd.round(3).dropna(), @@ -61,7 +67,7 @@ fig_rmsd.update_traces(orientation="h", side="positive", width=1.8) # add annotation for mean for each model -for model, srs_rmsd in df_sym.xs( +for model, srs_rmsd in df_go.xs( MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns" ).items(): mean_rmsd = srs_rmsd.mean() @@ -92,7 +98,7 @@ # %% calculate number of model spacegroups agreeing with DFT-relaxed spacegroup avg_spg_diff = ( - df_sym.xs(MbdKey.spg_num_diff, level=MbdKey.sym_prop, axis="columns") + df_go.xs(MbdKey.spg_num_diff, level=MbdKey.sym_prop, axis="columns") .mean(axis=0) .round(1) ) @@ -101,7 +107,7 @@ # %% violin plot of spacegroup number diff vs DFT fig_sym = px.violin( - df_sym.xs(MbdKey.spg_num_diff, level=MbdKey.sym_prop, axis="columns"), + df_go.xs(MbdKey.spg_num_diff, level=MbdKey.sym_prop, axis="columns"), title="Spacegroup Number Diff vs DFT", orientation="h", color="model", @@ -118,7 +124,7 @@ # %% violin plot of number of symmetry operations in ML-relaxed structures fig_sym_ops = px.violin( - df_sym.xs(Key.n_sym_ops, level=MbdKey.sym_prop, axis="columns"), + df_go.xs(Key.n_sym_ops, level=MbdKey.sym_prop, axis="columns"), title="Number of Symmetry Operations in ML-relaxed Structures", orientation="h", color="model", @@ -134,7 +140,7 @@ # %% violin plot of number of symmetry operations in ML-relaxed structures vs DFT fig_sym_ops_diff = px.violin( - df_sym.drop(Key.dft.label, level=Key.model, axis="columns") + df_go.drop(Key.dft.label, level=Key.model, axis="columns") .xs(MbdKey.n_sym_ops_diff, level=MbdKey.sym_prop, axis="columns") .reset_index(), orientation="h", @@ -153,7 +159,7 @@ # %% bar plot of number of symmetry operations in ML-relaxed structures vs DFT -df_sym_ops_diff = df_sym.drop(Key.dft.label, level=Key.model, axis="columns").xs( +df_sym_ops_diff = df_go.drop(Key.dft.label, level=Key.model, axis="columns").xs( MbdKey.n_sym_ops_diff, level=MbdKey.sym_prop, axis="columns" ) @@ -210,7 +216,7 @@ # %% Print summary of symmetry changes display( - df_sym_changes.round(3) + df_go_metrics.round(3) .rename(columns=lambda col: col.removeprefix("symmetry_")) .style.format(lambda x: f"{x:.1%}" if isinstance(x, float) else si_fmt(x)) .background_gradient(cmap="Oranges", subset="decrease") @@ -224,11 +230,11 @@ fig_rmsd_cdf = go.Figure() x_max = 0.05 -models = df_sym_changes.index +models = df_go_metrics.index # Calculate and plot CDF for each model for model in models: - rmsd_vals = df_sym.xs( + rmsd_vals = df_go.xs( MbdKey.structure_rmsd_vs_dft, level=MbdKey.sym_prop, axis="columns" )[model].dropna() @@ -278,10 +284,10 @@ # %% if __name__ == "__main__": - go_metrics.write_geo_opt_metrics_to_yaml(df_sym, df_sym_changes) + go_metrics.write_geo_opt_metrics_to_yaml(df_go_metrics, symprec) # %% plot ML vs DFT relaxed spacegroup correspondence as sankey diagrams - df_spg = df_sym.xs(Key.spg_num, level=MbdKey.sym_prop, axis="columns") + df_spg = df_go.xs(Key.spg_num, level=MbdKey.sym_prop, axis="columns") for model_label in {*df_spg} - {Key.dft.label}: # get most common pairs of DFT/Model spacegroups model = Model.from_label(model_label) diff --git a/scripts/metrics/update_df_geo_opt.py b/scripts/metrics/update_df_geo_opt.py index a5493eaf9..a69c9a806 100644 --- a/scripts/metrics/update_df_geo_opt.py +++ b/scripts/metrics/update_df_geo_opt.py @@ -1,6 +1,8 @@ -"""Functions to calculate and save geometry optimization metrics.""" +"""Run this script to add/update geometry optimization analysis for new models to a CSV +file containing all models.""" # %% +import gc import os import pandas as pd @@ -9,108 +11,109 @@ from matbench_discovery import ROOT, today from matbench_discovery.data import DataFiles -from matbench_discovery.enums import MbdKey from matbench_discovery.models import MODEL_METADATA from matbench_discovery.structure import analyze_symmetry, pred_vs_ref_struct_symmetry debug_mode: int = 0 -csv_path = f"{ROOT}/data/2024-11-26-all-models-geo-opt-analysis.csv.gz" -df_sym = pd.read_csv(csv_path, header=[0, 1], index_col=0) +symprec: float = 1e-2 +csv_path = f"{ROOT}/data/2024-11-29-all-models-geo-opt-analysis-{symprec=}.csv.gz" +if os.path.isfile(csv_path): + df_go = pd.read_csv(csv_path, index_col=0, header=[0, 1]) +else: + df_go = pd.DataFrame(columns=pd.MultiIndex(levels=[[], []], codes=[[], []])) -# %% + +# %% Load WBM reference structures df_wbm_structs = pd.read_json(DataFiles.wbm_computed_structure_entries.path) df_wbm_structs = df_wbm_structs.set_index(Key.mat_id) if debug_mode: df_wbm_structs = df_wbm_structs.head(debug_mode) -# %% Load all available model-relaxed structures for all models -dfs_model_structs: dict[str, pd.DataFrame] = {} +# %% Analyze DFT structures if not already done +dft_structs = locals().get("dft_structs") or { + mat_id: Structure.from_dict(cse[Key.structure]) + for mat_id, cse in df_wbm_structs[Key.cse].items() +} +if Key.dft.label in df_go: + dft_analysis = df_go[Key.dft.label] +else: + dft_analysis = analyze_symmetry( + dft_structs, + pbar=dict(desc=f"Getting DFT symmetries {symprec=}"), + symprec=symprec, + ) + for col in dft_analysis: + df_go[(Key.dft.label, col)] = dft_analysis[col] + -for model_name, model_metadata in MODEL_METADATA.items(): - if model_name in df_sym: - print(f"- {model_name} already analyzed") +# %% Process each model sequentially +for idx, (model_label, model_metadata) in enumerate(MODEL_METADATA.items()): + prog_str = f"{idx + 1}/{len(MODEL_METADATA)}:" + if model_label in df_go: + print(f"{prog_str} {model_label} already analyzed") continue + geo_opt_metrics = model_metadata.get("metrics", {}).get("geo_opt", {}) + + # skip models that don't support geometry optimization if geo_opt_metrics in ("not applicable", "not available"): continue + ml_relaxed_structs_path = f"{ROOT}/{geo_opt_metrics.get('pred_file')}" - if not ml_relaxed_structs_path: - continue - if not os.path.isfile(ml_relaxed_structs_path): - print(f"⚠️ {model_name}-relaxed structures not found") + if not ml_relaxed_structs_path or not os.path.isfile(ml_relaxed_structs_path): + print(f"⚠️ {model_label}-relaxed structures not found") continue - if ( - model_name in dfs_model_structs - # reload df_model if debug_mode changed - and (len(dfs_model_structs[model_name]) == debug_mode or debug_mode == 0) - ): - continue + # Load model structures df_model = pd.read_json(ml_relaxed_structs_path).set_index(Key.mat_id) if debug_mode: df_model = df_model.head(debug_mode) - dfs_model_structs[model_name] = df_model - n_structs_for_model = len(dfs_model_structs[model_name]) - print(f"+ Loaded {n_structs_for_model:,} structures for {model_name}") - - -# %% Perform symmetry analysis for all model-relaxed structures -dfs_sym_all: dict[str, pd.DataFrame] = {} -df_structs = pd.DataFrame() - -for model_name, df_model in dfs_model_structs.items(): - n_structs_for_model = len(df_model.dropna()) - n_structs_analyzed = len(dfs_sym_all.get(model_name, [])) - if n_structs_analyzed / n_structs_for_model > 0.97: - # skip model if >97% of its structures already analyzed - continue # accounts for structures failing symmetry analysis try: struct_col = next(col for col in df_model if Key.structure in col) except StopIteration: - print(f"No structure column found for {model_name}") + print(f"No structure column found for {model_label}") continue - df_structs[model_name] = { + + # Convert structures + model_structs = { mat_id: Structure.from_dict(struct_dict) for mat_id, struct_dict in df_model[struct_col].items() } - dfs_sym_all[model_name] = analyze_symmetry( - df_structs[model_name].dropna().to_dict(), - pbar=dict(desc=f"Analyzing {model_name} symmetries"), - ) - - -# %% Analyze DFT structures -dft_structs = { - mat_id: Structure.from_dict(cse[Key.structure]) - for mat_id, cse in df_wbm_structs[Key.cse].items() -} -if Key.dft.label not in df_sym: - dfs_sym_all[Key.dft.label] = analyze_symmetry(dft_structs) + # Analyze symmetry + model_analysis = analyze_symmetry( + model_structs, + pbar=dict(desc=f"{prog_str} Analyzing {model_label} symmetries"), + symprec=symprec, + ) -# %% Compare symmetry with DFT reference -n_models = len({*dfs_sym_all} - {Key.dft}) + # Add model results with proper multi-index columns + for col in model_analysis: + df_go[(model_label, col)] = model_analysis[col] -for idx, model_name in enumerate({*dfs_sym_all} - {Key.dft}): - dfs_sym_all[model_name] = pred_vs_ref_struct_symmetry( - dfs_sym_all[model_name], - df_sym[Key.dft.label], - df_structs[model_name].dropna().to_dict(), + # Compare with DFT reference + df_sym_change = pred_vs_ref_struct_symmetry( + model_analysis, + dft_analysis, + model_structs, dft_structs, - pbar=dict(desc=f"{idx+1}/{n_models} Comparing DFT vs {model_name} symmetries"), + pbar=dict(desc=f"{prog_str} Comparing DFT vs {model_label} symmetries"), ) + # Add comparison results with proper multi-index columns + for col in df_sym_change: + df_go[(model_label, col)] = df_sym_change[col] -# %% Combine all dataframes -df_sym = df_sym.join( - pd.concat(dfs_sym_all, axis="columns", names=[Key.model, MbdKey.sym_prop]) -) -df_sym = df_sym.convert_dtypes() + # Save after each model is processed + csv_path = f"{ROOT}/data/{today}-all-models-geo-opt-analysis-{symprec=}.csv.gz" + df_go.convert_dtypes().to_csv(csv_path) + print(f"{prog_str} Completed {model_label} and saved results") + # Free up memory (maybe helps with Jupyter kernel crashes) + del model_analysis, df_sym_change + gc.collect() -# %% Save results to CSV -csv_path = f"{ROOT}/data/{today}-all-models-geo-opt-analysis.csv.gz" -df_sym.to_csv(csv_path) +print("All models processed!") diff --git a/site/src/lib/DiscoveryMetricsTable.svelte b/site/src/lib/DiscoveryMetricsTable.svelte index 24004b5c2..384ee6a81 100644 --- a/site/src/lib/DiscoveryMetricsTable.svelte +++ b/site/src/lib/DiscoveryMetricsTable.svelte @@ -17,28 +17,31 @@ ] export let columns: { label: string; tooltip?: string; style?: string }[] = [ { label: `Model` }, - { label: `F1`, tooltip: `harmonic mean of precision and recall` }, - { label: `DAF`, tooltip: `discovery acceleration factor` }, - { label: `Prec`, tooltip: `precision of classifying thermodynamic stability` }, - { label: `Acc`, tooltip: `accuracy of classifying thermodynamic stability` }, + { label: `F1`, tooltip: `Harmonic mean of precision and recall` }, + { label: `DAF`, tooltip: `Discovery acceleration factor` }, + { label: `Prec`, tooltip: `Precision of classifying thermodynamic stability` }, + { label: `Acc`, tooltip: `Accuracy of classifying thermodynamic stability` }, { label: `TPR`, - tooltip: `true positive rate of classifying thermodynamic stability`, + tooltip: `True positive rate of classifying thermodynamic stability`, }, { label: `TNR`, - tooltip: `true negative rate of classifying thermodynamic stability`, + tooltip: `True negative rate of classifying thermodynamic stability`, }, { label: `MAE`, - tooltip: `mean absolute error of predicting the convex hull distance`, + tooltip: `Mean absolute error of predicting the convex hull distance`, style: `border-left: 1px solid black;`, }, - { label: `RMSE` }, - { label: `R2`, tooltip: `coefficient of determination` }, + { + label: `RMSE`, + tooltip: `Root mean squared error of predicting the convex hull distance`, + }, + { label: `R2`, tooltip: `Coefficient of determination` }, { label: `κSRME`, - tooltip: `symmetric relative mean error in predicted phonon mode contributions to thermal conductivity κ`, + tooltip: `Symmetric relative mean error in predicted phonon mode contributions to thermal conductivity κ`, style: `border-left: 1px solid black;`, }, ...(show_metadata ? metadata_cols : []), diff --git a/site/src/lib/GeoOptMetricsTable.svelte b/site/src/lib/GeoOptMetricsTable.svelte index 8bd97ee06..20602469a 100644 --- a/site/src/lib/GeoOptMetricsTable.svelte +++ b/site/src/lib/GeoOptMetricsTable.svelte @@ -6,45 +6,95 @@ export let show_non_compliant: boolean = false export let show_metadata: boolean = true export let metadata_cols: { label: string; tooltip?: string }[] = [] - export let columns: { label: string; tooltip?: string; style?: string }[] = [ + + // Get all unique symprec values from MODEL_METADATA + $: symprec_values = [ + ...new Set( + MODEL_METADATA.flatMap((model) => + Object.keys(model.metrics?.geo_opt ?? {}) + .filter((key) => key.startsWith(`symprec=`)) + .map((key) => key.replace(`symprec=`, ``)), + ), + ), + ].sort((val1, val2) => parseFloat(val2) - parseFloat(val1)) // Sort in descending order + + // Helper to format symprec in scientific notation + const format_symprec = (symprec: string) => { + const exp = symprec.split(`e-`)[1] + return `symprec=10-${exp}Å` + } + + const sep_line_style = `border-left: 1px solid black` + + // Create columns for each symprec value + $: columns = [ { label: `Model` }, { label: `RMSD`, tooltip: `Root mean squared displacement (in Å) of ML vs DFT relaxed atomic positions as calculated by pymatgen StructureMatcher`, + style: sep_line_style, }, - { - label: `σmatch`, - tooltip: `Fraction of structures where ML and DFT ground state have matching spacegroup`, - }, - { - label: `σdec`, - tooltip: `Fraction of structures where the number of symmetry operations decreased after ML relaxation`, - }, - { - label: `σinc`, - tooltip: `Fraction of structures where the number of symmetry operations increased after ML relaxation`, - }, - { - label: `Nops,MAE`, - tooltip: `Mean absolute error of number of symmetry operations in DFT and ML-relaxed structures`, - }, + // Symmetry match columns + ...symprec_values.flatMap((symprec) => [ + { + group: format_symprec(symprec), + label: `σmatch`, + tooltip: `Fraction of structures where ML and DFT ground state have matching spacegroup (symprec=${symprec}Å)`, + style: sep_line_style, + }, + { + group: format_symprec(symprec), + label: `σdec`, + tooltip: `Fraction of structures where the number of symmetry operations decreased after ML relaxation (symprec=${symprec}Å)`, + }, + { + group: format_symprec(symprec), + label: `σinc`, + tooltip: `Fraction of structures where the number of symmetry operations increased after ML relaxation (symprec=${symprec}Å). Not colored because it's high or low is good or bad. Could be models find higher symmetry lower-energy structures than DFT optimizer.`, + color_scale: null, + }, + { + group: format_symprec(symprec), + label: `Nops,MAE`, + tooltip: `Mean absolute error of number of symmetry operations in DFT and ML-relaxed structures (symprec=${symprec}Å)`, + }, + ]), { label: `Nstructs`, tooltip: `Number of structures relaxed by each model and used to compute these metrics`, + style: sep_line_style, }, ...(show_metadata ? metadata_cols : []), ] + // Create arrays of metric IDs for better/worse + $: higher_is_better = symprec_values.flatMap((symprec) => + geo_opt.higher_is_better.map((metric) => `${metric} (${format_symprec(symprec)})`), + ) + + $: lower_is_better = [ + ...geo_opt.lower_is_better.filter((metric) => !metric.includes(`(`)), + ...symprec_values.flatMap((symprec) => + geo_opt.lower_is_better + .filter((metric) => metric.includes(`<`)) + .map((metric) => `${metric} (${format_symprec(symprec)})`), + ), + ] + // Transform MODEL_METADATA into table data format $: metrics_data = MODEL_METADATA.filter( (model) => (show_non_compliant || model_is_compliant(model)) && - model.metrics?.geo_opt?.rmsd != undefined, + // Check if model has data for all symprec values + symprec_values.every( + (symprec) => model.metrics?.geo_opt?.[`symprec=${symprec}`]?.rmsd != undefined, + ) && + model.model_name !== `BOWSR`, // hide BOWSR as it's a huge outlier that makes the table hard to read ) .sort( (row1, row2) => - (row2?.metrics?.geo_opt?.symmetry_match ?? 0) - - (row1?.metrics?.geo_opt?.symmetry_match ?? 0), + (row2?.metrics?.geo_opt?.[`symprec=${symprec_values[0]}`]?.symmetry_match ?? 0) - + (row1?.metrics?.geo_opt?.[`symprec=${symprec_values[0]}`]?.symmetry_match ?? 0), ) .map((model) => { const geo_opt = model.metrics?.geo_opt @@ -52,31 +102,50 @@ return { Model: `${model.model_name}`, - RMSD: geo_opt.rmsd, - 'σmatch': geo_opt.symmetry_match, - 'σdec': geo_opt.symmetry_decrease, - 'σinc': geo_opt.symmetry_increase, - 'Nops,MAE': geo_opt.n_sym_ops_mae, + RMSD: geo_opt[`symprec=${symprec_values[0]}`].rmsd, + ...symprec_values.reduce( + (acc, symprec) => ({ + ...acc, + [`σmatch (${format_symprec(symprec)})`]: + geo_opt[`symprec=${symprec}`].symmetry_match, + [`σdec (${format_symprec(symprec)})`]: + geo_opt[`symprec=${symprec}`].symmetry_decrease, + [`σinc (${format_symprec(symprec)})`]: + geo_opt[`symprec=${symprec}`].symmetry_increase, + [`Nops,MAE (${format_symprec(symprec)})`]: + geo_opt[`symprec=${symprec}`].n_sym_ops_mae, + }), + {}, + ), 'Nstructs': `${pretty_num(geo_opt.n_structures)}`, + )} structures">${pretty_num(geo_opt[`symprec=${symprec_values[0]}`].n_structures)}`, } }) + + // Update format object + $: format = { + RMSD: `.3f`, + ...symprec_values.reduce( + (acc, symprec) => ({ + ...acc, + [`σmatch (${format_symprec(symprec)})`]: `.1%`, + [`σdec (${format_symprec(symprec)})`]: `.1%`, + [`σinc (${format_symprec(symprec)})`]: `.1%`, + [`Nops,MAE (${format_symprec(symprec)})`]: `.3`, + }), + {}, + ), + } match': `.1%`, - 'σdec': `.1%`, - 'σinc': `.1%`, - 'Nops,MAE': `.3f`, - }} + {higher_is_better} + {lower_is_better} + {format} {...$$restProps} /> diff --git a/site/src/lib/HeatmapTable.svelte b/site/src/lib/HeatmapTable.svelte index aefb9b154..0840c6e1b 100644 --- a/site/src/lib/HeatmapTable.svelte +++ b/site/src/lib/HeatmapTable.svelte @@ -11,7 +11,13 @@ type TableData = Record[] export let data: TableData - export let columns: { label: string; tooltip?: string; style?: string }[] = [] + export let columns: { + group?: string + label: string + tooltip?: string + style?: string + color_scale?: keyof typeof d3sc + }[] = [] export let higher_is_better: string[] = [] export let lower_is_better: string[] = [] export let sticky_cols: number[] = [0] // default to sticky first column @@ -26,19 +32,26 @@ $: clean_data = data?.filter?.((row) => Object.values(row).some((val) => val !== undefined)) ?? [] + // Helper to make column IDs (needed since column labels in different groups can be repeated) + const get_col_id = (col: { group?: string; label: string }) => + col.group ? `${col.label} (${col.group})` : col.label + function sort_rows(column: string) { - if ($sort_state.column !== column) { + const col = columns.find((c) => c.label === column) + const col_id = get_col_id(col) + + if ($sort_state.column !== col_id) { $sort_state = { - column, - ascending: lower_is_better.includes(column), + column: col_id, + ascending: lower_is_better.includes(col_id), } } else { $sort_state.ascending = !$sort_state.ascending } clean_data = clean_data.sort((row1, row2) => { - const val1 = row1[column] - const val2 = row2[column] + const val1 = row1[col_id] + const val2 = row2[col_id] if (val1 === val2) return 0 if (val1 === null || val1 === undefined) return 1 @@ -49,46 +62,77 @@ }) } - function calc_color(value: number | string | undefined, col: string) { - const values = clean_data.map((row) => row[col]) + function calc_color( + value: number | string | undefined, + col: { group?: string; label: string }, + ) { + if (col.color_scale === null || typeof value !== `number`) + return { bg: null, text: null } + + const col_id = get_col_id(col) + const values = clean_data.map((row) => row[col_id]) const range = [min(values) ?? 0, max(values) ?? 1] - if (lower_is_better.includes(col)) { + if (lower_is_better.includes(col_id)) { range.reverse() } - const colorScale = scaleSequential() - .domain(range) - .interpolator(d3sc.interpolateViridis) - const bg = colorScale(value) + // Use custom color scale if specified, otherwise fall back to viridis + const scale_name = col.color_scale || `interpolateViridis` + const interpolator = d3sc[scale_name] || d3sc.interpolateViridis + + const color_scale = scaleSequential().domain(range).interpolator(interpolator) + + const bg = color_scale(value) const text = choose_bw_for_contrast(null, bg) return { bg, text } } $: visible_columns = columns.filter((col) => !hide_cols.includes(col.label)) + + const sort_indicator = (col: { group?: string; label: string }) => { + const col_id = get_col_id(col) + if ($sort_state.column === col_id) { + return `${$sort_state.ascending ? `↑` : `↓`}` + } else if (higher_is_better.includes(col_id) || lower_is_better.includes(col_id)) { + return `${ + higher_is_better.includes(col_id) ? `↑` : `↓` + }` + } + return `` + }
+ + {#if visible_columns.some((col) => col.group)} + + + {#each visible_columns as col} + {#if !col.group} + + {/if} + {/if} + {/each} + + {/if} + - {#each visible_columns as { label, tooltip = null, style = null }, col_idx} - {/each} @@ -96,19 +140,19 @@ {#each clean_data as row (JSON.stringify(row))} - {#each visible_columns as { label, style = null }, col_idx} - {@const val = row[label]} - {@const color = calc_color(val, label)} + {#each visible_columns as col, col_idx} + {@const val = row[get_col_id(col)]} + {@const color = calc_color(val, col)}
+ {:else} + {@const group_cols = visible_columns.filter((c) => c.group === col.group)} + {#if columns.indexOf(col) === columns.findIndex((c) => c.group === col.group)} + {@html col.group}
sort_rows(label)} title={tooltip} {style}> - {@html label} + {#each visible_columns as col, col_idx} + sort_rows(col.label)} style={col.style}> + {@html col.label} + {@html sort_indicator(col)} {#if col_idx == 0 && sort_hint} {/if} - {#if $sort_state.column === label} - - {$sort_state.ascending ? `↑` : `↓`} - - {:else if higher_is_better.includes(label) || lower_is_better.includes(label)} - - {higher_is_better.includes(label) ? `↓` : `↑`} - - {/if}
- {#if typeof val === `number` && format[label]} - {@html pretty_num(val, format[label])} + {#if typeof val === `number` && format[get_col_id(col)]} + {@html pretty_num(val, format[get_col_id(col)])} {:else if [undefined, null].includes(val)} n/a {:else} @@ -176,4 +220,9 @@ td[data-sort-value] { cursor: default; } + + .group-header th { + border-bottom: 1px solid black; + text-align: center; + } diff --git a/site/src/lib/model-schema.d.ts b/site/src/lib/model-schema.d.ts index b08b6c43f..46afa64e0 100644 --- a/site/src/lib/model-schema.d.ts +++ b/site/src/lib/model-schema.d.ts @@ -89,12 +89,30 @@ export interface ModelMetadata { | { pred_file: string | null pred_col: string | null - rmsd?: number - n_sym_ops_mae?: number - symmetry_decrease?: number - symmetry_match?: number - symmetry_increase?: number - n_structures?: number + 'symprec=1e-5'?: { + rmsd?: number + n_sym_ops_mae?: number + symmetry_decrease?: number + symmetry_match?: number + symmetry_increase?: number + n_structures?: number + } + 'symprec=1e-3'?: { + rmsd?: number + n_sym_ops_mae?: number + symmetry_decrease?: number + symmetry_match?: number + symmetry_increase?: number + n_structures?: number + } + 'symprec=1e-2'?: { + rmsd?: number + n_sym_ops_mae?: number + symmetry_decrease?: number + symmetry_match?: number + symmetry_increase?: number + n_structures?: number + } } | ('not applicable' | 'not available') discovery?: { diff --git a/site/src/routes/+layout.svelte b/site/src/routes/+layout.svelte index ad8f0183a..808b3736c 100644 --- a/site/src/routes/+layout.svelte +++ b/site/src/routes/+layout.svelte @@ -5,7 +5,7 @@ import { repository } from '$site/package.json' import { CmdPalette } from 'svelte-multiselect' import Toc from 'svelte-toc' - import { CopyButton, GitHubCorner, PrevNext } from 'svelte-zoo' + import { CopyButton, GitHubCorner } from 'svelte-zoo' import '../app.css' const routes = Object.keys(import.meta.glob(`./*/+page.{svelte,md}`)).map( @@ -83,13 +83,4 @@ - - {href} → - ← {href} - -