From 42ebece6dcdcfeb670ae0acdfc99869bc78e2846 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 22 Jan 2024 20:49:30 +0100 Subject: [PATCH] add test for NEBAnalysis.get_plot() (#3570) --- pymatgen/analysis/transition_state.py | 29 ++++++++++--------------- tests/analysis/test_transition_state.py | 18 +++++++++++---- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/pymatgen/analysis/transition_state.py b/pymatgen/analysis/transition_state.py index 5009ae4308b..7867d3a1691 100644 --- a/pymatgen/analysis/transition_state.py +++ b/pymatgen/analysis/transition_state.py @@ -166,32 +166,25 @@ def get_plot(self, normalize_rxn_coordinate: bool = True, label_barrier: bool = plt.Axes: matplotlib axes object. """ ax = pretty_plot(12, 8) - scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1] - x = np.arange(0, np.max(self.r), 0.01) - y = self.spline(x) * 1000 + scale = 1 / self.r[-1] if normalize_rxn_coordinate else 1 + xs = np.arange(0, np.max(self.r), 0.01) + ys = self.spline(xs) * 1000 relative_energies = self.energies - self.energies[0] - ax.plot( - self.r * scale, - relative_energies * 1000, - "ro", - x * scale, - y, - "k-", - linewidth=2, - markersize=10, - ) - ax.set_xlabel("Reaction coordinate") + ax.plot(self.r * scale, relative_energies * 1000, "ro", xs * scale, ys, "k-", linewidth=2, markersize=10) + + ax.set_xlabel("Reaction Coordinate") ax.set_ylabel("Energy (meV)") - ax.set_ylim((np.min(y) - 10, np.max(y) * 1.02 + 20)) + ax.set_ylim((np.min(ys) - 10, np.max(ys) * 1.02 + 20)) if label_barrier: - data = zip(x * scale, y) + data = zip(xs * scale, ys) barrier = max(data, key=lambda d: d[1]) - ax.plot([0, barrier[0]], [barrier[1], barrier[1]], "k--") + ax.plot([0, barrier[0]], [barrier[1], barrier[1]], "k--", linewidth=0.5) ax.annotate( - f"{np.max(y) - np.min(y):.0f} meV", + f"{np.max(ys) - np.min(ys):.0f} meV", xy=(barrier[0] / 2, barrier[1] * 1.02), xytext=(barrier[0] / 2, barrier[1] * 1.02), horizontalalignment="center", + fontsize=18, ) plt.tight_layout() return ax diff --git a/tests/analysis/test_transition_state.py b/tests/analysis/test_transition_state.py index f27cbd0eb7e..0b0267c5577 100644 --- a/tests/analysis/test_transition_state.py +++ b/tests/analysis/test_transition_state.py @@ -2,6 +2,7 @@ import json +from matplotlib import pyplot as plt from numpy.testing import assert_allclose from pymatgen.analysis.transition_state import NEBAnalysis, combine_neb_plots @@ -19,12 +20,12 @@ __date__ = "2/5/16" -test_dir = f"{TEST_FILES_DIR}/neb_analysis" +TEST_DIR = f"{TEST_FILES_DIR}/neb_analysis" class TestNEBAnalysis(PymatgenTest): def test_run(self): - neb_analysis1 = NEBAnalysis.from_dir(f"{test_dir}/neb1/neb") + neb_analysis1 = NEBAnalysis.from_dir(f"{TEST_DIR}/neb1/neb") neb_analysis1_from_dict = NEBAnalysis.from_dict(neb_analysis1.as_dict()) json_data = json.dumps(neb_analysis1.as_dict()) @@ -47,7 +48,7 @@ def test_run(self): neb_analysis1.setup_spline(spline_options={"saddle_point": "zero_slope"}) assert_allclose(neb_analysis1.get_extrema()[1][0], (0.50023335723480078, 325.20003984140203)) - with open(f"{test_dir}/neb2/neb_analysis2.json") as f: + with open(f"{TEST_DIR}/neb2/neb_analysis2.json") as f: neb_analysis2_dict = json.load(f) neb_analysis2 = NEBAnalysis.from_dict(neb_analysis2_dict) assert_allclose(neb_analysis2.get_extrema()[1][0], (0.37255257367467326, 562.40825334519991)) @@ -56,6 +57,15 @@ def test_run(self): assert_allclose(neb_analysis2.get_extrema()[1][0], (0.30371133723478794, 528.46229631648691)) def test_combine_neb_plots(self): - neb_dir = f"{test_dir}/neb1/neb" + neb_dir = f"{TEST_DIR}/neb1/neb" neb_analysis = NEBAnalysis.from_dir(neb_dir) combine_neb_plots([neb_analysis, neb_analysis]) + + def test_get_plot(self): + neb_dir = f"{TEST_DIR}/neb1/neb" + neb_analysis = NEBAnalysis.from_dir(neb_dir) + ax = neb_analysis.get_plot() + assert isinstance(ax, plt.Axes) + assert ax.texts[0].get_text() == "326 meV", "Unexpected annotation text" + assert ax.get_xlabel() == "Reaction Coordinate" + assert ax.get_ylabel() == "Energy (meV)"