Skip to content

Commit c1407df

Browse files
authored
Add pairs plot for arbitrary quantities (#550)
Add pairs_quantity and plot_quantity functions that allow plotting of quantities that can be calculated for each individual dataset. Currently, for the provided metrics this is only useful for posterior contraction, but could be useful for posterior z-score and other quantities as well.
1 parent eead4f3 commit c1407df

File tree

9 files changed

+718
-38
lines changed

9 files changed

+718
-38
lines changed

bayesflow/diagnostics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
mc_confusion_matrix,
2020
mmd_hypothesis_test,
2121
pairs_posterior,
22+
pairs_quantity,
2223
pairs_samples,
24+
plot_quantity,
2325
recovery,
2426
recovery_from_estimates,
2527
z_score_contraction,

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def posterior_contraction(
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13-
aggregation: Callable = np.median,
13+
aggregation: Callable | None = np.median,
1414
) -> dict[str, any]:
1515
"""
1616
Computes the posterior contraction (PC) from prior to posterior for the given samples.
@@ -27,16 +27,17 @@ def posterior_contraction(
2727
By default, select all keys.
2828
variable_names : Sequence[str], optional (default = None)
2929
Optional variable names to show in the output.
30-
aggregation : callable, optional (default = np.median)
30+
aggregation : callable or None, optional (default = np.median)
3131
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
32+
If None is provided, the individual values are returned.
3233
3334
Returns
3435
-------
3536
result : dict
3637
Dictionary containing:
3738
3839
- "values" : float or np.ndarray
39-
The aggregated posterior contraction per variable
40+
The (optionally aggregated) posterior contraction per variable
4041
- "metric_name" : str
4142
The name of the metric ("Posterior Contraction").
4243
- "variable_names" : str
@@ -59,6 +60,7 @@ def posterior_contraction(
5960
post_vars = samples["estimates"].var(axis=1, ddof=1)
6061
prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1)
6162
contraction = np.clip(1 - (post_vars / prior_vars), 0, 1)
62-
contraction = aggregation(contraction, axis=0)
63+
if aggregation is not None:
64+
contraction = aggregation(contraction, axis=0)
6365
variable_names = samples["estimates"].variable_names
6466
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names}

bayesflow/diagnostics/plots/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .mc_confusion_matrix import mc_confusion_matrix
77
from .mmd_hypothesis_test import mmd_hypothesis_test
88
from .pairs_posterior import pairs_posterior
9+
from .pairs_quantity import pairs_quantity
10+
from .plot_quantity import plot_quantity
911
from .pairs_samples import pairs_samples
1012
from .recovery import recovery
1113
from .recovery_from_estimates import recovery_from_estimates

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
4-
import keras
54
import matplotlib.pyplot as plt
65

6+
from ...utils.dict_utils import compute_test_quantities
77
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
88
from ...utils.ecdf import simultaneous_ecdf_bands
99
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
@@ -136,38 +136,17 @@ def calibration_ecdf(
136136

137137
# Optionally, compute and prepend test quantities from draws
138138
if test_quantities is not None:
139-
test_quantities_estimates = {}
140-
test_quantities_targets = {}
141-
142-
for key, test_quantity_fn in test_quantities.items():
143-
# Apply test_quantity_func to ground-truths
144-
tq_targets = test_quantity_fn(data=targets)
145-
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
146-
147-
# Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148-
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
149-
flattened_estimates = keras.tree.map_structure(
150-
lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:]))
151-
if isinstance(t, np.ndarray)
152-
else t,
153-
estimates,
154-
)
155-
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
156-
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
157-
158-
# Add custom test quantities to variable keys and names for plotting
159-
# keys and names are set to the test_quantities dict keys
160-
test_quantities_names = list(test_quantities.keys())
161-
162-
if variable_keys is None:
163-
variable_keys = list(estimates.keys())
164-
165-
if isinstance(variable_names, list):
166-
variable_names = test_quantities_names + variable_names
167-
168-
variable_keys = test_quantities_names + variable_keys
169-
estimates = test_quantities_estimates | estimates
170-
targets = test_quantities_targets | targets
139+
updated_data = compute_test_quantities(
140+
targets=targets,
141+
estimates=estimates,
142+
variable_keys=variable_keys,
143+
variable_names=variable_names,
144+
test_quantities=test_quantities,
145+
)
146+
variable_names = updated_data["variable_names"]
147+
variable_keys = updated_data["variable_keys"]
148+
estimates = updated_data["estimates"]
149+
targets = updated_data["targets"]
171150

172151
plot_data = prepare_plot_data(
173152
estimates=estimates,
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from collections.abc import Callable, Sequence, Mapping
2+
3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
6+
import numpy as np
7+
import pandas as pd
8+
import seaborn as sns
9+
10+
11+
from .plot_quantity import _prepare_values
12+
13+
14+
def pairs_quantity(
15+
values: Mapping[str, np.ndarray] | np.ndarray | Callable,
16+
targets: Mapping[str, np.ndarray] | np.ndarray,
17+
*,
18+
variable_keys: Sequence[str] = None,
19+
variable_names: Sequence[str] = None,
20+
estimates: Mapping[str, np.ndarray] | np.ndarray | None = None,
21+
test_quantities: dict[str, Callable] = None,
22+
height: float = 2.5,
23+
cmap: str | matplotlib.colors.Colormap = "viridis",
24+
alpha: float = 0.9,
25+
markersize: float = 8.0,
26+
marker: str = "o",
27+
label: str = None,
28+
label_fontsize: int = 14,
29+
tick_fontsize: int = 12,
30+
colorbar_label_fontsize: int = 14,
31+
colorbar_tick_fontsize: int = 12,
32+
colorbar_width: float = 1.8,
33+
colorbar_height: float = 0.06,
34+
colorbar_offset: float = 0.06,
35+
vmin: float = None,
36+
vmax: float = None,
37+
default_name: str = "v",
38+
**kwargs,
39+
) -> sns.PairGrid:
40+
"""
41+
A pair plot function to plot quantities against their generating
42+
parameter values.
43+
44+
The value is indicated by a colormap. The marginal distribution for
45+
each parameter is plotted on the diagonal. Each column displays the
46+
values of corresponding to the parameter in the column.
47+
48+
The function supports the following different combinations to pass
49+
or compute the values:
50+
51+
1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables)
52+
2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names'
53+
as provided by the metrics functions. Note that the functions have to be called
54+
without aggregation to obtain value per dataset.
55+
3. pass a function to `values`, as well as `estimates`. The function should have the
56+
signature fn(estimates, targets, [aggregation]) and return an object like the
57+
`values` described in the previous options.
58+
59+
Parameters
60+
----------
61+
values : dict[str, np.ndarray] | np.ndarray | Callable,
62+
The value of the quantity to plot. One of the following:
63+
64+
1. an array of shape (num_datasets,) or (num_datasets, num_variables)
65+
2. a dictionary with the keys 'values', 'metric_name' and 'variable_names'
66+
as provided by the metrics functions. Note that the functions have to be called
67+
without aggregation to obtain value per dataset.
68+
3. a callable, requires passing `estimates` as well. The function should have the
69+
signature fn(estimates, targets, [aggregation]) and return an object like the
70+
ones described in the previous options.
71+
targets : dict[str, np.ndarray] | np.ndarray,
72+
The parameter values plotted on the axis.
73+
variable_keys : list or None, optional, default: None
74+
Select keys from the dictionary provided in samples.
75+
By default, select all keys.
76+
variable_names : list or None, optional, default: None
77+
The parameter names for nice plot titles. Inferred if None
78+
estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None
79+
The posterior draws obtained from n_data_sets. Can only be supplied if
80+
`values` is of type Callable.
81+
test_quantities : dict or None, optional, default: None
82+
A dict that maps plot titles to functions that compute
83+
test quantities based on estimate/target draws.
84+
Can only be supplied if `values` is a function.
85+
86+
The dict keys are automatically added to ``variable_keys``
87+
and ``variable_names``.
88+
Test quantity functions are expected to accept a dict of draws with
89+
shape ``(batch_size, ...)`` as the first (typically only)
90+
positional argument and return an NumPy array of shape
91+
``(batch_size,)``.
92+
The functions do not have to deal with an additional
93+
sample dimension, as appropriate reshaping is done internally.
94+
height : float, optional, default: 2.5
95+
The height of the pair plot
96+
cmap : str or Colormap, default: "viridis"
97+
The colormap for the plot.
98+
alpha : float in [0, 1], optional, default: 0.9
99+
The opacity of the plot
100+
markersize : float, optional, default: 8.0
101+
The marker size in points**2 for the scatter plot.
102+
marker : str, optional, default: 'o'
103+
The marker for the scatter plot.
104+
label : str, optional, default: None
105+
Label for the dataset to plot.
106+
label_fontsize : int, optional, default: 14
107+
The font size of the x and y-label texts (parameter names)
108+
tick_fontsize : int, optional, default: 12
109+
The font size of the axis tick labels
110+
colorbar_label_fontsize : int, optional, default: 14
111+
The font size of the colorbar label
112+
colorbar_tick_fontsize : int, optional, default: 12
113+
The font size of the colorbar tick labels
114+
colorbar_width : float, optional, default: 1.8
115+
The width of the colorbar in inches
116+
colorbar_height : float, optional, default: 0.06
117+
The height of the colorbar in inches
118+
colorbar_offset : float, optional, default: 0.06
119+
The vertical offset of the colorbar in inches
120+
vmin : float, optional, default: None
121+
Minimum value for the colormap. If None, the minimum value is
122+
determined from `values`.
123+
vmax : float, optional, default: None
124+
Maximum value for the colormap. If None, the maximum value is
125+
determined from `values`.
126+
default_name : str, optional (default = "v")
127+
The default name to use for estimates if None provided
128+
**kwargs : dict, optional
129+
Additional keyword arguments passed to the sns.PairGrid constructor
130+
131+
Returns
132+
-------
133+
plt.Figure
134+
The figure instance
135+
136+
Raises
137+
------
138+
ValueError
139+
If a callable is supplied as `values`, but `estimates` is None.
140+
"""
141+
142+
if isinstance(values, Callable) and estimates is None:
143+
raise ValueError("Supplied a callable as `values`, but no `estimates`.")
144+
if not isinstance(values, Callable) and test_quantities is not None:
145+
raise ValueError(
146+
"Supplied `test_quantities`, but `values` is not a function. "
147+
"As the values have to be calculated for the test quantities, "
148+
"passing a function is required."
149+
)
150+
151+
d = _prepare_values(
152+
values=values,
153+
targets=targets,
154+
estimates=estimates,
155+
variable_keys=variable_keys,
156+
variable_names=variable_names,
157+
test_quantities=test_quantities,
158+
label=label,
159+
default_name=default_name,
160+
)
161+
(values, targets, variable_keys, variable_names, test_quantities, label) = (
162+
d["values"],
163+
d["targets"],
164+
d["variable_keys"],
165+
d["variable_names"],
166+
d["test_quantities"],
167+
d["label"],
168+
)
169+
170+
# Convert samples to pd.DataFrame
171+
data_to_plot = pd.DataFrame(targets, columns=variable_names)
172+
173+
# initialize plot
174+
g = sns.PairGrid(
175+
data_to_plot,
176+
height=height,
177+
vars=variable_names,
178+
**kwargs,
179+
)
180+
181+
vmin = values.min() if vmin is None else vmin
182+
vmax = values.max() if vmax is None else vmax
183+
184+
# Generate grids
185+
dim = g.axes.shape[0]
186+
for i in range(dim):
187+
for j in range(dim):
188+
# if one value for each variable is supplied, use it for the corresponding column
189+
row_values = values[:, j] if values.ndim == 2 else values
190+
191+
if i == j:
192+
ax = g.axes[i, j].twinx()
193+
ax.scatter(
194+
targets[:, i],
195+
values[:, i],
196+
c=row_values,
197+
cmap=cmap,
198+
s=markersize,
199+
marker=marker,
200+
vmin=vmin,
201+
vmax=vmax,
202+
alpha=alpha,
203+
)
204+
ax.spines["left"].set_visible(False)
205+
ax.spines["top"].set_visible(False)
206+
ax.tick_params(axis="both", which="major", labelsize=tick_fontsize)
207+
ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize)
208+
ax.set_ylim(vmin, vmax)
209+
210+
if i > 0:
211+
g.axes[i, j].get_yaxis().set_visible(False)
212+
g.axes[i, j].spines["left"].set_visible(False)
213+
if i == dim - 1:
214+
ax.set_ylabel(label, size=label_fontsize)
215+
else:
216+
g.axes[i, j].grid(alpha=0.5)
217+
g.axes[i, j].set_axisbelow(True)
218+
g.axes[i, j].scatter(
219+
targets[:, j],
220+
targets[:, i],
221+
c=row_values,
222+
cmap=cmap,
223+
s=markersize,
224+
vmin=vmin,
225+
vmax=vmax,
226+
alpha=alpha,
227+
marker=marker,
228+
)
229+
230+
def inches_to_figure(fig, values):
231+
return fig.transFigure.inverted().transform(fig.dpi_scale_trans.transform(values))
232+
233+
# position and draw colorbar
234+
_, yoffset = inches_to_figure(g.figure, [0, colorbar_offset])
235+
cwidth, cheight = inches_to_figure(g.figure, [colorbar_width, colorbar_offset])
236+
cax = g.figure.add_axes([0.5 - cwidth / 2, -yoffset - cheight, cwidth, cheight])
237+
238+
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
239+
cbar = plt.colorbar(
240+
matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap),
241+
cax=cax,
242+
location="bottom",
243+
label=label,
244+
alpha=alpha,
245+
)
246+
247+
cbar.set_label(label, size=colorbar_label_fontsize)
248+
cax.tick_params(labelsize=colorbar_tick_fontsize)
249+
250+
dim = g.axes.shape[0]
251+
for i in range(dim):
252+
# Modify tick sizes
253+
for j in range(i + 1):
254+
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
255+
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
256+
257+
# adjust the font size of labels
258+
# the labels themselves remain the same as before, i.e., variable_names
259+
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
260+
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
261+
262+
return g

0 commit comments

Comments
 (0)