Skip to content

Commit 637793e

Browse files
committed
Restructured RMA MPI loop, fixed download hang, added ExtXYZ property embedding, robust --restart orphan recovery, and corresponding tests
1 parent 7b212d2 commit 637793e

1 file changed

Lines changed: 50 additions & 5 deletions

File tree

src/process_omol25/process_omol25.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,47 @@ def handle_signal(signum, frame):
284284
with open(self.files_to_process, "r", encoding="utf-8") as f:
285285
self.data = json_load(f)
286286

287+
if self.restart:
288+
if self.rank == 0:
289+
logger.info("Scanning output directory to recover already processed files...")
290+
recovered_count = 0
291+
parquet_files = list(self.output_dir.glob(f"props_{self.group_name}*.parquet"))
292+
293+
from ase.io import read as ase_read
294+
for pf in parquet_files:
295+
xyz_file = pf.parent / pf.name.replace("props_", "structs_").replace(".parquet", ".xyz")
296+
297+
try:
298+
if not xyz_file.exists():
299+
raise FileNotFoundError(f"Missing ExtXYZ partner: {xyz_file.name}")
300+
301+
df = pd.read_parquet(pf, columns=["argonne_rel"])
302+
303+
atoms_list = ase_read(str(xyz_file), index=":")
304+
if not isinstance(atoms_list, list):
305+
atoms_list = [atoms_list] if atoms_list else []
306+
307+
if len(df) != len(atoms_list):
308+
raise ValueError(f"Length mismatch: {len(df)} Parquet rows vs {len(atoms_list)} XYZ structures")
309+
310+
# If we pass all checks, mark as processed!
311+
for x in df["argonne_rel"].dropna().unique():
312+
if x in self.data and not self.data[x].get("processed", False):
313+
self.data[x]["processed"] = True
314+
recovered_count += 1
315+
316+
except Exception as e:
317+
# If a part file is corrupt/orphaned, delete it so _final_merge ignores it
318+
logger.warning(f"Discarding corrupt/incomplete part {pf.name}: {e}")
319+
pf.unlink(missing_ok=True)
320+
xyz_file.unlink(missing_ok=True)
321+
322+
if recovered_count > 0:
323+
logger.info(f"Recovered {recovered_count} previously processed items from disk, and discarded invalid parts.")
324+
325+
if self.size > 1 and self.comm is not None:
326+
self.data = self.comm.bcast(self.data, root=0)
327+
287328
self.prefixes = [x for x in self.data if not self.data[x].get("processed", False)]
288329

289330
if self.rank == 0:
@@ -447,10 +488,10 @@ def _manager_loop(self):
447488
processed_count = 0
448489
pbar = tqdm(total=num_tasks, desc="Total Progress")
449490

450-
# We wait for results until all tasks are accounted for OR all workers finished
491+
# We wait for results until all workers have finished
451492
active_workers = self.size - 1
452493

453-
while active_workers > 0 or processed_count < num_tasks:
494+
while active_workers > 0:
454495
status = MPI.Status()
455496
# We use ANY_TAG because workers might send TAG_RESULT or TAG_DONE
456497
msg = self.comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status)
@@ -459,10 +500,11 @@ def _manager_loop(self):
459500

460501
if tag == TAG_RESULT and isinstance(msg, tuple):
461502
rec, x = msg
462-
if x in self.data:
503+
if rec is not None and x in self.data:
463504
self.data[x]['processed'] = True
464-
processed_count += 1
465-
pbar.update(1)
505+
# Always advance the progress bar, even on failure
506+
processed_count += 1
507+
pbar.update(1)
466508
elif tag == TAG_DONE:
467509
active_workers -= 1
468510
logger.debug(f"Worker {source} finished. {active_workers} remaining.")
@@ -525,6 +567,9 @@ def _worker_loop(self):
525567
self.flush_recs(recs, all_atoms)
526568
recs = []
527569
all_atoms = []
570+
else:
571+
x = self.prefixes[idx]
572+
self.comm.send((None, x), dest=0, tag=TAG_RESULT)
528573

529574
# Final flush and signal manager
530575
if recs: self.flush_recs(recs, all_atoms)

0 commit comments

Comments
 (0)