Skip to content

Commit eead4f3

Browse files
authored
improvements to diagnostic plots (#556)
* improvements to diagnostics plots add markersize parameter, add tests, support dataset_id for pairs_samples Fixes #554. * simplify test_calibration_ecdf_from_quantiles
1 parent 74a5036 commit eead4f3

File tree

9 files changed

+121
-21
lines changed

9 files changed

+121
-21
lines changed

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def calibration_ecdf_from_quantiles(
2626
fill_color: str = "grey",
2727
num_row: int = None,
2828
num_col: int = None,
29+
markersize: float = None,
2930
**kwargs,
3031
) -> plt.Figure:
3132
"""
@@ -97,6 +98,8 @@ def calibration_ecdf_from_quantiles(
9798
num_col : int, optional, default: None
9899
The number of columns for the subplots.
99100
Dynamically determined if None.
101+
markersize : float, optional, default: None
102+
The marker size in points.
100103
**kwargs : dict, optional, default: {}
101104
Keyword arguments can be passed to control the behavior of
102105
ECDF simultaneous band computation through the ``ecdf_bands_kwargs``
@@ -142,11 +145,15 @@ def calibration_ecdf_from_quantiles(
142145

143146
if stacked:
144147
if j == 0:
145-
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs")
148+
plot_data["axes"][0].plot(
149+
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDFs"
150+
)
146151
else:
147-
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95)
152+
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95)
148153
else:
149-
plot_data["axes"].flat[j].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDF")
154+
plot_data["axes"].flat[j].plot(
155+
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDF"
156+
)
150157

151158
# Compute uniform ECDF and bands
152159
alpha, z, L, U = pointwise_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def mc_calibration(
2727
color: str = "#132a70",
2828
num_col: int = None,
2929
num_row: int = None,
30+
markersize: float = None,
3031
) -> plt.Figure:
3132
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
3233
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
@@ -60,6 +61,8 @@ def mc_calibration(
6061
The number of rows for the subplots. Dynamically determined if None.
6162
num_col : int, optional, default: None
6263
The number of columns for the subplots. Dynamically determined if None.
64+
markersize : float, optional, default: None
65+
The marker size in points.
6366
6467
Returns
6568
-------
@@ -88,7 +91,7 @@ def mc_calibration(
8891

8992
for j, ax in enumerate(plot_data["axes"].flat):
9093
# Plot calibration curve
91-
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
94+
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color, markersize=markersize)
9295

9396
# Plot PMP distribution over bins
9497
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def pairs_posterior(
2424
prior_color: str | tuple = "gray",
2525
target_color: str | tuple = "red",
2626
alpha: float = 0.9,
27+
markersize: float = 40,
28+
target_markersize: float = 40,
2729
label_fontsize: int = 14,
2830
tick_fontsize: int = 12,
2931
legend_fontsize: int = 14,
@@ -62,6 +64,10 @@ def pairs_posterior(
6264
The color for the optional true parameter lines and points
6365
alpha : float in [0, 1], optional, default: 0.9
6466
The opacity of the posterior plots
67+
markersize : float, optional, default: 40
68+
The marker size in points**2 of the scatter plots
69+
target_markersize : float, optional, default: 40
70+
The marker size in points**2 of the target marker
6571
6672
**kwargs : dict, optional, default: {}
6773
Further optional keyword arguments propagated to `_pairs_samples`
@@ -101,6 +107,9 @@ def pairs_posterior(
101107
label_fontsize=label_fontsize,
102108
tick_fontsize=tick_fontsize,
103109
legend_fontsize=legend_fontsize,
110+
markersize=markersize,
111+
target_markersize=target_markersize,
112+
target_color=target_color,
104113
**kwargs,
105114
)
106115

@@ -114,7 +123,7 @@ def pairs_posterior(
114123
g.data = pd.DataFrame(targets, columns=targets.variable_names)
115124
g.data["_source"] = "True Parameter"
116125
g.map_diag(plot_true_params_as_lines, color=target_color)
117-
g.map_offdiag(plot_true_params_as_points, color=target_color)
126+
g.map_offdiag(plot_true_params_as_points, color=target_color, s=target_markersize)
118127

119128
create_legends(
120129
g,
@@ -124,6 +133,7 @@ def pairs_posterior(
124133
legend_fontsize=legend_fontsize,
125134
show_single_legend=False,
126135
target_color=target_color,
136+
target_markersize=target_markersize,
127137
)
128138

129139
return g

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
def pairs_samples(
1515
samples: Mapping[str, np.ndarray] | np.ndarray = None,
16+
dataset_id: int = None,
1617
variable_keys: Sequence[str] = None,
1718
variable_names: Sequence[str] = None,
1819
height: float = 2.5,
@@ -22,6 +23,7 @@ def pairs_samples(
2223
label_fontsize: int = 14,
2324
tick_fontsize: int = 12,
2425
show_single_legend: bool = False,
26+
markersize: float = 40,
2527
**kwargs,
2628
) -> sns.PairGrid:
2729
"""
@@ -32,6 +34,8 @@ def pairs_samples(
3234
----------
3335
samples : dict[str, Tensor], default: None
3436
Sample draws from any dataset
37+
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
38+
Should only be specified if estimates contain posterior draws from multiple datasets.
3539
variable_keys : list or None, optional, default: None
3640
Select keys from the dictionary provided in samples.
3741
By default, select all keys.
@@ -52,15 +56,23 @@ def pairs_samples(
5256
show_single_legend : bool, optional, default: False
5357
Optional toggle for the user to choose whether a single dataset
5458
should also display legend
59+
markersize : float, optional, default: 40
60+
Marker size in points**2 of the scatter plot.
5561
**kwargs : dict, optional
5662
Additional keyword arguments passed to the sns.PairGrid constructor
5763
"""
5864

5965
plot_data = dicts_to_arrays(
6066
estimates=samples,
67+
dataset_ids=dataset_id,
6168
variable_keys=variable_keys,
6269
variable_names=variable_names,
6370
)
71+
# dicts_to_arrays will keep the dataset axis even if it is of length 1
72+
# however, pairs plotting requires the dataset axis to be removed
73+
estimates_shape = plot_data["estimates"].shape
74+
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
75+
plot_data["estimates"] = np.squeeze(plot_data["estimates"], axis=0)
6476

6577
g = _pairs_samples(
6678
plot_data=plot_data,
@@ -71,6 +83,7 @@ def pairs_samples(
7183
label_fontsize=label_fontsize,
7284
tick_fontsize=tick_fontsize,
7385
show_single_legend=show_single_legend,
86+
markersize=markersize,
7487
**kwargs,
7588
)
7689

@@ -88,6 +101,9 @@ def _pairs_samples(
88101
tick_fontsize: int = 12,
89102
legend_fontsize: int = 14,
90103
show_single_legend: bool = False,
104+
markersize: float = 40,
105+
target_markersize: float = 40,
106+
target_color: str = "red",
91107
**kwargs,
92108
) -> sns.PairGrid:
93109
"""
@@ -101,6 +117,12 @@ def _pairs_samples(
101117
color2 : str, optional, default: 'gray'
102118
Secondary color for the pair plots.
103119
This is the color used for the prior draws.
120+
markersize : float, optional, default: 40
121+
Marker size in points**2 of the scatter plot.
122+
target_markersize : float, optional, default: 40
123+
Target marker size in points**2 of the scatter plot.
124+
target_color : str, optional, default: "red"
125+
Target marker color for the legend.
104126
105127
Other arguments are documented in pairs_samples
106128
"""
@@ -159,14 +181,14 @@ def _pairs_samples(
159181
)
160182

161183
# add scatter plots to the upper diagonal
162-
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
184+
g.map_upper(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0)
163185

164186
# add KDEs to the lower diagonal
165187
try:
166188
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha, common_norm=False)
167189
except Exception as e:
168190
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
169-
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
191+
g.map_lower(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0)
170192

171193
# Generate grids
172194
dim = g.axes.shape[0]
@@ -200,6 +222,9 @@ def _pairs_samples(
200222
legend_fontsize=legend_fontsize,
201223
label=label,
202224
show_single_legend=show_single_legend,
225+
markersize=markersize,
226+
target_markersize=target_markersize,
227+
target_color=target_color,
203228
)
204229

205230
# Return figure

bayesflow/diagnostics/plots/recovery.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def recovery(
2626
num_row: int = None,
2727
xlabel: str = "Ground truth",
2828
ylabel: str = "Estimate",
29+
markersize: float = None,
2930
**kwargs,
3031
) -> plt.Figure:
3132
"""
@@ -76,8 +77,10 @@ def recovery(
7677
The number of rows for the subplots. Dynamically determined if None.
7778
num_col : int, optional, default: None
7879
The number of columns for the subplots. Dynamically determined if None.
79-
xlabel:
80-
ylabel:
80+
xlabel :
81+
ylabel :
82+
markersize : float, optional, default: None
83+
The marker size in points.
8184
8285
Returns
8386
-------
@@ -122,10 +125,18 @@ def recovery(
122125
fmt="o",
123126
alpha=0.5,
124127
color=color,
128+
markersize=markersize,
125129
**kwargs,
126130
)
127131
else:
128-
_ = ax.scatter(targets[:, i], point_estimate[:, i], alpha=0.5, color=color, **kwargs)
132+
_ = ax.scatter(
133+
targets[:, i],
134+
point_estimate[:, i],
135+
alpha=0.5,
136+
color=color,
137+
s=None if markersize is None else markersize**2,
138+
**kwargs,
139+
)
129140

130141
make_quadratic(ax, targets[:, i], point_estimate[:, i])
131142

bayesflow/diagnostics/plots/recovery_from_estimates.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def recovery_from_estimates(
2525
num_row: int = None,
2626
xlabel: str = "Ground truth",
2727
ylabel: str = "Estimate",
28+
markersize: float = None,
2829
**kwargs,
2930
) -> plt.Figure:
3031
"""
@@ -79,8 +80,10 @@ def recovery_from_estimates(
7980
The number of rows for the subplots. Dynamically determined if None.
8081
num_col : int, optional, default: None
8182
The number of columns for the subplots. Dynamically determined if None.
82-
xlabel:
83-
ylabel:
83+
xlabel :
84+
ylabel :
85+
markersize : float, optional, default: None
86+
The marker size in points.
8487
8588
Returns
8689
-------
@@ -139,6 +142,7 @@ def recovery_from_estimates(
139142
marker=markers[q_idx],
140143
alpha=0.5,
141144
color=color,
145+
s=None if markersize is None else markersize**2,
142146
**kwargs,
143147
)
144148

bayesflow/diagnostics/plots/z_score_contraction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def z_score_contraction(
1818
color: str = "#132a70",
1919
num_col: int = None,
2020
num_row: int = None,
21+
markersize: float = None,
2122
) -> plt.Figure:
2223
"""
2324
Implements a graphical check for global model sensitivity by plotting the
@@ -76,6 +77,8 @@ def z_score_contraction(
7677
The number of rows for the subplots. Dynamically determined if None.
7778
num_col : int, optional, default: None
7879
The number of columns for the subplots. Dynamically determined if None.
80+
markersize : float, optional, default: None
81+
The marker size in points**2 of the scatter plot.
7982
8083
Returns
8184
-------
@@ -118,7 +121,7 @@ def z_score_contraction(
118121
if i >= plot_data["num_variables"]:
119122
break
120123

121-
ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5)
124+
ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5, s=markersize)
122125
ax.set_xlim([-0.05, 1.05])
123126

124127
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/utils/plot_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ def create_legends(
374374
label: str = "Posterior",
375375
show_single_legend: bool = False,
376376
legend_fontsize: int = 14,
377+
markersize: float = 40,
377378
target_color: str = "red",
379+
target_markersize: float = 40,
378380
):
379381
"""
380382
Helper function to create legends for pairplots.
@@ -396,8 +398,12 @@ def create_legends(
396398
should also display legend
397399
legend_fontsize : int, optional, default: 14
398400
fontsize for the legend
399-
target_color : str, optional, default "red"
401+
markersize : float, optional, default: 40
402+
The marker size in points**2
403+
target_color : str, optional, default: "red"
400404
Color for the target label
405+
target_markersize : float, optional, default: 40
406+
Marker size in points**2 of the target marker
401407
"""
402408
handles = []
403409
labels = []
@@ -414,7 +420,15 @@ def create_legends(
414420
labels.append(posterior_label)
415421

416422
if plot_data.get("targets") is not None:
417-
target_handle = plt.Line2D([0], [0], color=target_color, linestyle="--", marker="x", label="Targets")
423+
target_handle = plt.Line2D(
424+
[0],
425+
[0],
426+
color=target_color,
427+
linestyle="--",
428+
marker="x",
429+
markersize=np.sqrt(target_markersize),
430+
label="Targets",
431+
)
418432
target_label = "Targets"
419433
handles.append(target_handle)
420434
labels.append(target_label)

0 commit comments

Comments
 (0)