From 8bbedee5632dbd691fe4db2f895359710db5260c Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 31 Dec 2025 09:18:34 -0500 Subject: [PATCH 1/3] Add snapshot_save_old global param, and append to truncated existing clone history file in restarts --- pymatnext/cli/sample.py | 66 +++++++++++++++++++++----------------- pymatnext/ns.py | 4 +-- pymatnext/sample_params.py | 1 + 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/pymatnext/cli/sample.py b/pymatnext/cli/sample.py index 73d86bb..aadfd48 100755 --- a/pymatnext/cli/sample.py +++ b/pymatnext/cli/sample.py @@ -166,42 +166,43 @@ def sample(args, MPI, NS_comm, walker_comm): traj_interval = params_global["traj_interval"] sample_interval = params_global["sample_interval"] snapshot_interval = params_global["snapshot_interval"] + snapshot_save_old = params_global["snapshot_save_old"] stdout_report_interval_s = params_global["stdout_report_interval_s"] step_size_tune_interval = params_step_size_tune["interval"] - # WARNING: clone_history_file not restartable - if params_global["clone_history"]: - clone_history_file = open(f"{output_filename_prefix}.clone_history", "w") - clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n') - else: - clone_history_file = None ns_file_name = f"{output_filename_prefix}.NS_samples" + clone_history_file_name = f"{output_filename_prefix}.clone_history" if params_global["clone_history"] else None traj_file_name = f"{output_filename_prefix}.traj{config_suffix}" if NS_comm.rank == 0: if ns.snapshot_iter >= 0: - # snapshot, truncate existing NS_samples and .traj.suffix files + # snapshot, truncate existing .NS_samples, .clone_history, and .traj. files # NOTE: does this code belong here? Maybe refactor to a function, maybe # move trajectory truncation into NSConfig or something? - # truncate .NS_samples file - f_samples = open(ns_file_name, "r+") - # skip header - _ = f_samples.readline() - line_i = None - while True: - line = f_samples.readline() - if not line: - raise RuntimeError(f"Failed to find enough lines in .NS_samples file (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}") - - line_i = int(line.split()[0]) - if line_i + sample_interval > ns.snapshot_iter: - cur_pos = f_samples.tell() - f_samples.truncate(cur_pos) - break - - f_samples.close() + def _truncate_file(filename, n_header, sample_interval): + warnings.warn(f"Truncating {filename}") + # truncate .NS_samples file + with open(ns_file_name, "r+") as fd: + # skip header + for _ in range(n_header): + _ = fd.readline() + line_i = None + while True: + line = fd.readline() + if not line: + raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}") + + line_i = int(line.split()[0]) + if line_i + sample_interval > ns.snapshot_iter: + warnings.warn(f"Truncated {filename} at iter {line_i}") + cur_pos = fd.tell() + fd.truncate(cur_pos) + break + + _truncate_file(ns_file_name, n_header=1, sample_interval=sample_interval) + _truncate_file(clone_history_file_name, n_header=1, sample_interval=1) # truncate .traj.suffix file f_configs = open(traj_file_name, "r+") @@ -220,15 +221,22 @@ def sample(args, MPI, NS_comm, walker_comm): ns_file = open(ns_file_name, "a") traj_file = open(traj_file_name, "a") + clone_history_file = open(clone_history_file_name, "a") else: - # run from start, open new .NS_samples and .traj.suffix files - + # run from start, open new .NS_samples, .clone_history, and .traj. files + # write header as needed ns_file = open(ns_file_name, "w") header_dict = { "n_walkers": ns.n_configs_global, "n_cull": 1 } header_dict.update(ns.local_configs[0].header_dict()) ns_file.write("# " + " ".join(json.dumps(header_dict, indent=0).splitlines()) + "\n") + if clone_history_file_name: + clone_history_file = open(clone_history_fine_name, "w") + clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n') + else: + clone_history_file = None + traj_file = open(traj_file_name, "w") max_iter = params_global["max_iter"] @@ -265,7 +273,7 @@ def sample(args, MPI, NS_comm, walker_comm): global_ind_of_clone_source = (global_ind_of_max + 1 + ns.rng_global.integers(0, ns.n_configs_global - 1)) % ns.n_configs_global rank_of_clone_source, local_ind_of_clone_source = ns.local_ind(global_ind_of_clone_source) - if clone_history_file is not None: + if clone_history_file: clone_history_file.write(f"{loop_iter} {global_ind_of_clone_source} {global_ind_of_max}\n") if loop_iter % 1000 == 1000 - 1: clone_history_file.flush() @@ -326,11 +334,11 @@ def sample(args, MPI, NS_comm, walker_comm): # NOTE: should this be a time rather than iteration interval? That'd basically be straightforward, # except it would require an additional communication so all processes agree that it's time for a snapshot if loop_iter > 0 and snapshot_interval > 0 and loop_iter % snapshot_interval == 0: - ns.snapshot(loop_iter, output_filename_prefix) + ns.snapshot(loop_iter, output_filename_prefix, save_old=snapshot_save_old) loop_iter += 1 - if clone_history_file is not None: + if clone_history_file: clone_history_file.close() diff --git a/pymatnext/ns.py b/pymatnext/ns.py index af89d75..0b6a2e8 100644 --- a/pymatnext/ns.py +++ b/pymatnext/ns.py @@ -480,11 +480,11 @@ def snapshot(self, loop_iter, output_filename_prefix, save_old=2): output_filename_prefix: str initial part of filenames that will be written to save_old: int, default 2 - number of old snapshots to save + number of old snapshots to save, negative to save all """ if self.comm.rank == 0: old_state_files = NS._old_state_files(output_filename_prefix) - if len(old_state_files) > save_old - 1: + if save_old >= 0 and len(old_state_files) > save_old - 1: old_state_files = old_state_files[:-(save_old-1)] else: old_state_files = [] diff --git a/pymatnext/sample_params.py b/pymatnext/sample_params.py index 8c3637f..3031184 100644 --- a/pymatnext/sample_params.py +++ b/pymatnext/sample_params.py @@ -11,6 +11,7 @@ "sample_interval": 1, "traj_interval": 100, "snapshot_interval": 10000, + "snapshot_save_old": 2, "step_size_tune": { "interval": 1000, "n_configs": 1, From 755576c0acd35b58c8baabe4038bf742e99503ca Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 31 Dec 2025 11:00:08 -0500 Subject: [PATCH 2/3] typo --- pymatnext/cli/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymatnext/cli/sample.py b/pymatnext/cli/sample.py index aadfd48..0141b0a 100755 --- a/pymatnext/cli/sample.py +++ b/pymatnext/cli/sample.py @@ -232,7 +232,7 @@ def _truncate_file(filename, n_header, sample_interval): ns_file.write("# " + " ".join(json.dumps(header_dict, indent=0).splitlines()) + "\n") if clone_history_file_name: - clone_history_file = open(clone_history_fine_name, "w") + clone_history_file = open(clone_history_file_name, "w") clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n') else: clone_history_file = None From 61b11bfc2b7222fe446ce11a368883ebf88d0bd8 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 31 Dec 2025 12:08:16 -0500 Subject: [PATCH 3/3] refactor truncation of samples, clone history files --- pymatnext/cli/sample.py | 46 +++++++++++++++------------------------ pymatnext/ns_utils.py | 0 pymatnext/sample_utils.py | 23 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 29 deletions(-) mode change 100755 => 100644 pymatnext/ns_utils.py mode change 100755 => 100644 pymatnext/sample_utils.py diff --git a/pymatnext/cli/sample.py b/pymatnext/cli/sample.py index 0141b0a..e33fbf0 100755 --- a/pymatnext/cli/sample.py +++ b/pymatnext/cli/sample.py @@ -18,6 +18,7 @@ from pymatnext.ns import NS from pymatnext.params import check_fill_defaults from pymatnext.sample_params import param_defaults +from pymatnext.sample_utils import truncate_file_first_col_iter from pymatnext.loop_exit import NSLoopExit @@ -175,36 +176,15 @@ def sample(args, MPI, NS_comm, walker_comm): traj_file_name = f"{output_filename_prefix}.traj{config_suffix}" if NS_comm.rank == 0: + # set up I/O if ns.snapshot_iter >= 0: # snapshot, truncate existing .NS_samples, .clone_history, and .traj. files - # NOTE: does this code belong here? Maybe refactor to a function, maybe - # move trajectory truncation into NSConfig or something? - - def _truncate_file(filename, n_header, sample_interval): - warnings.warn(f"Truncating {filename}") - # truncate .NS_samples file - with open(ns_file_name, "r+") as fd: - # skip header - for _ in range(n_header): - _ = fd.readline() - line_i = None - while True: - line = fd.readline() - if not line: - raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}") - - line_i = int(line.split()[0]) - if line_i + sample_interval > ns.snapshot_iter: - warnings.warn(f"Truncated {filename} at iter {line_i}") - cur_pos = fd.tell() - fd.truncate(cur_pos) - break - - _truncate_file(ns_file_name, n_header=1, sample_interval=sample_interval) - _truncate_file(clone_history_file_name, n_header=1, sample_interval=1) - - # truncate .traj.suffix file + truncate_file_first_col_iter(ns_file_name, n_header=1, sample_interval=sample_interval, max_iter=ns.snapshot_iter) + truncate_file_first_col_iter(clone_history_file_name, n_header=1, sample_interval=1, max_iter=ns.snapshot_iter) + + # NOTE: should move trajectory truncation into NSConfig, since it's config file-format specific + # truncate .traj. file f_configs = open(traj_file_name, "r+") while True: try: @@ -238,6 +218,10 @@ def _truncate_file(filename, n_header, sample_interval): clone_history_file = None traj_file = open(traj_file_name, "w") + else: + ns_file = None + traj_file = None + clone_history_file = None max_iter = params_global["max_iter"] if max_iter > 0: @@ -258,7 +242,7 @@ def _truncate_file(filename, n_header, sample_interval): global_ind_of_max = ns.global_ind(ns.rank_of_max, ns.local_ind_of_max) # write quantities for max config which will be culled below - if NS_comm.rank == 0 and sample_interval > 0 and loop_iter % sample_interval == 0: + if ns_file and sample_interval > 0 and loop_iter % sample_interval == 0: ns_file.write(f"{loop_iter} {global_ind_of_max} {ns.max_val:.10f} " + " ".join([f"{quant:.10f}" for quant in ns.max_quants]) + "\n") ns_file.flush() @@ -280,7 +264,7 @@ def _truncate_file(filename, n_header, sample_interval): # write max to traj file if traj_interval > 0 and loop_iter % traj_interval == 0: - if NS_comm.rank == 0: + if traj_file: # only head node writes if NS_comm.rank == ns.rank_of_max: # already local @@ -338,8 +322,12 @@ def _truncate_file(filename, n_header, sample_interval): loop_iter += 1 + if ns_file: + ns_file.close() if clone_history_file: clone_history_file.close() + if traj_file: + traj_file.close() def main(args_list=None, mpi_finalize=True): diff --git a/pymatnext/ns_utils.py b/pymatnext/ns_utils.py old mode 100755 new mode 100644 diff --git a/pymatnext/sample_utils.py b/pymatnext/sample_utils.py old mode 100755 new mode 100644 index 790b8a8..22be3c2 --- a/pymatnext/sample_utils.py +++ b/pymatnext/sample_utils.py @@ -1,3 +1,5 @@ +import warnings + class NullComm(): """Fake alternative to mpi4py.MPI.Comm for serial run, which implements needed subset of mpi4py.MPI.Comm methods @@ -36,3 +38,24 @@ class MPI: def Finalize(): return + + +def truncate_file_first_col_iter(filename, n_header, sample_interval, max_iter): + warnings.warn(f"Truncating {filename}") + # truncate file after first col exceeds iteration + with open(filename, "r+") as fd: + # skip header + for _ in range(n_header): + _ = fd.readline() + line_i = None + while True: + line = fd.readline() + if not line: + raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {max_iter}") + + line_i = int(line.split()[0]) + if line_i + sample_interval > max_iter: + warnings.warn(f"Truncated {filename} at iter {line_i}") + cur_pos = fd.tell() + fd.truncate(cur_pos) + break