3434 "crps" : "CRPS" ,
3535 "check" : "Check Score" ,
3636 "interval" : "Interval Score" ,
37- "rms_adv_group_cal" : "Root-mean-squared Adversarial Group Calibration Error" ,
37+ "rms_adv_group_cal" : ( "Root-mean-squared Adversarial Group " " Calibration Error") ,
3838 "ma_adv_group_cal" : "Mean-absolute Adversarial Group Calibration Error" ,
3939}
4040
4141
42- def get_all_accuracy_metrics (y_pred , y_true , verbose ):
42+ def get_all_accuracy_metrics (y_pred , y_true , verbose = True ):
4343
4444 if verbose :
4545 print (" (1/n) Calculating accuracy metrics" )
@@ -48,7 +48,7 @@ def get_all_accuracy_metrics(y_pred, y_true, verbose):
4848 return acc_metrics
4949
5050
51- def get_all_average_calibration (y_pred , y_std , y_true , num_bins , verbose ):
51+ def get_all_average_calibration (y_pred , y_std , y_true , num_bins , verbose = True ):
5252
5353 if verbose :
5454 print (" (2/n) Calculating average calibration metrics" )
@@ -67,7 +67,9 @@ def get_all_average_calibration(y_pred, y_std, y_true, num_bins, verbose):
6767 return cali_metrics
6868
6969
70- def get_all_adversarial_group_calibration (y_pred , y_std , y_true , num_bins , verbose ):
70+ def get_all_adversarial_group_calibration (
71+ y_pred , y_std , y_true , num_bins , verbose = True
72+ ):
7173
7274 adv_group_cali_metrics = {}
7375 if verbose :
@@ -80,6 +82,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
8082 y_true ,
8183 cali_type = "mean_abs" ,
8284 num_bins = num_bins ,
85+ verbose = verbose ,
8386 )
8487 ma_adv_group_size = ma_adv_group_cali .group_size
8588 ma_adv_group_cali_score_mean = ma_adv_group_cali .score_mean
@@ -99,6 +102,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
99102 y_true ,
100103 cali_type = "root_mean_sq" ,
101104 num_bins = num_bins ,
105+ verbose = verbose ,
102106 )
103107 rms_adv_group_size = rms_adv_group_cali .group_size
104108 rms_adv_group_cali_score_mean = rms_adv_group_cali .score_mean
@@ -112,7 +116,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
112116 return adv_group_cali_metrics
113117
114118
115- def get_all_sharpness_metrics (y_std , verbose ):
119+ def get_all_sharpness_metrics (y_std , verbose = True ):
116120
117121 if verbose :
118122 print (" (4/n) Calculating sharpness metrics" )
@@ -123,7 +127,9 @@ def get_all_sharpness_metrics(y_std, verbose):
123127 return sharp_metrics
124128
125129
126- def get_all_scoring_rule_metrics (y_pred , y_std , y_true , resolution , scaled , verbose ):
130+ def get_all_scoring_rule_metrics (
131+ y_pred , y_std , y_true , resolution , scaled , verbose = True
132+ ):
127133
128134 if verbose :
129135 print (" (n/n) Calculating proper scoring rule metrics" )
@@ -143,15 +149,15 @@ def get_all_scoring_rule_metrics(y_pred, y_std, y_true, resolution, scaled, verb
143149
144150def _print_adversarial_group_calibration (adv_group_metric_dic , print_group_num = 3 ):
145151
146- for adv_group_cali_type , adv_group_cali_dic in adv_group_metric_dic .items ():
147- num_groups = adv_group_cali_dic ["group_sizes" ].shape [0 ]
152+ for a_group_cali_type , a_group_cali_dic in adv_group_metric_dic .items ():
153+ num_groups = a_group_cali_dic ["group_sizes" ].shape [0 ]
148154 print_idxs = [int (x ) for x in np .linspace (1 , num_groups - 1 , print_group_num )]
149- print (" {}" .format (METRIC_NAMES [adv_group_cali_type ]))
155+ print (" {}" .format (METRIC_NAMES [a_group_cali_type ]))
150156 for idx in print_idxs :
151157 print (
152158 " Group Size: {:.2f} -- Calibration Error: {:.3f}" .format (
153- adv_group_cali_dic ["group_sizes" ][idx ],
154- adv_group_cali_dic ["adv_group_cali_mean" ][idx ],
159+ a_group_cali_dic ["group_sizes" ][idx ],
160+ a_group_cali_dic ["adv_group_cali_mean" ][idx ],
155161 )
156162 )
157163
@@ -160,44 +166,45 @@ def get_all_metrics(
160166 y_pred , y_std , y_true , num_bins = 100 , resolution = 99 , scaled = True , verbose = True
161167):
162168
163- """ Accuracy """
169+ # Accuracy
164170 accuracy_metrics = get_all_accuracy_metrics (y_pred , y_true , verbose )
165171
166- """ Calibration """
172+ # Calibration
167173 calibration_metrics = get_all_average_calibration (
168174 y_pred , y_std , y_true , num_bins , verbose
169175 )
170176
171- """ Adversarial Group Calibration """
177+ # Adversarial Group Calibration
172178 adv_group_cali_metrics = get_all_adversarial_group_calibration (
173179 y_pred , y_std , y_true , num_bins , verbose
174180 )
175181
176- """ Sharpness """
182+ # Sharpness
177183 sharpness_metrics = get_all_sharpness_metrics (y_std , verbose )
178184
179- """ Proper Scoring Rules """
185+ # Proper Scoring Rules
180186 scoring_rule_metrics = get_all_scoring_rule_metrics (
181187 y_pred , y_std , y_true , resolution , scaled , verbose
182188 )
183- print ("**Finished Calculating All Metrics**" )
184189
185190 # Print all outputs
186- print ("\n " )
187- print (" Accuracy Metrics " .center (60 , "=" ))
188- for acc_metric , acc_val in accuracy_metrics .items ():
189- print (" {:<13} {:.3f}" .format (METRIC_NAMES [acc_metric ], acc_val ))
190- print (" Average Calibration Metrics " .center (60 , "=" ))
191- for cali_metric , cali_val in calibration_metrics .items ():
192- print (" {:<37} {:.3f}" .format (METRIC_NAMES [cali_metric ], cali_val ))
193- print (" Adversarial Group Calibration Metrics " .center (60 , "=" ))
194- _print_adversarial_group_calibration (adv_group_cali_metrics , print_group_num = 3 )
195- print (" Sharpness Metrics " .center (60 , "=" ))
196- for sharp_metric , sharp_val in sharpness_metrics .items ():
197- print (" {:} {:.3f}" .format (METRIC_NAMES [sharp_metric ], sharp_val ))
198- print (" Scoring Rule Metrics " .center (60 , "=" ))
199- for sr_metric , sr_val in scoring_rule_metrics .items ():
200- print (" {:<25} {:.3f}" .format (METRIC_NAMES [sr_metric ], sr_val ))
191+ if verbose :
192+ print ("**Finished Calculating All Metrics**" )
193+ print ("\n " )
194+ print (" Accuracy Metrics " .center (60 , "=" ))
195+ for acc_metric , acc_val in accuracy_metrics .items ():
196+ print (" {:<13} {:.3f}" .format (METRIC_NAMES [acc_metric ], acc_val ))
197+ print (" Average Calibration Metrics " .center (60 , "=" ))
198+ for cali_metric , cali_val in calibration_metrics .items ():
199+ print (" {:<37} {:.3f}" .format (METRIC_NAMES [cali_metric ], cali_val ))
200+ print (" Adversarial Group Calibration Metrics " .center (60 , "=" ))
201+ _print_adversarial_group_calibration (adv_group_cali_metrics , print_group_num = 3 )
202+ print (" Sharpness Metrics " .center (60 , "=" ))
203+ for sharp_metric , sharp_val in sharpness_metrics .items ():
204+ print (" {:} {:.3f}" .format (METRIC_NAMES [sharp_metric ], sharp_val ))
205+ print (" Scoring Rule Metrics " .center (60 , "=" ))
206+ for sr_metric , sr_val in scoring_rule_metrics .items ():
207+ print (" {:<25} {:.3f}" .format (METRIC_NAMES [sr_metric ], sr_val ))
201208
202209 all_scores = {
203210 "accuracy" : accuracy_metrics ,
@@ -208,10 +215,3 @@ def get_all_metrics(
208215 }
209216
210217 return all_scores
211-
212-
213- if __name__ == "__main__" :
214- y_pred = np .array ([1 , 2 , 3 , 4 ])
215- y_std = np .array ([1 , 2 , 3 , 4 ])
216- y_true = np .array ([1.3 , 2.3 , 3.3 , 4 ])
217- get_all_metrics (y_pred , y_std , y_true )
0 commit comments