Skip to content

Commit 407240c

Browse files
authored
* Clean Up * Fixed bug.
1 parent e4f8470 commit 407240c

File tree

11 files changed

+136
-145
lines changed

11 files changed

+136
-145
lines changed

examples/viz_minimal.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import uncertainty_toolbox.data as udata
99

1010
import neatplot
11+
1112
neatplot.set_style()
12-
neatplot.update_rc('figure.dpi', 150)
13-
neatplot.update_rc('text.usetex', False)
13+
neatplot.update_rc("figure.dpi", 150)
14+
neatplot.update_rc("text.usetex", False)
1415

1516

1617
# Set random seed
@@ -20,10 +21,10 @@
2021
(y_pred, y_std, y_true) = udata.synthetic_arange_random()
2122

2223
# Print details about the synthetic results
23-
print('* y_true: {}'.format(y_true))
24-
print('* y_pred: {}'.format(y_pred))
25-
print('* |y_true - y_pred|: {}'.format(np.abs(y_true - y_pred)))
26-
print('* y_std: {}'.format(y_std))
24+
print("* y_true: {}".format(y_true))
25+
print("* y_pred: {}".format(y_pred))
26+
print("* |y_true - y_pred|: {}".format(np.abs(y_true - y_pred)))
27+
print("* y_std: {}".format(y_std))
2728

2829
# Plot
2930
uviz.plot_intervals(y_pred, y_std, y_true, show=True)

examples/viz_recalibrate.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
)
1313
import uncertainty_toolbox.viz as uviz
1414
from uncertainty_toolbox.recalibration import iso_recal
15-
from uncertainty_toolbox.viz import plot_calibration
1615

1716
import neatplot
1817

@@ -44,27 +43,18 @@ def update_rc_params():
4443

4544

4645
# List of predictive means and standard deviations
47-
pred_mean_list = [
48-
f,
49-
# f + 0.1,
50-
# f - 0.1,
51-
# f + 0.25,
52-
# f - 0.25,
53-
]
46+
pred_mean_list = [f]
5447

5548
pred_std_list = [
5649
std * 0.5, # overconfident
5750
std * 2.0, # underconfident
58-
# std, # correct
5951
]
6052

6153
# Loop through, make plots, and compute metrics
6254
for i, pred_mean in enumerate(pred_mean_list):
6355
for j, pred_std in enumerate(pred_std_list):
6456
# Before recalibration
65-
exp_props, obs_props = get_proportion_lists_vectorized(
66-
pred_mean, pred_std, y
67-
)
57+
exp_props, obs_props = get_proportion_lists_vectorized(pred_mean, pred_std, y)
6858
recal_model = None
6959
mace = umetrics.mean_absolute_calibration_error(
7060
pred_mean, pred_std, y, recal_model=recal_model
@@ -76,9 +66,7 @@ def update_rc_params():
7666
pred_mean, pred_std, y, recal_model=recal_model
7767
)
7868
print("Before Recalibration")
79-
print(
80-
" MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma)
81-
)
69+
print(" MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma))
8270

8371
uviz.plot_calibration(
8472
pred_mean,
@@ -104,11 +92,9 @@ def update_rc_params():
10492
pred_mean, pred_std, y, recal_model=recal_model
10593
)
10694
print(" After Recalibration")
107-
print(
108-
" MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma)
109-
)
95+
print(" MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma))
11096

111-
plot_calibration(
97+
uviz.plot_calibration(
11298
pred_mean,
11399
pred_std,
114100
y,

examples/viz_synth_sine.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import numpy as np
55
import matplotlib.pyplot as plt
66
from matplotlib import rc
7-
#plt.ion()
87

98
import uncertainty_toolbox.data as udata
109
import uncertainty_toolbox.metrics as umetrics
1110
import uncertainty_toolbox.viz as uviz
1211

1312
import neatplot
13+
1414
neatplot.set_style()
15-
neatplot.update_rc('text.usetex', False) # Set to True for system latex
15+
neatplot.update_rc("text.usetex", False) # Set to True for system latex
1616

1717

1818
# Set random seed
@@ -26,17 +26,17 @@
2626
savefig = False
2727

2828

29-
def save_figure(name_str, file_type='png'):
29+
def save_figure(name_str, file_type="png"):
3030
"""Save figure, or do nothing if savefig is False."""
3131
if savefig:
3232
neatplot.save_figure(name_str, file_type)
3333

3434

3535
def update_rc_params():
3636
"""Update matplotlib rc params."""
37-
plt.rcParams.update({'font.size': 14})
38-
plt.rcParams.update({'xtick.labelsize': 14})
39-
plt.rcParams.update({'ytick.labelsize': 14})
37+
plt.rcParams.update({"font.size": 14})
38+
plt.rcParams.update({"xtick.labelsize": 14})
39+
plt.rcParams.update({"ytick.labelsize": 14})
4040

4141

4242
def make_plots(pred_mean, pred_std, idx1, idx2):
@@ -48,38 +48,32 @@ def make_plots(pred_mean, pred_std, idx1, idx2):
4848

4949
# Make xy plot
5050
uviz.plot_xy(pred_mean, pred_std, y, x, n_subset=300, ylims=ylims, xlims=[0, 15])
51-
save_figure(f'xy_{idx1}_{idx2}')
51+
save_figure(f"xy_{idx1}_{idx2}")
5252
plt.show()
5353

5454
# Make intervals plot
5555
uviz.plot_intervals(pred_mean, pred_std, y, n_subset=n_subset, ylims=ylims)
56-
save_figure(f'intervals_{idx1}_{idx2}')
56+
save_figure(f"intervals_{idx1}_{idx2}")
5757
plt.show()
5858

5959
# Make calibration plot
6060
uviz.plot_calibration(pred_mean, pred_std, y)
61-
save_figure(f'calibration_{idx1}_{idx2}')
61+
save_figure(f"calibration_{idx1}_{idx2}")
6262
plt.show()
6363

6464
# Make ordered intervals plot
6565
uviz.plot_intervals_ordered(pred_mean, pred_std, y, n_subset=n_subset, ylims=ylims)
66-
save_figure(f'intervals_ordered_{idx1}_{idx2}')
66+
save_figure(f"intervals_ordered_{idx1}_{idx2}")
6767
plt.show()
6868

6969

7070
# List of predictive means and standard deviations
71-
pred_mean_list = [
72-
f,
73-
#f + 0.1,
74-
#f - 0.1,
75-
#f + 0.25,
76-
#f - 0.25,
77-
]
71+
pred_mean_list = [f]
7872

7973
pred_std_list = [
80-
std * 0.5, # overconfident
81-
std * 2.0, # underconfident
82-
std, # correct
74+
std * 0.5, # overconfident
75+
std * 2.0, # underconfident
76+
std, # correct
8377
]
8478

8579
# Loop through, make plots, and compute metrics
@@ -91,4 +85,4 @@ def make_plots(pred_mean, pred_std, idx1, idx2):
9185

9286
make_plots(pred_mean, pred_std, i, j)
9387

94-
print(f'MACE: {mace}, RMSCE: {rmsce}, MA: {ma}')
88+
print(f"MACE: {mace}, RMSCE: {rmsce}, MA: {ma}")

uncertainty_toolbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
curvy_cosine,
99
)
1010

11-
from .metrics import(
11+
from .metrics import (
1212
get_all_metrics,
1313
)
1414

uncertainty_toolbox/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def synthetic_arange_random(num_points=10):
1212
"""
1313
y_true = np.arange(num_points)
1414
y_pred = np.arange(num_points) + np.random.random((num_points,))
15-
y_std = np.abs(y_true - y_pred) + .1 * np.random.random((num_points,))
15+
y_std = np.abs(y_true - y_pred) + 0.1 * np.random.random((num_points,))
1616

1717
return (y_pred, y_std, y_true)
1818

uncertainty_toolbox/metrics.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
"crps": "CRPS",
3535
"check": "Check Score",
3636
"interval": "Interval Score",
37-
"rms_adv_group_cal": "Root-mean-squared Adversarial Group Calibration Error",
37+
"rms_adv_group_cal": ("Root-mean-squared Adversarial Group " "Calibration Error"),
3838
"ma_adv_group_cal": "Mean-absolute Adversarial Group Calibration Error",
3939
}
4040

4141

42-
def get_all_accuracy_metrics(y_pred, y_true, verbose):
42+
def get_all_accuracy_metrics(y_pred, y_true, verbose=True):
4343

4444
if verbose:
4545
print(" (1/n) Calculating accuracy metrics")
@@ -48,7 +48,7 @@ def get_all_accuracy_metrics(y_pred, y_true, verbose):
4848
return acc_metrics
4949

5050

51-
def get_all_average_calibration(y_pred, y_std, y_true, num_bins, verbose):
51+
def get_all_average_calibration(y_pred, y_std, y_true, num_bins, verbose=True):
5252

5353
if verbose:
5454
print(" (2/n) Calculating average calibration metrics")
@@ -67,7 +67,9 @@ def get_all_average_calibration(y_pred, y_std, y_true, num_bins, verbose):
6767
return cali_metrics
6868

6969

70-
def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbose):
70+
def get_all_adversarial_group_calibration(
71+
y_pred, y_std, y_true, num_bins, verbose=True
72+
):
7173

7274
adv_group_cali_metrics = {}
7375
if verbose:
@@ -80,6 +82,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
8082
y_true,
8183
cali_type="mean_abs",
8284
num_bins=num_bins,
85+
verbose=verbose,
8386
)
8487
ma_adv_group_size = ma_adv_group_cali.group_size
8588
ma_adv_group_cali_score_mean = ma_adv_group_cali.score_mean
@@ -99,6 +102,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
99102
y_true,
100103
cali_type="root_mean_sq",
101104
num_bins=num_bins,
105+
verbose=verbose,
102106
)
103107
rms_adv_group_size = rms_adv_group_cali.group_size
104108
rms_adv_group_cali_score_mean = rms_adv_group_cali.score_mean
@@ -112,7 +116,7 @@ def get_all_adversarial_group_calibration(y_pred, y_std, y_true, num_bins, verbo
112116
return adv_group_cali_metrics
113117

114118

115-
def get_all_sharpness_metrics(y_std, verbose):
119+
def get_all_sharpness_metrics(y_std, verbose=True):
116120

117121
if verbose:
118122
print(" (4/n) Calculating sharpness metrics")
@@ -123,7 +127,9 @@ def get_all_sharpness_metrics(y_std, verbose):
123127
return sharp_metrics
124128

125129

126-
def get_all_scoring_rule_metrics(y_pred, y_std, y_true, resolution, scaled, verbose):
130+
def get_all_scoring_rule_metrics(
131+
y_pred, y_std, y_true, resolution, scaled, verbose=True
132+
):
127133

128134
if verbose:
129135
print(" (n/n) Calculating proper scoring rule metrics")
@@ -143,15 +149,15 @@ def get_all_scoring_rule_metrics(y_pred, y_std, y_true, resolution, scaled, verb
143149

144150
def _print_adversarial_group_calibration(adv_group_metric_dic, print_group_num=3):
145151

146-
for adv_group_cali_type, adv_group_cali_dic in adv_group_metric_dic.items():
147-
num_groups = adv_group_cali_dic["group_sizes"].shape[0]
152+
for a_group_cali_type, a_group_cali_dic in adv_group_metric_dic.items():
153+
num_groups = a_group_cali_dic["group_sizes"].shape[0]
148154
print_idxs = [int(x) for x in np.linspace(1, num_groups - 1, print_group_num)]
149-
print(" {}".format(METRIC_NAMES[adv_group_cali_type]))
155+
print(" {}".format(METRIC_NAMES[a_group_cali_type]))
150156
for idx in print_idxs:
151157
print(
152158
" Group Size: {:.2f} -- Calibration Error: {:.3f}".format(
153-
adv_group_cali_dic["group_sizes"][idx],
154-
adv_group_cali_dic["adv_group_cali_mean"][idx],
159+
a_group_cali_dic["group_sizes"][idx],
160+
a_group_cali_dic["adv_group_cali_mean"][idx],
155161
)
156162
)
157163

@@ -160,44 +166,45 @@ def get_all_metrics(
160166
y_pred, y_std, y_true, num_bins=100, resolution=99, scaled=True, verbose=True
161167
):
162168

163-
""" Accuracy """
169+
# Accuracy
164170
accuracy_metrics = get_all_accuracy_metrics(y_pred, y_true, verbose)
165171

166-
""" Calibration """
172+
# Calibration
167173
calibration_metrics = get_all_average_calibration(
168174
y_pred, y_std, y_true, num_bins, verbose
169175
)
170176

171-
""" Adversarial Group Calibration """
177+
# Adversarial Group Calibration
172178
adv_group_cali_metrics = get_all_adversarial_group_calibration(
173179
y_pred, y_std, y_true, num_bins, verbose
174180
)
175181

176-
""" Sharpness """
182+
# Sharpness
177183
sharpness_metrics = get_all_sharpness_metrics(y_std, verbose)
178184

179-
""" Proper Scoring Rules """
185+
# Proper Scoring Rules
180186
scoring_rule_metrics = get_all_scoring_rule_metrics(
181187
y_pred, y_std, y_true, resolution, scaled, verbose
182188
)
183-
print("**Finished Calculating All Metrics**")
184189

185190
# Print all outputs
186-
print("\n")
187-
print(" Accuracy Metrics ".center(60, "="))
188-
for acc_metric, acc_val in accuracy_metrics.items():
189-
print(" {:<13} {:.3f}".format(METRIC_NAMES[acc_metric], acc_val))
190-
print(" Average Calibration Metrics ".center(60, "="))
191-
for cali_metric, cali_val in calibration_metrics.items():
192-
print(" {:<37} {:.3f}".format(METRIC_NAMES[cali_metric], cali_val))
193-
print(" Adversarial Group Calibration Metrics ".center(60, "="))
194-
_print_adversarial_group_calibration(adv_group_cali_metrics, print_group_num=3)
195-
print(" Sharpness Metrics ".center(60, "="))
196-
for sharp_metric, sharp_val in sharpness_metrics.items():
197-
print(" {:} {:.3f}".format(METRIC_NAMES[sharp_metric], sharp_val))
198-
print(" Scoring Rule Metrics ".center(60, "="))
199-
for sr_metric, sr_val in scoring_rule_metrics.items():
200-
print(" {:<25} {:.3f}".format(METRIC_NAMES[sr_metric], sr_val))
191+
if verbose:
192+
print("**Finished Calculating All Metrics**")
193+
print("\n")
194+
print(" Accuracy Metrics ".center(60, "="))
195+
for acc_metric, acc_val in accuracy_metrics.items():
196+
print(" {:<13} {:.3f}".format(METRIC_NAMES[acc_metric], acc_val))
197+
print(" Average Calibration Metrics ".center(60, "="))
198+
for cali_metric, cali_val in calibration_metrics.items():
199+
print(" {:<37} {:.3f}".format(METRIC_NAMES[cali_metric], cali_val))
200+
print(" Adversarial Group Calibration Metrics ".center(60, "="))
201+
_print_adversarial_group_calibration(adv_group_cali_metrics, print_group_num=3)
202+
print(" Sharpness Metrics ".center(60, "="))
203+
for sharp_metric, sharp_val in sharpness_metrics.items():
204+
print(" {:} {:.3f}".format(METRIC_NAMES[sharp_metric], sharp_val))
205+
print(" Scoring Rule Metrics ".center(60, "="))
206+
for sr_metric, sr_val in scoring_rule_metrics.items():
207+
print(" {:<25} {:.3f}".format(METRIC_NAMES[sr_metric], sr_val))
201208

202209
all_scores = {
203210
"accuracy": accuracy_metrics,
@@ -208,10 +215,3 @@ def get_all_metrics(
208215
}
209216

210217
return all_scores
211-
212-
213-
if __name__ == "__main__":
214-
y_pred = np.array([1, 2, 3, 4])
215-
y_std = np.array([1, 2, 3, 4])
216-
y_true = np.array([1.3, 2.3, 3.3, 4])
217-
get_all_metrics(y_pred, y_std, y_true)

uncertainty_toolbox/metrics_accuracy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Metrics for assessing the quality of predictive uncertainty quantification.
33
"""
4-
54
import numpy as np
65
from sklearn.metrics import (
76
mean_absolute_error,
@@ -10,8 +9,6 @@
109
median_absolute_error,
1110
)
1211

13-
""" Error, Calibration, Sharpness Metrics """
14-
1512

1613
def prediction_error_metrics(y_pred, y_true):
1714
"""

0 commit comments

Comments
 (0)