Skip to content

Commit ff41e91

Browse files
authored
Merge pull request #203 from LSSTDESC/issue/202/summary_stats
Issue/202/summary stats
2 parents e4f223c + 9bf689a commit ff41e91

3 files changed

Lines changed: 50 additions & 0 deletions

File tree

src/rail/core/common_params.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
default=lsst_zp_errors,
8181
msg="BPZ adds these values in quadrature to the photometric errors",
8282
),
83+
calc_summary_stats=Param(
84+
dtype=bool,
85+
default=False,
86+
msg="Compute summary statistics",
87+
),
8388
calculated_point_estimates=Param(
8489
dtype=list,
8590
default=[],

src/rail/estimation/estimator.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class CatEstimator(RailStage, PointEstimationMixin):
3939
nzbins=SHARED_PARAMS,
4040
id_col=SHARED_PARAMS,
4141
redshift_col=SHARED_PARAMS,
42+
calc_summary_stats=SHARED_PARAMS,
4243
)
4344
config_options.update(
4445
**PointEstimationMixin.config_options.copy(),
@@ -113,6 +114,46 @@ def _process_chunk(
113114
f"{self.name}._process_chunk is not implemented"
114115
) # pragma: no cover
115116

117+
def _calculate_summary_stats(
118+
self,
119+
qp_dstn: qp.Ensemble,
120+
) -> qp.Ensemble:
121+
122+
if qp_dstn.ancil is None: # pragma: no cover
123+
ancil_dict: dict[str, np.ndarray] = dict()
124+
qp_dstn.set_ancil(ancil_dict)
125+
126+
quantiles = [0.025, 0.16, 0.5, 0.85, 0.975]
127+
quant_names = ['q2p5', 'q16', 'median', 'q84', '97p5']
128+
129+
locs = qp_dstn.ppf(quantiles)
130+
for name_, vals_ in zip(quant_names, locs.T):
131+
qp_dstn.ancil[f"z_{name_}"] = np.expand_dims(vals_, -1)
132+
133+
grid: np.ndarray | None = None
134+
135+
if 'z_mode' not in qp_dstn.ancil:
136+
grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)
137+
qp_dstn.ancil['z_mode'] = qp_dstn.mode(grid)
138+
139+
try:
140+
qp_dstn.ancil['z_mean'] = qp_dstn.mean()
141+
qp_dstn.ancil['z_std'] = qp_dstn.std()
142+
except IndexError: # pragma: no cover
143+
# this is needed b/c qp.MixMod pdf sometimes fails to compute moments
144+
grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)
145+
pdfs = qp_dstn.pdf(grid)
146+
norms = pdfs.sum(axis=1)
147+
means = np.sum(pdfs * grid, axis=1) / norms
148+
diffs = (np.expand_dims(grid, -1) - means).T
149+
wt_diffs = diffs * diffs * pdfs
150+
stds = np.sqrt((wt_diffs).sum(axis=1)/norms)
151+
qp_dstn.ancil['z_mean'] = np.expand_dims(means, -1)
152+
qp_dstn.ancil['z_std'] = np.expand_dims(stds, -1)
153+
154+
155+
return qp_dstn
156+
116157
def _do_chunk_output(
117158
self,
118159
qp_dstn: qp.Ensemble,
@@ -123,6 +164,9 @@ def _do_chunk_output(
123164
) -> None:
124165
qp_dstn = self.calculate_point_estimates(qp_dstn)
125166

167+
if self.config.calc_summary_stats:
168+
qp_dstn = self._calculate_summary_stats(qp_dstn)
169+
126170
# if there is no ancil set by the calculate_point_estimate, initiate one
127171
if data is not None:
128172
if qp_dstn.ancil is None: # pragma: no cover

tests/estimation/test_algos.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_random_pz() -> None:
2323
"rand_width": 0.025,
2424
"rand_zmin": 0.0,
2525
"rand_zmax": 3.0,
26+
"calc_summary_stats": True,
2627
"nzbins": 301,
2728
"hdf5_groupname": "photometry",
2829
"model": "None",

0 commit comments

Comments
 (0)