Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pypesto/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
parameters_lowlevel,
)
from .profile_cis import profile_cis, profile_nested_cis
from .profiles import profile_lowlevel, profiles, profiles_lowlevel
from .profiles import (
profile_lowlevel,
profile_lowlevel_2d,
profiles,
profiles_lowlevel,
visualize_2d_profile,
)
from .reference_points import ReferencePoint, create_references
from .sampling import (
sampling_1d_marginals,
Expand Down
305 changes: 305 additions & 0 deletions pypesto/visualize/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,308 @@ def process_profile_indices(
)

return profile_indices_ret


def profile_lowlevel_2d(
result: Result,
profile_index: int,
second_par_index: int,
ax: plt.Axes,
profile_list_id: int = 0,
ratio_min: float = 0.0,
cmap: str = "viridis",
show_bounds: bool = False,
plot_objective_values: bool = False,
show_colorbar: bool = False,
) -> plt.Axes:
"""
Lowlevel routine for plotting a two-parameter profile visualization.

This function visualizes the profile of one parameter (x-axis) while showing
the values of a second parameter (y-axis), with colors indicating the
objective ratio or function value.

Parameters
----------
result:
A single `pypesto.Result` after profiling.
profile_index:
Integer index specifying which profile to plot (x-axis parameter).
second_par_index:
Integer index specifying which parameter to show on y-axis.
ax:
Axes object to use for plotting.
profile_list_id:
Index of the profile list to visualize.
ratio_min:
Minimum ratio below which to cut off.
cmap:
Colormap to use for the objective ratio/value colors.
show_bounds:
Whether to extend the plot to show parameter bounds.
plot_objective_values:
Whether to plot the objective function values instead of the likelihood
ratio values.
show_colorbar:
Whether to show a colorbar for this subplot.

Returns
-------
scatter:
The scatter plot object (for potential colorbar creation).
"""
# Get the profile result
if result.profile_result is None:
raise ValueError("Result does not contain profile results.")

profile_list = result.profile_result.list[profile_list_id]

if profile_list[profile_index] is None:
raise ValueError(f"Profile for parameter {profile_index} has not been computed.")

profiler_result = profile_list[profile_index]

# Extract data from the profile
x_path = profiler_result.x_path
ratio_path = profiler_result.ratio_path
fval_path = profiler_result.fval_path

# Get the parameter values
x_values = x_path[profile_index, :] # Profiled parameter values (x-axis)
y_values = x_path[second_par_index, :] # Second parameter values (y-axis)

# Get color values (either ratio or objective values)
if plot_objective_values:
color_values = fval_path
else:
color_values = ratio_path

# Filter based on ratio_min
indices = np.where(ratio_path > ratio_min)
x_values = x_values[indices]
y_values = y_values[indices]
color_values = color_values[indices]

# Create the scatter plot with color mapping
scatter = ax.scatter(
x_values,
y_values,
c=color_values,
cmap=cmap,
s=30,
edgecolors='black',
linewidths=0.3
)

# Add a line connecting the points to show the profile path
ax.plot(x_values, y_values, 'k-', alpha=0.2, linewidth=0.8, zorder=0)

# Set labels
x_label = result.problem.x_names[profile_index] if hasattr(result.problem, 'x_names') else f"Parameter {profile_index}"
y_label = result.problem.x_names[second_par_index] if hasattr(result.problem, 'x_names') else f"Parameter {second_par_index}"

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

# Optionally show bounds
if show_bounds:
if result.problem.lb_full is not None and result.problem.ub_full is not None:
lb_x = result.problem.lb_full[profile_index]
ub_x = result.problem.ub_full[profile_index]
lb_y = result.problem.lb_full[second_par_index]
ub_y = result.problem.ub_full[second_par_index]
ax.set_xlim([lb_x, ub_x])
ax.set_ylim([lb_y, ub_y])

# Draw boundary lines
ax.axhline(y=lb_y, color='red', linestyle='--', alpha=0.3, linewidth=0.8)
ax.axhline(y=ub_y, color='red', linestyle='--', alpha=0.3, linewidth=0.8)
ax.axvline(x=lb_x, color='red', linestyle='--', alpha=0.3, linewidth=0.8)
ax.axvline(x=ub_x, color='red', linestyle='--', alpha=0.3, linewidth=0.8)

ax.grid(True, alpha=0.3)

# Add colorbar if requested
if show_colorbar:
cbar = plt.colorbar(scatter, ax=ax)
if plot_objective_values:
cbar.set_label('Objective value', rotation=270, labelpad=15)
else:
cbar.set_label('Log-posterior ratio', rotation=270, labelpad=15)

return scatter


def visualize_2d_profile(
result: Result,
profile_indices: Sequence[int] = None,
size: tuple[float, float] = None,
profile_list_id: int = 0,
ratio_min: float = 0.0,
cmap: str = "viridis",
show_bounds: bool = False,
plot_objective_values: bool = False,
reference: ReferencePoint | Sequence[ReferencePoint] = None,
) -> tuple[plt.Figure, np.ndarray]:
"""
Create an n×n grid of profile plots.

Diagonal plots show 1D profiles, off-diagonal plots show 2D parameter
relationships during profiling.

Parameters
----------
result:
A single `pypesto.Result` after profiling.
profile_indices:
List of integer indices specifying which parameters to include.
If None, all parameters with computed profiles are included.
size:
Figure size (width, height) in inches. If None, automatically sized
based on number of parameters.
profile_list_id:
Index of the profile list to visualize.
ratio_min:
Minimum ratio below which to cut off.
cmap:
Colormap to use for the 2D plots.
show_bounds:
Whether to extend plots to show parameter bounds.
plot_objective_values:
Whether to plot the objective function values instead of the likelihood
ratio values.
reference:
List of reference points for optimization results.

Returns
-------
fig:
The figure object.
axes:
Array of axes objects (n×n grid).
"""
# Get the profile result
if result.profile_result is None:
raise ValueError("Result does not contain profile results.")

profile_list = result.profile_result.list[profile_list_id]

# Determine which profiles to plot
if profile_indices is None:
profile_indices = [
i for i, prof in enumerate(profile_list)
if prof is not None
]

n_params = len(profile_indices)

if n_params == 0:
raise ValueError("No profiles available to plot.")

# Determine figure size
if size is None:
size_per_subplot = 3
size = (n_params * size_per_subplot, n_params * size_per_subplot)

# Create the figure with n×n subplots
fig, axes = plt.subplots(n_params, n_params, figsize=size)

# Ensure axes is always 2D array
if n_params == 1:
axes = np.array([[axes]])
elif axes.ndim == 1:
axes = axes.reshape(-1, 1)

# Create reference points for 1D profiles
ref = create_references(references=reference)

# Track scatter objects for shared colorbar
scatter_objects = []

# Loop through all subplots
for i, row_param_idx in enumerate(profile_indices):
for j, col_param_idx in enumerate(profile_indices):
ax = axes[i, j]

if i == j:
# Diagonal: Plot 1D profile using existing function
# Get the profile data
fvals, _ = handle_inputs(
result,
profile_indices=[row_param_idx],
profile_list=profile_list_id,
ratio_min=ratio_min,
plot_objective_values=plot_objective_values,
)

if not fvals[row_param_idx] is None:
# Plot the 1D profile
profile_lowlevel(
fvals[row_param_idx],
ax,
show_bounds=show_bounds,
lb=result.problem.lb_full[row_param_idx] if result.problem.lb_full is not None else None,
ub=result.problem.ub_full[row_param_idx] if result.problem.ub_full is not None else None,
)

# Set labels
x_label = result.problem.x_names[row_param_idx] if hasattr(result.problem, 'x_names') else f"Par {row_param_idx}"
ax.set_xlabel(x_label)
if j == 0:
if plot_objective_values:
ax.set_ylabel("Objective value")
else:
ax.set_ylabel("Log-posterior ratio")
else:
ax.set_ylabel("")

# Add reference points
if len(ref) > 0:
for i_ref in ref:
current_x = i_ref["x"][row_param_idx]
ax.plot(
[current_x, current_x],
[0.0, 1.0],
color=i_ref.color,
label=i_ref.legend if i == 0 and j == 0 else None,
)
if i == 0 and j == 0 and i_ref.legend is not None:
ax.legend()

else:
# Off-diagonal: Plot 2D profile
# For subplot (i, j): profile col_param_idx (x-axis) vs row_param_idx (y-axis)
try:
scatter = profile_lowlevel_2d(
result=result,
profile_index=col_param_idx,
second_par_index=row_param_idx,
ax=ax,
profile_list_id=profile_list_id,
ratio_min=ratio_min,
cmap=cmap,
show_bounds=show_bounds,
plot_objective_values=plot_objective_values,
show_colorbar=False,
)
scatter_objects.append((scatter, ax))
except (ValueError, IndexError):
# If profile doesn't exist, leave the subplot empty
ax.text(0.5, 0.5, 'No profile', ha='center', va='center', transform=ax.transAxes)
ax.set_xticks([])
ax.set_yticks([])

# Add a shared colorbar on the right side for 2D plots
if scatter_objects:
# Use the last scatter object for the colorbar
scatter, _ = scatter_objects[-1]
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(scatter, cax=cbar_ax)
if plot_objective_values:
cbar.set_label('Objective value', rotation=270, labelpad=20)
else:
cbar.set_label('Log-posterior ratio', rotation=270, labelpad=20)

plt.tight_layout(rect=[0, 0, 0.9, 1])

return fig, axes
Loading