forked from uncertainty-toolbox/uncertainty-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
viz_recalibrate_readme.py
90 lines (77 loc) · 2.94 KB
/
viz_recalibrate_readme.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
This script produces the recalibration figures that appear in the README Recalibration
section:
https://github.com/uncertainty-toolbox/uncertainty-toolbox/blob/recal/README.md#recalibration
"""
import numpy as np
import matplotlib.pyplot as plt
import uncertainty_toolbox as uct
# Set plot style
uct.viz.set_style()
uct.viz.update_rc("text.usetex", True) # Set to True for system latex
uct.viz.update_rc("font.size", 14) # Set font size
uct.viz.update_rc("xtick.labelsize", 14) # Set font size for xaxis tick labels
uct.viz.update_rc("ytick.labelsize", 14) # Set font size for yaxis tick labels
# Set random seed
np.random.seed(11)
# Generate synthetic predictive uncertainty results
n_obs = 650
f, std, y, x = uct.synthetic_sine_heteroscedastic(n_obs)
# Save figure (set to True to save)
savefig = True
# List of predictive means and standard deviations
pred_mean_list = [f]
pred_std_list = [
std * 0.5, # overconfident
std * 2.0, # underconfident
]
# Loop through, make plots, and compute metrics
for i, pred_mean in enumerate(pred_mean_list):
for j, pred_std in enumerate(pred_std_list):
# Before recalibration
exp_props, obs_props = uct.get_proportion_lists_vectorized(
pred_mean, pred_std, y
)
mace = uct.mean_absolute_calibration_error(
pred_mean, pred_std, y, recal_model=None
)
rmsce = uct.root_mean_squared_calibration_error(
pred_mean, pred_std, y, recal_model=None
)
ma = uct.miscalibration_area(pred_mean, pred_std, y, recal_model=None)
print("Before Recalibration: ", end="")
print("MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
uct.plot_calibration(
pred_mean,
pred_std,
y,
exp_props=exp_props,
obs_props=obs_props,
ax=ax,
)
uct.viz.save_figure(f"before_recal_{j}", "svg")
# After recalibration
recal_model = uct.iso_recal(exp_props, obs_props)
recal_exp_props, recal_obs_props = uct.get_proportion_lists_vectorized(
pred_mean, pred_std, y, recal_model=recal_model
)
mace = uct.mean_absolute_calibration_error(
pred_mean, pred_std, y, recal_model=recal_model
)
rmsce = uct.root_mean_squared_calibration_error(
pred_mean, pred_std, y, recal_model=recal_model
)
ma = uct.miscalibration_area(pred_mean, pred_std, y, recal_model=recal_model)
print("After Recalibration: ", end="")
print("MACE: {:.5f}, RMSCE: {:.5f}, MA: {:.5f}".format(mace, rmsce, ma))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
uct.plot_calibration(
pred_mean,
pred_std,
y,
exp_props=recal_exp_props,
obs_props=recal_obs_props,
ax=ax,
)
uct.viz.save_figure(f"after_recal_{j}", "svg")