Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions pymatnext/cli/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pymatnext.ns import NS
from pymatnext.params import check_fill_defaults
from pymatnext.sample_params import param_defaults
from pymatnext.sample_utils import truncate_file_first_col_iter

from pymatnext.loop_exit import NSLoopExit

Expand Down Expand Up @@ -166,44 +167,24 @@ def sample(args, MPI, NS_comm, walker_comm):
traj_interval = params_global["traj_interval"]
sample_interval = params_global["sample_interval"]
snapshot_interval = params_global["snapshot_interval"]
snapshot_save_old = params_global["snapshot_save_old"]
stdout_report_interval_s = params_global["stdout_report_interval_s"]
step_size_tune_interval = params_step_size_tune["interval"]
# WARNING: clone_history_file not restartable
if params_global["clone_history"]:
clone_history_file = open(f"{output_filename_prefix}.clone_history", "w")
clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n')
else:
clone_history_file = None

ns_file_name = f"{output_filename_prefix}.NS_samples"
clone_history_file_name = f"{output_filename_prefix}.clone_history" if params_global["clone_history"] else None
traj_file_name = f"{output_filename_prefix}.traj{config_suffix}"

if NS_comm.rank == 0:
# set up I/O
if ns.snapshot_iter >= 0:
# snapshot, truncate existing NS_samples and .traj.suffix files

# NOTE: does this code belong here? Maybe refactor to a function, maybe
# move trajectory truncation into NSConfig or something?

# truncate .NS_samples file
f_samples = open(ns_file_name, "r+")
# skip header
_ = f_samples.readline()
line_i = None
while True:
line = f_samples.readline()
if not line:
raise RuntimeError(f"Failed to find enough lines in .NS_samples file (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}")

line_i = int(line.split()[0])
if line_i + sample_interval > ns.snapshot_iter:
cur_pos = f_samples.tell()
f_samples.truncate(cur_pos)
break
# snapshot, truncate existing .NS_samples, .clone_history, and .traj.<suffix> files

f_samples.close()
truncate_file_first_col_iter(ns_file_name, n_header=1, sample_interval=sample_interval, max_iter=ns.snapshot_iter)
truncate_file_first_col_iter(clone_history_file_name, n_header=1, sample_interval=1, max_iter=ns.snapshot_iter)

# truncate .traj.suffix file
# NOTE: should move trajectory truncation into NSConfig, since it's config file-format specific
# truncate .traj.<suffix> file
f_configs = open(traj_file_name, "r+")
while True:
try:
Expand All @@ -220,16 +201,27 @@ def sample(args, MPI, NS_comm, walker_comm):

ns_file = open(ns_file_name, "a")
traj_file = open(traj_file_name, "a")
clone_history_file = open(clone_history_file_name, "a")

else:
# run from start, open new .NS_samples and .traj.suffix files

# run from start, open new .NS_samples, .clone_history, and .traj.<suffix> files
# write header as needed
ns_file = open(ns_file_name, "w")
header_dict = { "n_walkers": ns.n_configs_global, "n_cull": 1 }
header_dict.update(ns.local_configs[0].header_dict())
ns_file.write("# " + " ".join(json.dumps(header_dict, indent=0).splitlines()) + "\n")

if clone_history_file_name:
clone_history_file = open(clone_history_file_name, "w")
clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n')
else:
clone_history_file = None

traj_file = open(traj_file_name, "w")
else:
ns_file = None
traj_file = None
clone_history_file = None

max_iter = params_global["max_iter"]
if max_iter > 0:
Expand All @@ -250,7 +242,7 @@ def sample(args, MPI, NS_comm, walker_comm):
global_ind_of_max = ns.global_ind(ns.rank_of_max, ns.local_ind_of_max)

# write quantities for max config which will be culled below
if NS_comm.rank == 0 and sample_interval > 0 and loop_iter % sample_interval == 0:
if ns_file and sample_interval > 0 and loop_iter % sample_interval == 0:
ns_file.write(f"{loop_iter} {global_ind_of_max} {ns.max_val:.10f} " + " ".join([f"{quant:.10f}" for quant in ns.max_quants]) + "\n")
ns_file.flush()

Expand All @@ -265,14 +257,14 @@ def sample(args, MPI, NS_comm, walker_comm):
global_ind_of_clone_source = (global_ind_of_max + 1 + ns.rng_global.integers(0, ns.n_configs_global - 1)) % ns.n_configs_global
rank_of_clone_source, local_ind_of_clone_source = ns.local_ind(global_ind_of_clone_source)

if clone_history_file is not None:
if clone_history_file:
clone_history_file.write(f"{loop_iter} {global_ind_of_clone_source} {global_ind_of_max}\n")
if loop_iter % 1000 == 1000 - 1:
clone_history_file.flush()

# write max to traj file
if traj_interval > 0 and loop_iter % traj_interval == 0:
if NS_comm.rank == 0:
if traj_file:
# only head node writes
if NS_comm.rank == ns.rank_of_max:
# already local
Expand Down Expand Up @@ -326,12 +318,16 @@ def sample(args, MPI, NS_comm, walker_comm):
# NOTE: should this be a time rather than iteration interval? That'd basically be straightforward,
# except it would require an additional communication so all processes agree that it's time for a snapshot
if loop_iter > 0 and snapshot_interval > 0 and loop_iter % snapshot_interval == 0:
ns.snapshot(loop_iter, output_filename_prefix)
ns.snapshot(loop_iter, output_filename_prefix, save_old=snapshot_save_old)

loop_iter += 1

if clone_history_file is not None:
if ns_file:
ns_file.close()
if clone_history_file:
clone_history_file.close()
if traj_file:
traj_file.close()


def main(args_list=None, mpi_finalize=True):
Expand Down
4 changes: 2 additions & 2 deletions pymatnext/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,11 @@ def snapshot(self, loop_iter, output_filename_prefix, save_old=2):
output_filename_prefix: str
initial part of filenames that will be written to
save_old: int, default 2
number of old snapshots to save
number of old snapshots to save, negative to save all
"""
if self.comm.rank == 0:
old_state_files = NS._old_state_files(output_filename_prefix)
if len(old_state_files) > save_old - 1:
if save_old >= 0 and len(old_state_files) > save_old - 1:
old_state_files = old_state_files[:-(save_old-1)]
else:
old_state_files = []
Expand Down
Empty file modified pymatnext/ns_utils.py
100755 → 100644
Empty file.
1 change: 1 addition & 0 deletions pymatnext/sample_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"sample_interval": 1,
"traj_interval": 100,
"snapshot_interval": 10000,
"snapshot_save_old": 2,
"step_size_tune": {
"interval": 1000,
"n_configs": 1,
Expand Down
23 changes: 23 additions & 0 deletions pymatnext/sample_utils.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

class NullComm():
"""Fake alternative to mpi4py.MPI.Comm for serial run, which implements needed subset
of mpi4py.MPI.Comm methods
Expand Down Expand Up @@ -36,3 +38,24 @@ class MPI:

def Finalize():
return


def truncate_file_first_col_iter(filename, n_header, sample_interval, max_iter):
warnings.warn(f"Truncating {filename}")
# truncate file after first col exceeds iteration
with open(filename, "r+") as fd:
# skip header
for _ in range(n_header):
_ = fd.readline()
line_i = None
while True:
line = fd.readline()
if not line:
raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {max_iter}")

line_i = int(line.split()[0])
if line_i + sample_interval > max_iter:
warnings.warn(f"Truncated {filename} at iter {line_i}")
cur_pos = fd.tell()
fd.truncate(cur_pos)
break