forked from uncertainty-toolbox/uncertainty-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathviz_synth_sine.py
88 lines (64 loc) · 2.35 KB
/
viz_synth_sine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Examples of code for visualizations.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
import uncertainty_toolbox.data as udata
import uncertainty_toolbox.metrics as umetrics
import uncertainty_toolbox.viz as uviz
import neatplot
neatplot.set_style()
neatplot.update_rc("text.usetex", False) # Set to True for system latex
# Set random seed
np.random.seed(11)
# Generate synthetic predictive uncertainty results
n_obs = 650
f, std, y, x = udata.synthetic_sine_heteroscedastic(n_obs)
# Save figure (set to True to save)
savefig = False
def save_figure(name_str, file_type="png"):
"""Save figure, or do nothing if savefig is False."""
if savefig:
neatplot.save_figure(name_str, file_type)
def update_rc_params():
"""Update matplotlib rc params."""
plt.rcParams.update({"font.size": 14})
plt.rcParams.update({"xtick.labelsize": 14})
plt.rcParams.update({"ytick.labelsize": 14})
def make_plots(pred_mean, pred_std, idx1, idx2):
"""Make set of plots."""
update_rc_params()
ylims = [-3, 3]
n_subset = 50
# Make xy plot
uviz.plot_xy(pred_mean, pred_std, y, x, n_subset=300, ylims=ylims, xlims=[0, 15])
save_figure(f"xy_{idx1}_{idx2}")
plt.show()
# Make intervals plot
uviz.plot_intervals(pred_mean, pred_std, y, n_subset=n_subset, ylims=ylims)
save_figure(f"intervals_{idx1}_{idx2}")
plt.show()
# Make calibration plot
uviz.plot_calibration(pred_mean, pred_std, y)
save_figure(f"calibration_{idx1}_{idx2}")
plt.show()
# Make ordered intervals plot
uviz.plot_intervals_ordered(pred_mean, pred_std, y, n_subset=n_subset, ylims=ylims)
save_figure(f"intervals_ordered_{idx1}_{idx2}")
plt.show()
# List of predictive means and standard deviations
pred_mean_list = [f]
pred_std_list = [
std * 0.5, # overconfident
std * 2.0, # underconfident
std, # correct
]
# Loop through, make plots, and compute metrics
for i, pred_mean in enumerate(pred_mean_list):
for j, pred_std in enumerate(pred_std_list):
mace = umetrics.mean_absolute_calibration_error(pred_mean, pred_std, y)
rmsce = umetrics.root_mean_squared_calibration_error(pred_mean, pred_std, y)
ma = umetrics.miscalibration_area(pred_mean, pred_std, y)
make_plots(pred_mean, pred_std, i, j)
print(f"MACE: {mace}, RMSCE: {rmsce}, MA: {ma}")