Skip to content

Commit b79da6f

Browse files
committed
add phonon bands
1 parent 0b69d59 commit b79da6f

1 file changed

Lines changed: 206 additions & 0 deletions

File tree

src/lavello_mlips/phonon_bands.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# -*- coding: utf-8 -*-
2+
# Author; alin m elena, alin@elena.re
3+
# Contribs;
4+
# Date: 21-03-2025
5+
# ©alin m elena, GPL v3 https://www.gnu.org/licenses/gpl-3.0.en.html
6+
import yaml
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
import h5py
10+
import lzma
11+
import argparse
12+
from pathlib import Path
13+
from sklearn.metrics import root_mean_squared_error as rmse
14+
15+
def main():
16+
# Parse arguments:
17+
parser = argparse.ArgumentParser(
18+
description="distributions"
19+
)
20+
parser.add_argument(
21+
"--bands",
22+
nargs="+",
23+
help="input bands files, output from some calculations",
24+
)
25+
parser.add_argument(
26+
"--title",
27+
default="xxx",
28+
help="title for the graph",
29+
)
30+
31+
parser.add_argument(
32+
"--fmin",
33+
type=float,
34+
help="min frequency",
35+
)
36+
37+
parser.add_argument(
38+
"--fmax",
39+
type=float,
40+
help="max frequency",
41+
)
42+
43+
parser.add_argument("--dft", help="input dft bands file for comparison", default=None)
44+
parser.add_argument("--ml_labels", nargs="+", help="labels for ml bands", default=None)
45+
parser.add_argument("--dft_label", help="label for dft bands", default="DFT")
46+
parser.add_argument("--save", help="File to save the plot", default=None)
47+
args = parser.parse_args()
48+
49+
title = args.title
50+
save_file = args.save
51+
bands = args.bands
52+
fmin = args.fmin
53+
fmax = args.fmax
54+
dft_file = args.dft
55+
ml_labels = args.ml_labels
56+
dft_label_text = args.dft_label
57+
58+
assert bands is not None and len(bands) > 0
59+
if ml_labels is not None:
60+
assert len(ml_labels) == len(bands), "Number of labels must match number of band files"
61+
62+
data_list = []
63+
nqpoint = None
64+
labels = None
65+
sp = None
66+
67+
for band_file in bands:
68+
p = Path(band_file)
69+
assert p.exists(), f"File {band_file} does not exist"
70+
71+
ext = p.suffix
72+
print(f"{ext}")
73+
data = None
74+
if ext == '.xz':
75+
with lzma.open(p, 'r') as file:
76+
dc = file.read()
77+
data = yaml.safe_load(dc)
78+
elif ext == '.hdf5':
79+
data = h5py.File(p, 'r')
80+
print(f"{list(data.keys())}")
81+
else:
82+
with open(p, 'r') as file:
83+
data = yaml.safe_load(file)
84+
85+
if ext==".hdf5":
86+
if nqpoint is None:
87+
nqpoint = data["nqpoint"][:][0]
88+
labels = [ [ y.decode('utf-8') for y in list(x)] for x in data['label'][:] ]
89+
sp = data['segment_nqpoint'][:][0]
90+
num_modes = data["natom"][()]*3
91+
f = data['frequency'][:]
92+
frequencies = f.reshape(-1,f.shape[-1])
93+
else:
94+
if nqpoint is None:
95+
nqpoint = data["nqpoint"]
96+
labels = data['labels']
97+
sp = data['segment_nqpoint'][0]
98+
num_modes = data["natom"]*3
99+
frequencies = np.array([[band["frequency"] for band in phonon["band"]] for phonon in data["phonon"]])
100+
101+
data_list.append({'frequencies': frequencies, 'num_modes': num_modes})
102+
103+
dft_data = None
104+
if dft_file:
105+
p = Path(dft_file)
106+
assert p.exists(), f"DFT file {dft_file} does not exist"
107+
ext = p.suffix
108+
if ext == '.xz':
109+
with lzma.open(p, 'r') as file:
110+
dc = file.read()
111+
data = yaml.safe_load(dc)
112+
elif ext == '.hdf5':
113+
data = h5py.File(p, 'r')
114+
else:
115+
with open(p, 'r') as file:
116+
data = yaml.safe_load(file)
117+
118+
if ext == ".hdf5":
119+
f = data['frequency'][:]
120+
dft_frequencies = f.reshape(-1, f.shape[-1])
121+
num_modes = data["natom"][()]*3
122+
else:
123+
dft_frequencies = np.array([[band["frequency"] for band in phonon["band"]] for phonon in data["phonon"]])
124+
num_modes = data["natom"]*3
125+
dft_data = {'frequencies': dft_frequencies, 'num_modes': num_modes}
126+
127+
k_points = np.arange(nqpoint)
128+
129+
npa= -1
130+
seg_labels = {}
131+
seg_tick = {}
132+
for i,seg in enumerate(labels):
133+
if i > 0 and seg[0] == labels[i-1][1]:
134+
seg_labels[npa] += [seg[1]]
135+
else:
136+
npa += 1
137+
seg_labels[npa] = seg
138+
139+
for k in seg_labels:
140+
if k > 0:
141+
seg_tick[k] = [ seg_tick[k-1][-1]+i*sp for i in range(len(seg_labels[k]))]
142+
else:
143+
seg_tick[k] = [ i*sp for i in range(len(seg_labels[k]))]
144+
145+
npa += 1
146+
fs=8
147+
fsize=40
148+
# Add constant 4 inches to height for title/legend, and 3 inches to width for massive Y-axis labels
149+
fig, axs = plt.subplots(nrows=1, ncols=npa, figsize=(npa*fs + 3, fs + 4), squeeze=False, subplot_kw=dict(box_aspect=1))
150+
151+
colors = plt.cm.tab10.colors
152+
153+
for i in range(npa):
154+
for idx, d in enumerate(data_list):
155+
frequencies = d['frequencies']
156+
num_modes = d['num_modes']
157+
c = colors[idx % len(colors)]
158+
for mode in range(num_modes):
159+
label = Path(bands[idx]).stem if mode == 0 and i == 0 else None
160+
axs[0,i].plot(k_points[seg_tick[i][0]:seg_tick[i][-1]], frequencies[seg_tick[i][0]:seg_tick[i][-1], mode], color=c, alpha=0.5, linewidth=3, label=label)
161+
162+
if dft_data is not None:
163+
dft_frequencies = dft_data['frequencies']
164+
num_modes = dft_data['num_modes']
165+
for mode in range(num_modes):
166+
label = dft_label_text if mode == 0 and i == 0 else None
167+
axs[0,i].plot(k_points[seg_tick[i][0]:seg_tick[i][-1]], dft_frequencies[seg_tick[i][0]:seg_tick[i][-1], mode], color='red', alpha=0.5, linewidth=3, linestyle='--', label=label)
168+
169+
axs[0,i].tick_params(axis='both', labelsize=fsize)
170+
axs[0,i].set_xticks(seg_tick[i], labels=seg_labels[i])
171+
if fmin is not None and fmax is not None:
172+
axs[0,i].set_ylim([fmin, fmax])
173+
axs[0,i].set_xlim([k_points[seg_tick[i][0]], np.max(k_points[seg_tick[i][0]:seg_tick[i][-1]])+1])
174+
if i == 0:
175+
axs[0,i].set_ylabel(f'Frequency [THz]', fontsize=fsize)
176+
else:
177+
axs[0,i].set_yticklabels([])
178+
179+
if len(data_list) >= 1:
180+
total_items = len(data_list)
181+
for idx, d in enumerate(data_list):
182+
label = ml_labels[idx] if ml_labels is not None else Path(bands[idx]).stem
183+
c = colors[idx % len(colors)]
184+
185+
srmse = ""
186+
if dft_data is not None:
187+
val_rmse = rmse(dft_data['frequencies'], d['frequencies'])
188+
srmse = f" (RMSE: {val_rmse:.4f})"
189+
print(f"{label} RMSE: {val_rmse:.4f}")
190+
191+
axs[0,0].text(0.0, 1.05 + (total_items - 1 - idx)*0.08, label + srmse, color=c, fontsize=fsize//2,
192+
transform=axs[0,0].transAxes, verticalalignment='bottom')
193+
194+
if dft_data is not None:
195+
axs[0,0].text(0.0, 1.05 + total_items*0.08, dft_label_text, color='red', fontsize=fsize//2,
196+
transform=axs[0,0].transAxes, verticalalignment='bottom')
197+
198+
plt.suptitle(title,fontsize=fsize)
199+
plt.tight_layout(pad=1.5)
200+
if args.save is None:
201+
plt.show()
202+
else:
203+
plt.savefig(f"{save_file}",transparent=True, bbox_inches='tight')
204+
205+
if __name__ == "__main__":
206+
main()

0 commit comments

Comments
 (0)