Skip to content
This repository was archived by the owner on Nov 5, 2022. It is now read-only.

Commit 4cbe4a5

Browse files
committed
fix: improve logging and minor performance improvements
1 parent 7f3ce95 commit 4cbe4a5

1 file changed

Lines changed: 33 additions & 22 deletions

File tree

cmpy/exactdiag.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def compute_groundstate(model, thresh=50):
4646
def solve_sector(model: AbstractManyBodyModel, sector: Sector, cache: dict = None):
4747
sector_key = (sector.n_up, sector.n_dn)
4848
if cache is not None and sector_key in cache:
49-
logger.debug("Loading eig %d, %d", sector.n_up, sector.n_dn)
49+
logger.debug("Loading eig %d, %d (%s)", sector.n_up, sector.n_dn, sector.size)
5050
eigvals, eigvecs = cache[sector_key]
5151
else:
5252
logger.debug("Solving eig %d, %d (%s)", sector.n_up, sector.n_dn, sector.size)
@@ -121,11 +121,12 @@ def _accumulate_sum(gf, z, evals, evals_p1, evecs_p1, cdag_evec, beta, emin):
121121
num_m = len(evals_p1)
122122
num_n = len(evals)
123123
for m in prange(num_m):
124+
eig_m = evals_p1[m]
125+
z_m = z - eig_m
124126
for n in range(num_n):
125-
eig_m = evals_p1[m]
126127
eig_n = evals[n]
127128
weights = exp_evals[n] + exp_evals_p1[m]
128-
gf += overlap[m, n] * weights / (z + eig_n - eig_m)
129+
gf += overlap[m, n] * weights / (z_m + eig_n)
129130

130131

131132
def accumulate_gf(gf, z, cdag, evals, evecs, evals_p1, evecs_p1, beta, emin=0.0):
@@ -179,17 +180,13 @@ def _acc_gf(self, sector, sector_p1, evals, evecs, evals_p1, evecs_p1, factor):
179180
e0 = self._gs_energy
180181
accumulate_gf(self._gf, z, cdag, evals, evecs, evals_p1, evecs_p1, beta, e0)
181182

182-
def _acc_occ(self, sector, evals, evecs, factor):
183-
up = sector.up_states
184-
dn = sector.dn_states
183+
def _acc_occ(self, up, dn, evals, evecs, factor):
185184
beta = self.beta
186185
e0 = self._gs_energy
187186
self._occ *= factor
188187
self._occ += occupation(up, dn, evals, evecs, beta, e0, self.pos, self.sigma)
189188

190-
def _acc_occ_double(self, sector, evals, evecs, factor):
191-
up = sector.up_states
192-
dn = sector.dn_states
189+
def _acc_occ_double(self, up, dn, evals, evecs, factor):
193190
beta = self.beta
194191
e0 = self._gs_energy
195192
self._occ_double *= factor
@@ -203,31 +200,45 @@ def accumulate(self, sector, sector_p1, evals, evecs, evals_p1, evecs_p1):
203200
self._gs_energy = min_energy
204201
logger.debug("New ground state: E_0=%.4f", min_energy)
205202

206-
logger.debug("accumulating")
203+
logger.debug("Accumulating")
204+
up = np.array(sector.up_states, dtype=np.int64)
205+
dn = np.array(sector.dn_states, dtype=np.int64)
207206
self._acc_part(evals, factor)
208207
self._acc_gf(sector, sector_p1, evals, evecs, evals_p1, evecs_p1, factor)
209-
# self._acc_occ(sector, evals, evecs, factor)
210-
# self._acc_occ_double(sector, evals, evecs, factor)
208+
self._acc_occ(up, dn, evals, evecs, factor)
209+
self._acc_occ_double(up, dn, evals, evecs, factor)
211210

212211

213212
def greens_function_lehmann(model, z, beta, pos=0, sigma=UP, eig_cache=None):
214-
logger.debug("Accumulating Lehmann sum (pos=%s, sigma=%s)", pos, sigma)
213+
basis = model.basis
214+
215+
logger.info("Accumulating Lehmann sum (pos=%s, sigma=%s)", pos, sigma)
216+
logger.debug("Sites: %s (%s states)", basis.num_sites, basis.size)
217+
215218
data = GreensFunctionMeasurement(z, beta, pos, sigma)
216219
eig_cache = eig_cache if eig_cache is not None else dict()
217-
for n_up, n_dn in model.iter_fillings():
220+
221+
fillings = list(basis.iter_fillings())
222+
num = len(fillings)
223+
w = len(str(num))
224+
for i, (n_up, n_dn) in enumerate(fillings):
218225
sector = model.get_sector(n_up, n_dn)
219-
sector_p1 = model.basis.upper_sector(n_up, n_dn, sigma)
226+
logger.info("[%s/%s] Sector %s, %s", f"{i+1:>{w}}", num, n_up, n_dn)
227+
228+
sector_p1 = basis.upper_sector(n_up, n_dn, sigma)
220229
if sector_p1 is not None:
221230
eigvals, eigvecs = solve_sector(model, sector, cache=eig_cache)
222231
eigvals_p1, eigvecs_p1 = solve_sector(model, sector_p1, cache=eig_cache)
223232
data.accumulate(sector, sector_p1, eigvals, eigvecs, eigvals_p1, eigvecs_p1)
224-
# else: eig_cache.clear()
225-
226-
logger.debug("-" * 40)
227-
logger.debug("gs-energy: %+.4f", data.gs_energy)
228-
logger.debug("occupation: %.4f", data.occ)
229-
logger.debug("double-occ: %.4f", data.occ_double)
230-
logger.debug("-" * 40)
233+
else:
234+
logger.debug("No upper sector, skipping")
235+
# eig_cache.clear()
236+
237+
logger.info("-" * 40)
238+
logger.info("gs-energy: %+.4f", data.gs_energy)
239+
logger.info("occupation: %.4f", data.occ)
240+
logger.info("double-occ: %.4f", data.occ_double)
241+
logger.info("-" * 40)
231242
return data
232243

233244

0 commit comments

Comments
 (0)