Skip to content

Commit c840e96

Browse files
committed
added flat_targets
1 parent 187230e commit c840e96

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

pyat/at/latticetools/observablelist.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,26 @@ def get_sum_residuals(self, *obsid: str | int, err: float | None = None) -> floa
651651
"""
652652
return sum(np.sum(res) for res in self._collect("residual", *obsid, err=err))
653653

654+
def get_targets(self, *obsid: str | int, err: float | None = None) -> tuple:
655+
"""Return the target values of observables.
656+
657+
Args:
658+
*obsid: name or index of selected observables (Default all)
659+
err: Default observable value to be used when the evaluation failed. By
660+
default, an Exception is raised.
661+
"""
662+
return self._collect("target", *obsid)
663+
664+
def get_flat_targets(self, *obsid: str | int) -> tuple:
665+
"""Return a 1-D array of target values of observables.
666+
667+
Args:
668+
*obsid: name or index of selected observables (Default all)
669+
err: Default observable value to be used when the evaluation failed. By
670+
default, an Exception is raised.
671+
"""
672+
return _flatten(self._collect("target", *obsid))
673+
654674
shapes = property(get_shapes, doc="Shapes of all values")
655675
flat_shape = property(get_flat_shape, doc="Shape of the flattened values")
656676
values = property(get_values, doc="values of all observables")
@@ -675,3 +695,5 @@ def get_sum_residuals(self, *obsid: str | int, err: float | None = None) -> floa
675695
flat_weights = property(get_flat_weights, doc="1-D array of Observable weights")
676696
residuals = property(get_residuals, doc="Residuals of all observable")
677697
sum_residuals = property(get_sum_residuals, doc="Sum of all residual values")
698+
targets = property(get_targets, doc="Target values of all observables")
699+
flat_targets = property(get_flat_targets, doc="1-D array of target values")

pyat/at/latticetools/response_matrix.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,19 +366,19 @@ def correction_matrix(self, nvals: int | None = None) -> FloatArray:
366366
return cormat
367367

368368
def get_correction(
369-
self, observed: FloatArray, nvals: int | None = None
369+
self, deviation: FloatArray, nvals: int | None = None
370370
) -> FloatArray:
371371
"""Compute the correction of the given observation.
372372
373373
Args:
374-
observed: Vector of observed deviations,
374+
deviation: Vector of observed deviations,
375375
nvals: Desired number of singular values. If :py:obj:`None`, use
376376
all singular values
377377
378378
Returns:
379379
corr: Correction vector
380380
"""
381-
return -self.correction_matrix(nvals=nvals) @ observed
381+
return -self.correction_matrix(nvals=nvals) @ deviation
382382

383383
def save(self, file) -> None:
384384
"""Save a response matrix in the NumPy .npy format.
@@ -513,12 +513,12 @@ def correct(
513513
for it, nv in zip(range(niter), np.broadcast_to(nvals, (niter,))):
514514
print(f"step {it + 1}, nvals = {nv}")
515515
obs.evaluate(ring, **self._eval_args)
516-
err = obs.flat_deviations
517-
if np.any(np.isnan(err)):
516+
deviation = obs.flat_deviations
517+
if np.any(np.isnan(deviation)):
518518
raise AtError(
519519
f"Step {it + 1}: Invalid observables, cannot compute correction"
520520
)
521-
corr = self.get_correction(obs.flat_deviations, nvals=nv)
521+
corr = self.get_correction(deviation, nvals=nv)
522522
sumcorr = sumcorr + corr # non-broadcastable sumcorr
523523
if apply:
524524
self.variables.increment(corr, ring=ring)

0 commit comments

Comments
 (0)