Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions pymatnext/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
import numpy as np

def calc_log_a(iters, n_walkers, n_cull, each_cull=False):
if each_cull:
def calc_log_a(iters, n_walkers, n_cull, discrete=False):
if discrete:
iters = np.asarray(iters)
n_cull = np.asarray(n_cull)
# need every iter, and the number culled for each iter
assert len(iters) == len(n_cull)
assert np.all(iters[1:] - iters[:-1] == 1)
# fraction remaining after each iteration
fracs = (n_walkers - n_cull) / n_walkers
# volume remaining after iteration i
# vol_i = \prod_{j=0..i} frac_i
# weight of configs culled in iteration i
# a_i = vol_{i-1} - vol_{i}
# = \prod_{j=0..i-1} frac_j - \prod_{j=0..i} frac_j
# = (\prod_{j=0..i-1} frac_j) (1 - frac_i)
# log(a_i) = (\sum_{j=0..i-1} log(frac_j)) + log(1 - frac_i)
frac_log_sums = np.append([0], np.cumsum(np.log(fracs)))
log_a = frac_log_sums[:-1] + np.log(1.0 - fracs)

else:

"""
# UNSUPPORTED MULTIPLE CULLS
# assume that for multiple culls, every energy is reported, use formula from
# SENS paper PRX v. 4 p 031034 (2014) Eq. 3
# also assume that iters array increments by one for each cull (i.e. not exactly NS iters)
Expand All @@ -15,8 +36,9 @@ def calc_log_a(iters, n_walkers, n_cull, each_cull=False):
# = prod(0..iter[i-1]) (N-i%P)/(N+1-i%P) - prod(0..iter[i]) (N-i%P)/(N+1-i%P)
# = [ prod(0..iter[i-1]) (N-i%P)/(N+1-i%P) ] * (1 - prod(iter[i-1]+1..iter[i]) (N-i%P)/(N+1-i%P))
# = [ prod(0..iter[i-1]) (N-i%P)/(N+1-i%P) ] * (1 - prod(iter[i-1]+1..iter[i]) (N-i%P)/(N+1-i%P))
raise RuntimeError('calc_log_a for each_cull not yet implemented')
else:
"""
if n_cull != 1:
raise RuntimeError(f'calc_log_a for n_cull = {n_cull} != 1 not yet implemented')
log_a = iters * np.log((n_walkers - n_cull + 1) / (n_walkers + 1))

return log_a
Expand Down Expand Up @@ -178,11 +200,12 @@ def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB,
problem = False
# one way to get bad sampling is to be too dominated by a few configurations
problem |= p_entropy < p_entropy_min
# another is to clip the top (high iteration #) of the distribution
low_percentile_mean = np.mean(Z_term[low_percentile_config:low_percentile_config + 1000] / Z_term_sum)
high_percentile_mean = np.mean(Z_term[high_percentile_config - 1000:high_percentile_config] / Z_term_sum)
# another is to clip the top (high iteration #) of the distribution, and therefore be very asymmetric
n_avg = (high_percentile_config - low_percentile_config) // 10
low_percentile_mean = np.mean(Z_term[low_percentile_config:low_percentile_config + n_avg] / Z_term_sum)
high_percentile_mean = np.mean(Z_term[high_percentile_config - n_avg:high_percentile_config] / Z_term_sum)
problem |= high_percentile_mean / low_percentile_mean > 2.0

results_dict['problem'] = 'true' if problem else 'false'
results_dict['problem'] = f'{problem}'.lower()

return results_dict
53 changes: 32 additions & 21 deletions pymatnext/cli/ns_analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,6 @@ def colname(colname_str):
return colname_str


# warn about P = 0
if comm_rank == 0:
if args.delta_P is None or args.delta_P == 0.0 or args.delta_P_GPa is None or args.delta_P_GPa == 0.0:
logging.warning("Analysis at P=0 with variable cell is ill defined, and we got --delta_P None or 0.0, "
"so be careful if run had cell moves and _sampling_ P was 0.0")

for infile_i, infile in enumerate(args.infile):
iters = []
Es = []
Expand Down Expand Up @@ -156,6 +150,16 @@ def colname(colname_str):
Es = np.asarray(Es)
vals = np.asarray(vals)

def _get_and_remove_from_extras(key, vals):
key_ind = header['extras'].index(key)
vals_of_key = vals[:, key_ind].copy()
inds = list(range(vals.shape[1]))
del inds[key_ind]
vals = vals[:, inds]
header['extras'].remove(key)

return vals_of_key, vals

# pointer to natoms
try:
natoms_ind = header['extras'].index('natoms')
Expand All @@ -164,30 +168,37 @@ def colname(colname_str):
natoms = None
# pull out Vs
try:
vol_ind = header['extras'].index('volume')
Vs = vals[:, vol_ind]
inds = list(range(vals.shape[1]))
del inds[vol_ind]
vals = vals[:, inds]
header['extras'].remove('volume')
Vs, vals = _get_and_remove_from_extras('volume', vals)
except (KeyError, ValueError):
Vs = None

# make into list of ndarrays, each of shape (Nsamples,)
vals = list(vals.T)
# warn about P = 0
if comm_rank == 0:
if Vs is not None and (args.delta_P is None or args.delta_P == 0.0 or args.delta_P_GPa is None or args.delta_P_GPa == 0.0):
logging.warning("Analysis at P=0 with variable cell is ill defined, and we got --delta_P None or 0.0, "
"so be careful if run had cell moves and _sampling_ P was 0.0")

# enthalpy if needed
if args.delta_P is not None and args.delta_P != 0.0:
if Vs is None:
raise RuntimeError('--delta_P != 0 requires volumes')
Es += args.delta_P*Vs

# shift energy 0
E_min = Es[-1]
Es -= E_min

# main
# header
n_walkers = header['n_walkers']
n_cull = header.get('n_cull', 1)
log_a = utils.calc_log_a(iters, n_walkers, n_cull)
discrete = header.get('discrete', False)
if discrete:
n_cull, vals = _get_and_remove_from_extras('n_cull', vals)
else:
n_cull = header.get('n_cull', 1)
log_a = utils.calc_log_a(iters, n_walkers, n_cull, discrete=discrete)

# make into list of ndarrays, each of shape (Nsamples,)
vals = list(vals.T)

flat_V_prior = False
if Vs is not None:
Expand Down Expand Up @@ -276,9 +287,9 @@ def str_format(fmt):
'U' : ('U', '{:11g}'),
'Cvp' : ('Cv or Cp', '{:11g}'),
'S' : ('S', '{:11g}'),
'low_percentile_config' : ('low % i', '{:10.0f}'),
'mode_config' : ('mode i', '{:10.0f}'),
'high_percentile_config' : ('high % i', '{:10.0f}'),
'low_percentile_config' : ('low % i', '{:10d}'),
'mode_config' : ('mode i', '{:10d}'),
'high_percentile_config' : ('high % i', '{:10d}'),
'p_entropy' : ('ent(p)', '{:6.3f}'),
'V' : ('V', '{:8g}'),
'thermal_exp' : ('alpha', '{:9.3g}'),
Expand All @@ -291,7 +302,7 @@ def str_format(fmt):

extensive_fields = ['log_Z', 'FG', 'U', 'Cvp', 'S', 'V', 'thermal_exp']
if comm_rank == 0:
print("# ", infile, "n_walkers", n_walkers, "n_cull", n_cull)
print("# ", infile, "n_walkers", n_walkers, "n_cull", n_cull if isinstance(n_cull, int) else "VARIABLE")

header_format = '# ' + T_format_s + ' ' + ' '.join([str_format(formats.get(k, default_format)[1]) for k in item_keys])
line_format = T_format[1] + ' ' + ' '.join([formats.get(k, default_format)[1] for k in item_keys])
Expand Down