diff --git a/pymatnext/cli/sample.py b/pymatnext/cli/sample.py index 73d86bb..e33fbf0 100755 --- a/pymatnext/cli/sample.py +++ b/pymatnext/cli/sample.py @@ -18,6 +18,7 @@ from pymatnext.ns import NS from pymatnext.params import check_fill_defaults from pymatnext.sample_params import param_defaults +from pymatnext.sample_utils import truncate_file_first_col_iter from pymatnext.loop_exit import NSLoopExit @@ -166,44 +167,24 @@ def sample(args, MPI, NS_comm, walker_comm): traj_interval = params_global["traj_interval"] sample_interval = params_global["sample_interval"] snapshot_interval = params_global["snapshot_interval"] + snapshot_save_old = params_global["snapshot_save_old"] stdout_report_interval_s = params_global["stdout_report_interval_s"] step_size_tune_interval = params_step_size_tune["interval"] - # WARNING: clone_history_file not restartable - if params_global["clone_history"]: - clone_history_file = open(f"{output_filename_prefix}.clone_history", "w") - clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n') - else: - clone_history_file = None ns_file_name = f"{output_filename_prefix}.NS_samples" + clone_history_file_name = f"{output_filename_prefix}.clone_history" if params_global["clone_history"] else None traj_file_name = f"{output_filename_prefix}.traj{config_suffix}" if NS_comm.rank == 0: + # set up I/O if ns.snapshot_iter >= 0: - # snapshot, truncate existing NS_samples and .traj.suffix files - - # NOTE: does this code belong here? Maybe refactor to a function, maybe - # move trajectory truncation into NSConfig or something? - - # truncate .NS_samples file - f_samples = open(ns_file_name, "r+") - # skip header - _ = f_samples.readline() - line_i = None - while True: - line = f_samples.readline() - if not line: - raise RuntimeError(f"Failed to find enough lines in .NS_samples file (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}") - - line_i = int(line.split()[0]) - if line_i + sample_interval > ns.snapshot_iter: - cur_pos = f_samples.tell() - f_samples.truncate(cur_pos) - break + # snapshot, truncate existing .NS_samples, .clone_history, and .traj. files - f_samples.close() + truncate_file_first_col_iter(ns_file_name, n_header=1, sample_interval=sample_interval, max_iter=ns.snapshot_iter) + truncate_file_first_col_iter(clone_history_file_name, n_header=1, sample_interval=1, max_iter=ns.snapshot_iter) - # truncate .traj.suffix file + # NOTE: should move trajectory truncation into NSConfig, since it's config file-format specific + # truncate .traj. file f_configs = open(traj_file_name, "r+") while True: try: @@ -220,16 +201,27 @@ def sample(args, MPI, NS_comm, walker_comm): ns_file = open(ns_file_name, "a") traj_file = open(traj_file_name, "a") + clone_history_file = open(clone_history_file_name, "a") else: - # run from start, open new .NS_samples and .traj.suffix files - + # run from start, open new .NS_samples, .clone_history, and .traj. files + # write header as needed ns_file = open(ns_file_name, "w") header_dict = { "n_walkers": ns.n_configs_global, "n_cull": 1 } header_dict.update(ns.local_configs[0].header_dict()) ns_file.write("# " + " ".join(json.dumps(header_dict, indent=0).splitlines()) + "\n") + if clone_history_file_name: + clone_history_file = open(clone_history_file_name, "w") + clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n') + else: + clone_history_file = None + traj_file = open(traj_file_name, "w") + else: + ns_file = None + traj_file = None + clone_history_file = None max_iter = params_global["max_iter"] if max_iter > 0: @@ -250,7 +242,7 @@ def sample(args, MPI, NS_comm, walker_comm): global_ind_of_max = ns.global_ind(ns.rank_of_max, ns.local_ind_of_max) # write quantities for max config which will be culled below - if NS_comm.rank == 0 and sample_interval > 0 and loop_iter % sample_interval == 0: + if ns_file and sample_interval > 0 and loop_iter % sample_interval == 0: ns_file.write(f"{loop_iter} {global_ind_of_max} {ns.max_val:.10f} " + " ".join([f"{quant:.10f}" for quant in ns.max_quants]) + "\n") ns_file.flush() @@ -265,14 +257,14 @@ def sample(args, MPI, NS_comm, walker_comm): global_ind_of_clone_source = (global_ind_of_max + 1 + ns.rng_global.integers(0, ns.n_configs_global - 1)) % ns.n_configs_global rank_of_clone_source, local_ind_of_clone_source = ns.local_ind(global_ind_of_clone_source) - if clone_history_file is not None: + if clone_history_file: clone_history_file.write(f"{loop_iter} {global_ind_of_clone_source} {global_ind_of_max}\n") if loop_iter % 1000 == 1000 - 1: clone_history_file.flush() # write max to traj file if traj_interval > 0 and loop_iter % traj_interval == 0: - if NS_comm.rank == 0: + if traj_file: # only head node writes if NS_comm.rank == ns.rank_of_max: # already local @@ -326,12 +318,16 @@ def sample(args, MPI, NS_comm, walker_comm): # NOTE: should this be a time rather than iteration interval? That'd basically be straightforward, # except it would require an additional communication so all processes agree that it's time for a snapshot if loop_iter > 0 and snapshot_interval > 0 and loop_iter % snapshot_interval == 0: - ns.snapshot(loop_iter, output_filename_prefix) + ns.snapshot(loop_iter, output_filename_prefix, save_old=snapshot_save_old) loop_iter += 1 - if clone_history_file is not None: + if ns_file: + ns_file.close() + if clone_history_file: clone_history_file.close() + if traj_file: + traj_file.close() def main(args_list=None, mpi_finalize=True): diff --git a/pymatnext/ns.py b/pymatnext/ns.py index af89d75..0b6a2e8 100644 --- a/pymatnext/ns.py +++ b/pymatnext/ns.py @@ -480,11 +480,11 @@ def snapshot(self, loop_iter, output_filename_prefix, save_old=2): output_filename_prefix: str initial part of filenames that will be written to save_old: int, default 2 - number of old snapshots to save + number of old snapshots to save, negative to save all """ if self.comm.rank == 0: old_state_files = NS._old_state_files(output_filename_prefix) - if len(old_state_files) > save_old - 1: + if save_old >= 0 and len(old_state_files) > save_old - 1: old_state_files = old_state_files[:-(save_old-1)] else: old_state_files = [] diff --git a/pymatnext/ns_utils.py b/pymatnext/ns_utils.py old mode 100755 new mode 100644 diff --git a/pymatnext/sample_params.py b/pymatnext/sample_params.py index 8c3637f..3031184 100644 --- a/pymatnext/sample_params.py +++ b/pymatnext/sample_params.py @@ -11,6 +11,7 @@ "sample_interval": 1, "traj_interval": 100, "snapshot_interval": 10000, + "snapshot_save_old": 2, "step_size_tune": { "interval": 1000, "n_configs": 1, diff --git a/pymatnext/sample_utils.py b/pymatnext/sample_utils.py old mode 100755 new mode 100644 index 790b8a8..22be3c2 --- a/pymatnext/sample_utils.py +++ b/pymatnext/sample_utils.py @@ -1,3 +1,5 @@ +import warnings + class NullComm(): """Fake alternative to mpi4py.MPI.Comm for serial run, which implements needed subset of mpi4py.MPI.Comm methods @@ -36,3 +38,24 @@ class MPI: def Finalize(): return + + +def truncate_file_first_col_iter(filename, n_header, sample_interval, max_iter): + warnings.warn(f"Truncating {filename}") + # truncate file after first col exceeds iteration + with open(filename, "r+") as fd: + # skip header + for _ in range(n_header): + _ = fd.readline() + line_i = None + while True: + line = fd.readline() + if not line: + raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {max_iter}") + + line_i = int(line.split()[0]) + if line_i + sample_interval > max_iter: + warnings.warn(f"Truncated {filename} at iter {line_i}") + cur_pos = fd.tell() + fd.truncate(cur_pos) + break