@@ -284,6 +284,47 @@ def handle_signal(signum, frame):
284284 with open (self .files_to_process , "r" , encoding = "utf-8" ) as f :
285285 self .data = json_load (f )
286286
287+ if self .restart :
288+ if self .rank == 0 :
289+ logger .info ("Scanning output directory to recover already processed files..." )
290+ recovered_count = 0
291+ parquet_files = list (self .output_dir .glob (f"props_{ self .group_name } *.parquet" ))
292+
293+ from ase .io import read as ase_read
294+ for pf in parquet_files :
295+ xyz_file = pf .parent / pf .name .replace ("props_" , "structs_" ).replace (".parquet" , ".xyz" )
296+
297+ try :
298+ if not xyz_file .exists ():
299+ raise FileNotFoundError (f"Missing ExtXYZ partner: { xyz_file .name } " )
300+
301+ df = pd .read_parquet (pf , columns = ["argonne_rel" ])
302+
303+ atoms_list = ase_read (str (xyz_file ), index = ":" )
304+ if not isinstance (atoms_list , list ):
305+ atoms_list = [atoms_list ] if atoms_list else []
306+
307+ if len (df ) != len (atoms_list ):
308+ raise ValueError (f"Length mismatch: { len (df )} Parquet rows vs { len (atoms_list )} XYZ structures" )
309+
310+ # If we pass all checks, mark as processed!
311+ for x in df ["argonne_rel" ].dropna ().unique ():
312+ if x in self .data and not self .data [x ].get ("processed" , False ):
313+ self .data [x ]["processed" ] = True
314+ recovered_count += 1
315+
316+ except Exception as e :
317+ # If a part file is corrupt/orphaned, delete it so _final_merge ignores it
318+ logger .warning (f"Discarding corrupt/incomplete part { pf .name } : { e } " )
319+ pf .unlink (missing_ok = True )
320+ xyz_file .unlink (missing_ok = True )
321+
322+ if recovered_count > 0 :
323+ logger .info (f"Recovered { recovered_count } previously processed items from disk, and discarded invalid parts." )
324+
325+ if self .size > 1 and self .comm is not None :
326+ self .data = self .comm .bcast (self .data , root = 0 )
327+
287328 self .prefixes = [x for x in self .data if not self .data [x ].get ("processed" , False )]
288329
289330 if self .rank == 0 :
@@ -447,10 +488,10 @@ def _manager_loop(self):
447488 processed_count = 0
448489 pbar = tqdm (total = num_tasks , desc = "Total Progress" )
449490
450- # We wait for results until all tasks are accounted for OR all workers finished
491+ # We wait for results until all workers have finished
451492 active_workers = self .size - 1
452493
453- while active_workers > 0 or processed_count < num_tasks :
494+ while active_workers > 0 :
454495 status = MPI .Status ()
455496 # We use ANY_TAG because workers might send TAG_RESULT or TAG_DONE
456497 msg = self .comm .recv (source = MPI .ANY_SOURCE , tag = MPI .ANY_TAG , status = status )
@@ -459,10 +500,11 @@ def _manager_loop(self):
459500
460501 if tag == TAG_RESULT and isinstance (msg , tuple ):
461502 rec , x = msg
462- if x in self .data :
503+ if rec is not None and x in self .data :
463504 self .data [x ]['processed' ] = True
464- processed_count += 1
465- pbar .update (1 )
505+ # Always advance the progress bar, even on failure
506+ processed_count += 1
507+ pbar .update (1 )
466508 elif tag == TAG_DONE :
467509 active_workers -= 1
468510 logger .debug (f"Worker { source } finished. { active_workers } remaining." )
@@ -525,6 +567,9 @@ def _worker_loop(self):
525567 self .flush_recs (recs , all_atoms )
526568 recs = []
527569 all_atoms = []
570+ else :
571+ x = self .prefixes [idx ]
572+ self .comm .send ((None , x ), dest = 0 , tag = TAG_RESULT )
528573
529574 # Final flush and signal manager
530575 if recs : self .flush_recs (recs , all_atoms )
0 commit comments