diff --git a/pypesto/visualize/__init__.py b/pypesto/visualize/__init__.py index 7318cf9c9..0362aeb83 100644 --- a/pypesto/visualize/__init__.py +++ b/pypesto/visualize/__init__.py @@ -44,7 +44,13 @@ parameters_lowlevel, ) from .profile_cis import profile_cis, profile_nested_cis -from .profiles import profile_lowlevel, profiles, profiles_lowlevel +from .profiles import ( + profile_lowlevel, + profile_lowlevel_2d, + profiles, + profiles_lowlevel, + visualize_2d_profile, +) from .reference_points import ReferencePoint, create_references from .sampling import ( sampling_1d_marginals, diff --git a/pypesto/visualize/profiles.py b/pypesto/visualize/profiles.py index a556af3e8..4353d2fef 100644 --- a/pypesto/visualize/profiles.py +++ b/pypesto/visualize/profiles.py @@ -609,3 +609,308 @@ def process_profile_indices( ) return profile_indices_ret + + +def profile_lowlevel_2d( + result: Result, + profile_index: int, + second_par_index: int, + ax: plt.Axes, + profile_list_id: int = 0, + ratio_min: float = 0.0, + cmap: str = "viridis", + show_bounds: bool = False, + plot_objective_values: bool = False, + show_colorbar: bool = False, +) -> plt.Axes: + """ + Lowlevel routine for plotting a two-parameter profile visualization. + + This function visualizes the profile of one parameter (x-axis) while showing + the values of a second parameter (y-axis), with colors indicating the + objective ratio or function value. + + Parameters + ---------- + result: + A single `pypesto.Result` after profiling. + profile_index: + Integer index specifying which profile to plot (x-axis parameter). + second_par_index: + Integer index specifying which parameter to show on y-axis. + ax: + Axes object to use for plotting. + profile_list_id: + Index of the profile list to visualize. + ratio_min: + Minimum ratio below which to cut off. + cmap: + Colormap to use for the objective ratio/value colors. + show_bounds: + Whether to extend the plot to show parameter bounds. + plot_objective_values: + Whether to plot the objective function values instead of the likelihood + ratio values. + show_colorbar: + Whether to show a colorbar for this subplot. + + Returns + ------- + scatter: + The scatter plot object (for potential colorbar creation). + """ + # Get the profile result + if result.profile_result is None: + raise ValueError("Result does not contain profile results.") + + profile_list = result.profile_result.list[profile_list_id] + + if profile_list[profile_index] is None: + raise ValueError(f"Profile for parameter {profile_index} has not been computed.") + + profiler_result = profile_list[profile_index] + + # Extract data from the profile + x_path = profiler_result.x_path + ratio_path = profiler_result.ratio_path + fval_path = profiler_result.fval_path + + # Get the parameter values + x_values = x_path[profile_index, :] # Profiled parameter values (x-axis) + y_values = x_path[second_par_index, :] # Second parameter values (y-axis) + + # Get color values (either ratio or objective values) + if plot_objective_values: + color_values = fval_path + else: + color_values = ratio_path + + # Filter based on ratio_min + indices = np.where(ratio_path > ratio_min) + x_values = x_values[indices] + y_values = y_values[indices] + color_values = color_values[indices] + + # Create the scatter plot with color mapping + scatter = ax.scatter( + x_values, + y_values, + c=color_values, + cmap=cmap, + s=30, + edgecolors='black', + linewidths=0.3 + ) + + # Add a line connecting the points to show the profile path + ax.plot(x_values, y_values, 'k-', alpha=0.2, linewidth=0.8, zorder=0) + + # Set labels + x_label = result.problem.x_names[profile_index] if hasattr(result.problem, 'x_names') else f"Parameter {profile_index}" + y_label = result.problem.x_names[second_par_index] if hasattr(result.problem, 'x_names') else f"Parameter {second_par_index}" + + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + # Optionally show bounds + if show_bounds: + if result.problem.lb_full is not None and result.problem.ub_full is not None: + lb_x = result.problem.lb_full[profile_index] + ub_x = result.problem.ub_full[profile_index] + lb_y = result.problem.lb_full[second_par_index] + ub_y = result.problem.ub_full[second_par_index] + ax.set_xlim([lb_x, ub_x]) + ax.set_ylim([lb_y, ub_y]) + + # Draw boundary lines + ax.axhline(y=lb_y, color='red', linestyle='--', alpha=0.3, linewidth=0.8) + ax.axhline(y=ub_y, color='red', linestyle='--', alpha=0.3, linewidth=0.8) + ax.axvline(x=lb_x, color='red', linestyle='--', alpha=0.3, linewidth=0.8) + ax.axvline(x=ub_x, color='red', linestyle='--', alpha=0.3, linewidth=0.8) + + ax.grid(True, alpha=0.3) + + # Add colorbar if requested + if show_colorbar: + cbar = plt.colorbar(scatter, ax=ax) + if plot_objective_values: + cbar.set_label('Objective value', rotation=270, labelpad=15) + else: + cbar.set_label('Log-posterior ratio', rotation=270, labelpad=15) + + return scatter + + +def visualize_2d_profile( + result: Result, + profile_indices: Sequence[int] = None, + size: tuple[float, float] = None, + profile_list_id: int = 0, + ratio_min: float = 0.0, + cmap: str = "viridis", + show_bounds: bool = False, + plot_objective_values: bool = False, + reference: ReferencePoint | Sequence[ReferencePoint] = None, +) -> tuple[plt.Figure, np.ndarray]: + """ + Create an n×n grid of profile plots. + + Diagonal plots show 1D profiles, off-diagonal plots show 2D parameter + relationships during profiling. + + Parameters + ---------- + result: + A single `pypesto.Result` after profiling. + profile_indices: + List of integer indices specifying which parameters to include. + If None, all parameters with computed profiles are included. + size: + Figure size (width, height) in inches. If None, automatically sized + based on number of parameters. + profile_list_id: + Index of the profile list to visualize. + ratio_min: + Minimum ratio below which to cut off. + cmap: + Colormap to use for the 2D plots. + show_bounds: + Whether to extend plots to show parameter bounds. + plot_objective_values: + Whether to plot the objective function values instead of the likelihood + ratio values. + reference: + List of reference points for optimization results. + + Returns + ------- + fig: + The figure object. + axes: + Array of axes objects (n×n grid). + """ + # Get the profile result + if result.profile_result is None: + raise ValueError("Result does not contain profile results.") + + profile_list = result.profile_result.list[profile_list_id] + + # Determine which profiles to plot + if profile_indices is None: + profile_indices = [ + i for i, prof in enumerate(profile_list) + if prof is not None + ] + + n_params = len(profile_indices) + + if n_params == 0: + raise ValueError("No profiles available to plot.") + + # Determine figure size + if size is None: + size_per_subplot = 3 + size = (n_params * size_per_subplot, n_params * size_per_subplot) + + # Create the figure with n×n subplots + fig, axes = plt.subplots(n_params, n_params, figsize=size) + + # Ensure axes is always 2D array + if n_params == 1: + axes = np.array([[axes]]) + elif axes.ndim == 1: + axes = axes.reshape(-1, 1) + + # Create reference points for 1D profiles + ref = create_references(references=reference) + + # Track scatter objects for shared colorbar + scatter_objects = [] + + # Loop through all subplots + for i, row_param_idx in enumerate(profile_indices): + for j, col_param_idx in enumerate(profile_indices): + ax = axes[i, j] + + if i == j: + # Diagonal: Plot 1D profile using existing function + # Get the profile data + fvals, _ = handle_inputs( + result, + profile_indices=[row_param_idx], + profile_list=profile_list_id, + ratio_min=ratio_min, + plot_objective_values=plot_objective_values, + ) + + if not fvals[row_param_idx] is None: + # Plot the 1D profile + profile_lowlevel( + fvals[row_param_idx], + ax, + show_bounds=show_bounds, + lb=result.problem.lb_full[row_param_idx] if result.problem.lb_full is not None else None, + ub=result.problem.ub_full[row_param_idx] if result.problem.ub_full is not None else None, + ) + + # Set labels + x_label = result.problem.x_names[row_param_idx] if hasattr(result.problem, 'x_names') else f"Par {row_param_idx}" + ax.set_xlabel(x_label) + if j == 0: + if plot_objective_values: + ax.set_ylabel("Objective value") + else: + ax.set_ylabel("Log-posterior ratio") + else: + ax.set_ylabel("") + + # Add reference points + if len(ref) > 0: + for i_ref in ref: + current_x = i_ref["x"][row_param_idx] + ax.plot( + [current_x, current_x], + [0.0, 1.0], + color=i_ref.color, + label=i_ref.legend if i == 0 and j == 0 else None, + ) + if i == 0 and j == 0 and i_ref.legend is not None: + ax.legend() + + else: + # Off-diagonal: Plot 2D profile + # For subplot (i, j): profile col_param_idx (x-axis) vs row_param_idx (y-axis) + try: + scatter = profile_lowlevel_2d( + result=result, + profile_index=col_param_idx, + second_par_index=row_param_idx, + ax=ax, + profile_list_id=profile_list_id, + ratio_min=ratio_min, + cmap=cmap, + show_bounds=show_bounds, + plot_objective_values=plot_objective_values, + show_colorbar=False, + ) + scatter_objects.append((scatter, ax)) + except (ValueError, IndexError): + # If profile doesn't exist, leave the subplot empty + ax.text(0.5, 0.5, 'No profile', ha='center', va='center', transform=ax.transAxes) + ax.set_xticks([]) + ax.set_yticks([]) + + # Add a shared colorbar on the right side for 2D plots + if scatter_objects: + # Use the last scatter object for the colorbar + scatter, _ = scatter_objects[-1] + cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) + cbar = fig.colorbar(scatter, cax=cbar_ax) + if plot_objective_values: + cbar.set_label('Objective value', rotation=270, labelpad=20) + else: + cbar.set_label('Log-posterior ratio', rotation=270, labelpad=20) + + plt.tight_layout(rect=[0, 0, 0.9, 1]) + + return fig, axes