@@ -39,6 +39,7 @@ class CatEstimator(RailStage, PointEstimationMixin):
3939 nzbins = SHARED_PARAMS ,
4040 id_col = SHARED_PARAMS ,
4141 redshift_col = SHARED_PARAMS ,
42+ calc_summary_stats = SHARED_PARAMS ,
4243 )
4344 config_options .update (
4445 ** PointEstimationMixin .config_options .copy (),
@@ -113,6 +114,46 @@ def _process_chunk(
113114 f"{ self .name } ._process_chunk is not implemented"
114115 ) # pragma: no cover
115116
117+ def _calculate_summary_stats (
118+ self ,
119+ qp_dstn : qp .Ensemble ,
120+ ) -> qp .Ensemble :
121+
122+ if qp_dstn .ancil is None : # pragma: no cover
123+ ancil_dict : dict [str , np .ndarray ] = dict ()
124+ qp_dstn .set_ancil (ancil_dict )
125+
126+ quantiles = [0.025 , 0.16 , 0.5 , 0.85 , 0.975 ]
127+ quant_names = ['q2p5' , 'q16' , 'median' , 'q84' , '97p5' ]
128+
129+ locs = qp_dstn .ppf (quantiles )
130+ for name_ , vals_ in zip (quant_names , locs .T ):
131+ qp_dstn .ancil [f"z_{ name_ } " ] = np .expand_dims (vals_ , - 1 )
132+
133+ grid : np .ndarray | None = None
134+
135+ if 'z_mode' not in qp_dstn .ancil :
136+ grid = np .linspace (self .config .zmin , self .config .zmax , self .config .nzbins )
137+ qp_dstn .ancil ['z_mode' ] = qp_dstn .mode (grid )
138+
139+ try :
140+ qp_dstn .ancil ['z_mean' ] = qp_dstn .mean ()
141+ qp_dstn .ancil ['z_std' ] = qp_dstn .std ()
142+ except IndexError : # pragma: no cover
143+ # this is needed b/c qp.MixMod pdf sometimes fails to compute moments
144+ grid = np .linspace (self .config .zmin , self .config .zmax , self .config .nzbins )
145+ pdfs = qp_dstn .pdf (grid )
146+ norms = pdfs .sum (axis = 1 )
147+ means = np .sum (pdfs * grid , axis = 1 ) / norms
148+ diffs = (np .expand_dims (grid , - 1 ) - means ).T
149+ wt_diffs = diffs * diffs * pdfs
150+ stds = np .sqrt ((wt_diffs ).sum (axis = 1 )/ norms )
151+ qp_dstn .ancil ['z_mean' ] = np .expand_dims (means , - 1 )
152+ qp_dstn .ancil ['z_std' ] = np .expand_dims (stds , - 1 )
153+
154+
155+ return qp_dstn
156+
116157 def _do_chunk_output (
117158 self ,
118159 qp_dstn : qp .Ensemble ,
@@ -123,6 +164,9 @@ def _do_chunk_output(
123164 ) -> None :
124165 qp_dstn = self .calculate_point_estimates (qp_dstn )
125166
167+ if self .config .calc_summary_stats :
168+ qp_dstn = self ._calculate_summary_stats (qp_dstn )
169+
126170 # if there is no ancil set by the calculate_point_estimate, initiate one
127171 if data is not None :
128172 if qp_dstn .ancil is None : # pragma: no cover
0 commit comments