From 4cc930c11d340ed494a2864cffe88471cc65c29c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 9 Jul 2025 12:14:16 -0700 Subject: [PATCH 001/101] simualte different embeddings --- tests/representation/test_distance.py | 1149 +++++++++++++++++ viscy/representation/evaluation/clustering.py | 112 ++ 2 files changed, 1261 insertions(+) create mode 100644 tests/representation/test_distance.py diff --git a/tests/representation/test_distance.py b/tests/representation/test_distance.py new file mode 100644 index 000000000..d1146f251 --- /dev/null +++ b/tests/representation/test_distance.py @@ -0,0 +1,1149 @@ +# %% +from typing import Literal + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr +from scipy import stats + +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, +) + + +def generate_directional_embeddings_corrected( + n_timepoints: int = 100, + embedding_dim: int = 3, + n_tracks: int = 5, + movement_type: Literal[ + "smooth", "mild_chaos", "moderate_chaos", "high_chaos" + ] = "smooth", + target_direction: np.ndarray = None, + noise_std: float = 0.05, + seed: int = 42, + normalize_method: Literal["zscore", "l2"] | None = "zscore", +) -> xr.Dataset: + """ + Generate embeddings with multiple chaos levels. + + Parameters + ---------- + movement_type : str + - "smooth": Consistent direction and step size + - "mild_chaos": Slight randomness, similar to smooth + - "moderate_chaos": Moderate randomness and variability + - "high_chaos": High randomness and large jumps + """ + np.random.seed(seed) + + # Default target direction (toward positive x-axis) + if target_direction is None: + target_direction = np.zeros(embedding_dim) + target_direction[0] = 2.0 + + # Normalize target direction + target_direction = target_direction / (np.linalg.norm(target_direction) + 1e-8) + + # Define chaos parameters for each movement type + chaos_params = { + "smooth": { + "random_prob": 0.0, + "noise_scale": 0.15, + "jump_prob": 0.0, + "base_step": 0.12, + "step_std": 0.15, + }, + "mild_chaos": { + "random_prob": 0.1, + "noise_scale": 0.2, + "jump_prob": 0.03, + "exp_scales": [0.15, 0.25], + "jump_range": (1.5, 2.5), + }, + "moderate_chaos": { + "random_prob": 0.25, + "noise_scale": 0.3, + "jump_prob": 0.08, + "exp_scales": [0.12, 0.3, 0.6], + "jump_range": (2.0, 4.0), + }, + "high_chaos": { + "random_prob": 0.4, + "noise_scale": 0.4, + "jump_prob": 0.15, + "exp_scales": [0.1, 0.3, 0.8], + "jump_range": (3.0, 8.0), + }, + } + + params = chaos_params[movement_type] + + all_embeddings = [] + all_indices = [] + fov_name = "000000" + + for track_id in range(n_tracks): + timepoints = np.arange(n_timepoints) + embeddings = np.zeros((n_timepoints, embedding_dim)) + embeddings[0] = np.random.randn(embedding_dim) * 0.5 + + for t in range(1, n_timepoints): + if movement_type == "smooth": + # Smooth movement (original logic) + random_component = ( + np.random.randn(embedding_dim) * params["noise_scale"] + ) + direction = target_direction + random_component + direction = direction / (np.linalg.norm(direction) + 1e-8) + + step_size = params["base_step"] * ( + 1 + np.random.normal(0, params["step_std"]) + ) + step_size = max(0.05, step_size) + + else: + # Chaotic movement with varying levels + # Direction logic + if np.random.random() < params["random_prob"]: + direction = np.random.randn(embedding_dim) + direction = direction / (np.linalg.norm(direction) + 1e-8) + else: + random_component = ( + np.random.randn(embedding_dim) * params["noise_scale"] + ) + direction = target_direction + random_component + direction = direction / (np.linalg.norm(direction) + 1e-8) + + # Step size distribution + exp_scales = params["exp_scales"] + if len(exp_scales) == 2: # mild_chaos + if np.random.random() < 0.5: + step_size = np.random.exponential(exp_scales[0]) + else: + step_size = np.random.exponential(exp_scales[1]) + else: # moderate_chaos, high_chaos + rand_val = np.random.random() + if rand_val < 0.2: + step_size = np.random.exponential(exp_scales[0]) + elif rand_val < 0.5: + step_size = np.random.exponential(exp_scales[1]) + else: + step_size = np.random.exponential(exp_scales[2]) + + # Large jumps + if np.random.random() < params["jump_prob"]: + step_size *= np.random.uniform(*params["jump_range"]) + + # Take step + step = step_size * direction + embeddings[t] = embeddings[t - 1] + step + embeddings[t] += np.random.normal(0, noise_std, embedding_dim) + + # Optional normalization + if normalize_method == "zscore": + embeddings = (embeddings - np.mean(embeddings, axis=0)) / ( + np.std(embeddings, axis=0) + 1e-8 + ) + if normalize_method == "l2": + embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) + + all_embeddings.append(embeddings) + + # Create indices + for t in range(n_timepoints): + all_indices.append( + { + "fov_name": fov_name, + "track_id": track_id, + "t": timepoints[t], + "id": len(all_indices), + } + ) + + # Combine all tracks + all_embeddings = np.vstack(all_embeddings) + ultrack_indices = pd.DataFrame(all_indices) + index = pd.MultiIndex.from_frame(ultrack_indices) + + dataset_dict = {"features": (("sample", "features"), all_embeddings)} + dataset = xr.Dataset(dataset_dict, coords={"sample": index}).reset_index("sample") + + return dataset + + +def analyze_step_sizes_before_and_after_normalization( + n_tracks: int = 5, + n_timepoints: int = 100, + embedding_dim: int = 3, + target_direction: np.ndarray = None, + seed: int = 42, +) -> tuple[plt.Figure, plt.Axes]: + """ + Compare step size distributions before and after normalization. + + This demonstrates how normalization affects the step size magnitudes. + """ + # Generate datasets with and without normalization + unnormalized_smooth = generate_directional_embeddings_corrected( + n_tracks=n_tracks, + n_timepoints=n_timepoints, + embedding_dim=embedding_dim, + movement_type="smooth", + target_direction=target_direction, + normalize_method=None, # Key difference + seed=seed, + ) + + unnormalized_chaotic = generate_directional_embeddings_corrected( + n_tracks=n_tracks, + n_timepoints=n_timepoints, + embedding_dim=embedding_dim, + movement_type="mild_chaos", + target_direction=target_direction, + normalize_method=None, # Key difference + seed=seed, + ) + + normalized_smooth = generate_directional_embeddings_corrected( + n_tracks=n_tracks, + n_timepoints=n_timepoints, + embedding_dim=embedding_dim, + movement_type="smooth", + target_direction=target_direction, + normalize_method=None, + seed=seed, + ) + + normalized_chaotic = generate_directional_embeddings_corrected( + n_tracks=n_tracks, + n_timepoints=n_timepoints, + embedding_dim=embedding_dim, + movement_type="mild_chaos", + target_direction=target_direction, + normalize_method=None, + seed=seed, + ) + + # Extract step sizes using the debug function logic + def extract_step_sizes_simple(dataset): + all_step_sizes = [] + unique_track_ids = np.unique(dataset["track_id"].values) + + for track_id in unique_track_ids: + track_mask = dataset["track_id"] == track_id + track_times = dataset["t"].values[track_mask] + track_embeddings = dataset["features"].values[track_mask] + + time_order = np.argsort(track_times) + sorted_embeddings = track_embeddings[time_order] + + if len(sorted_embeddings) > 1: + steps = np.diff(sorted_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + + return np.array(all_step_sizes) + + # Extract step sizes + smooth_unnorm_steps = extract_step_sizes_simple(unnormalized_smooth) + chaotic_unnorm_steps = extract_step_sizes_simple(unnormalized_chaotic) + smooth_norm_steps = extract_step_sizes_simple(normalized_smooth) + chaotic_norm_steps = extract_step_sizes_simple(normalized_chaotic) + + # Create comparison plot + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) + + # Before normalization + ax1.hist( + smooth_unnorm_steps, + bins=50, + alpha=0.7, + color="#2ca02c", + label=f"Smooth (μ={np.mean(smooth_unnorm_steps):.3f}, σ={np.std(smooth_unnorm_steps):.3f})", + ) + ax1.hist( + chaotic_unnorm_steps, + bins=50, + alpha=0.7, + color="#d62728", + label=f"Chaotic (μ={np.mean(chaotic_unnorm_steps):.3f}, σ={np.std(chaotic_unnorm_steps):.3f})", + ) + ax1.set_title("Before Normalization") + ax1.set_xlabel("Step Size") + ax1.set_ylabel("Frequency") + ax1.legend() + + # After normalization + ax2.hist( + smooth_norm_steps, + bins=50, + alpha=0.7, + color="#2ca02c", + label=f"Smooth (μ={np.mean(smooth_norm_steps):.3f}, σ={np.std(smooth_norm_steps):.3f})", + ) + ax2.hist( + chaotic_norm_steps, + bins=50, + alpha=0.7, + color="#d62728", + label=f"Chaotic (μ={np.mean(chaotic_norm_steps):.3f}, σ={np.std(chaotic_norm_steps):.3f})", + ) + ax2.set_title("After Normalization") + ax2.set_xlabel("Step Size") + ax2.set_ylabel("Frequency") + ax2.legend() + + # Log-scale comparison (before normalization) + ax3.hist(smooth_unnorm_steps, bins=50, alpha=0.7, color="#2ca02c", label="Smooth") + ax3.hist(chaotic_unnorm_steps, bins=50, alpha=0.7, color="#d62728", label="Chaotic") + ax3.set_yscale("log") + ax3.set_title("Before Normalization (Log Scale)") + ax3.set_xlabel("Step Size") + ax3.set_ylabel("Frequency (log)") + ax3.legend() + + # Coefficient of variation comparison + cv_smooth_unnorm = np.std(smooth_unnorm_steps) / np.mean(smooth_unnorm_steps) + cv_chaotic_unnorm = np.std(chaotic_unnorm_steps) / np.mean(chaotic_unnorm_steps) + cv_smooth_norm = np.std(smooth_norm_steps) / np.mean(smooth_norm_steps) + cv_chaotic_norm = np.std(chaotic_norm_steps) / np.mean(chaotic_norm_steps) + + categories = [ + "Smooth\n(Unnorm)", + "Chaotic\n(Unnorm)", + "Smooth\n(Norm)", + "Chaotic\n(Norm)", + ] + cv_values = [cv_smooth_unnorm, cv_chaotic_unnorm, cv_smooth_norm, cv_chaotic_norm] + colors = ["#2ca02c", "#d62728", "#2ca02c", "#d62728"] + alphas = [1.0, 1.0, 0.5, 0.5] + + # Create individual bars with their own alpha values + bars = [] + for i, (cat, val, color, alpha) in enumerate( + zip(categories, cv_values, colors, alphas) + ): + bar = ax4.bar(cat, val, color=color, alpha=alpha) + bars.extend(bar) + + ax4.set_ylabel("Coefficient of Variation (σ/μ)") + ax4.set_title("Step Size Variability Comparison") + ax4.tick_params(axis="x", rotation=45) + + plt.tight_layout() + return fig, (ax1, ax2, ax3, ax4) + + +def compute_msd_pairwise_optimized( + embedding_dataset: xr.Dataset, + distance_metric: Literal["euclidean", "cosine"] = "euclidean", +) -> dict[int, list[float]]: + """ + Compute Mean Squared Displacement using pairwise distance matrix. + + Uses compare_time_offset for efficient diagonal extraction. + + Parameters + ---------- + embedding_dataset : xr.Dataset + Dataset containing embeddings and metadata + distance_metric : Literal["euclidean", "cosine"] + Distance metric to use + + Returns + ------- + dict[int, list[float]] + Dictionary mapping time lag τ to list of squared displacements + """ + from collections import defaultdict + + unique_tracks_df = ( + embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() + ) + + displacement_per_tau = defaultdict(list) + + for fov_name, track_id in zip( + unique_tracks_df["fov_name"], unique_tracks_df["track_id"] + ): + # Filter data for this track + track_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + + # Sort by time + time_order = np.argsort(track_data["t"].values) + times = track_data["t"].values[time_order] + track_embeddings = track_data["features"].values[time_order] + + # Compute pairwise distance matrix + if distance_metric == "euclidean": + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric="euclidean" + ) + distance_matrix = distance_matrix**2 # Square for MSD + elif distance_metric == "cosine": + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric="cosine" + ) + else: + raise ValueError(f"Unsupported distance metric: {distance_metric}") + + # Extract displacements using diagonal offsets + n_timepoints = len(times) + for time_offset in range(1, n_timepoints): + diagonal_displacements = compare_time_offset(distance_matrix, time_offset) + + for i, displacement in enumerate(diagonal_displacements): + tau = int(times[i + time_offset] - times[i]) + displacement_per_tau[tau].append(displacement) + + return dict(displacement_per_tau) + + +def normalize_msd_by_embedding_variance( + msd_data_dict: dict[str, dict[int, list[float]]], + datasets: dict[str, xr.Dataset], +) -> dict[str, dict[int, list[float]]]: + """ + Normalize MSD values by the embedding variance for each movement type. + + This enables fair comparison between different embedding models or movement types + by removing scale differences. + + Parameters + ---------- + msd_data_dict : dict[str, dict[int, list[float]]] + Dictionary mapping movement type to MSD data + datasets : dict[str, xr.Dataset] + Dictionary mapping movement type to dataset (for computing variance) + + Returns + ------- + dict[str, dict[int, list[float]]] + Normalized MSD data with same structure as input + """ + normalized_msd_data = {} + + for movement_type, msd_data in msd_data_dict.items(): + # Calculate embedding variance for this movement type + embeddings = datasets[movement_type]["features"].values + embedding_variance = np.var(embeddings) + + print(f"{movement_type}: embedding_variance = {embedding_variance:.4f}") + + # Normalize all MSD values by this variance + normalized_msd_data[movement_type] = {} + for tau, displacements in msd_data.items(): + normalized_msd_data[movement_type][tau] = [ + disp / embedding_variance for disp in displacements + ] + + return normalized_msd_data + + +def normalize_step_sizes_by_embedding_variance( + datasets: dict[str, xr.Dataset], +) -> dict[str, dict[str, float]]: + """ + Normalize step size statistics by embedding variance for fair comparison. + + Parameters + ---------- + datasets : dict[str, xr.Dataset] + Dictionary mapping movement type to dataset + + Returns + ------- + dict[str, dict[str, float]] + Dictionary with normalized step size statistics + """ + step_stats = {} + + print("\n=== Step Size Statistics (Normalized by Embedding Variance) ===") + print("-" * 70) + + for movement_type, dataset in datasets.items(): + # Calculate embedding variance for normalization + embeddings = dataset["features"].values + embedding_variance = np.var(embeddings) + + # Extract step sizes + all_step_sizes = [] + unique_track_ids = np.unique(dataset["track_id"].values) + + for track_id in unique_track_ids: + track_mask = dataset["track_id"] == track_id + track_embeddings = dataset["features"].values[track_mask] + track_times = dataset["t"].values[track_mask] + + # Sort by time and remove duplicates + time_order = np.argsort(track_times) + sorted_embeddings = track_embeddings[time_order] + sorted_times = track_times[time_order] + unique_times, unique_indices = np.unique(sorted_times, return_index=True) + final_embeddings = sorted_embeddings[unique_indices] + + if len(final_embeddings) > 1: + steps = np.diff(final_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + + step_sizes = np.array(all_step_sizes) + + # Calculate raw statistics + raw_mean = np.mean(step_sizes) + raw_std = np.std(step_sizes) + raw_cv = raw_std / raw_mean + + # Calculate normalized statistics + norm_mean = raw_mean / np.sqrt(embedding_variance) + norm_std = raw_std / np.sqrt(embedding_variance) + norm_cv = norm_std / norm_mean # CV remains the same after scaling + + step_stats[movement_type] = { + "raw_mean": raw_mean, + "raw_std": raw_std, + "raw_cv": raw_cv, + "norm_mean": norm_mean, + "norm_std": norm_std, + "norm_cv": norm_cv, + "embedding_variance": embedding_variance, + "n_steps": len(step_sizes), + } + + print( + f"{movement_type:15} | Raw: μ={raw_mean:.4f}, σ={raw_std:.4f}, CV={raw_cv:.4f}" + ) + print(f"{'':15} | Norm: μ={norm_mean:.4f}, σ={norm_std:.4f}, CV={norm_cv:.4f}") + print(f"{'':15} | Var={embedding_variance:.4f}, N={len(step_sizes)}") + print("-" * 70) + + return step_stats + + +def plot_msd_comparison( + msd_data_dict: dict[str, dict[int, list[float]]], + title: str = "MSD: Smooth vs Chaotic Diffusion (Same Direction)", + log_scale: bool = True, + show_power_law_fits: bool = True, +) -> tuple[plt.Figure, plt.Axes]: + """ + Plot MSD curves comparing smooth and chaotic diffusion. + + Parameters + ---------- + msd_data_dict : dict[str, dict[int, list[float]]] + Dictionary mapping movement type to MSD data + title : str + Plot title + log_scale : bool + Whether to use log-log scale + show_power_law_fits : bool + Whether to show power law fits + + Returns + ------- + tuple[plt.Figure, plt.Axes] + Figure and axes objects + """ + fig, ax = plt.subplots(figsize=(10, 7)) + + colors = {"smooth": "#2ca02c", "chaotic": "#d62728"} + + for movement_type, msd_data in msd_data_dict.items(): + time_lags = sorted(msd_data.keys()) + msd_means = [] + msd_stds = [] + + for tau in time_lags: + displacements = np.array(msd_data[tau]) + msd_means.append(np.mean(displacements)) + msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) + + time_lags = np.array(time_lags) + msd_means = np.array(msd_means) + msd_stds = np.array(msd_stds) + + # Plot with error bars + color = colors.get(movement_type, "#1f77b4") + ax.errorbar( + time_lags, + msd_means, + yerr=msd_stds, + marker="o", + label=f"{movement_type.title()} Diffusion", + color=color, + capsize=3, + capthick=1, + linewidth=2, + ) + + # Fit power law if requested + if show_power_law_fits and len(time_lags) > 3: + valid_mask = (time_lags > 0) & (msd_means > 0) + if np.sum(valid_mask) > 3: + log_tau = np.log(time_lags[valid_mask]) + log_msd = np.log(msd_means[valid_mask]) + + slope, intercept, r_value, p_value, std_err = stats.linregress( + log_tau, log_msd + ) + + # Plot fit line + tau_fit = np.linspace( + time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 + ) + msd_fit = np.exp(intercept) * tau_fit**slope + + ax.plot( + tau_fit, + msd_fit, + "--", + color=color, + alpha=0.7, + label=f"{movement_type}: α={slope:.2f} (R²={r_value**2:.3f})", + ) + + ax.set_xlabel("Time Lag (τ)", fontsize=12) + ax.set_ylabel("Mean Squared Displacement", fontsize=12) + ax.set_title(title, fontsize=14) + + if log_scale: + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(True, alpha=0.3) + + ax.legend() + plt.tight_layout() + return fig, ax + + +def plot_trajectory_comparison_3d( + smooth_dataset: xr.Dataset, + chaotic_dataset: xr.Dataset, + target_direction: np.ndarray = None, + title: str = "3D Trajectory Comparison: Smooth vs Chaotic", +) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]: + """ + Plot 3D trajectories comparing smooth and chaotic diffusion side by side. + + Parameters + ---------- + smooth_dataset : xr.Dataset + Dataset with smooth diffusion trajectories + chaotic_dataset : xr.Dataset + Dataset with chaotic diffusion trajectories + target_direction : np.ndarray + Target direction vector + title : str + Plot title + + Returns + ------- + tuple[plt.Figure, tuple[plt.Axes, plt.Axes]] + Figure and axes objects + """ + fig = plt.figure(figsize=(16, 7)) + + # Default target direction + if target_direction is None: + target_direction = np.array([2.0, 0.0, 0.0]) + + # Smooth diffusion plot + ax1 = fig.add_subplot(121, projection="3d") + plot_single_trajectory_3d(smooth_dataset, ax1, "Smooth Diffusion", target_direction) + + # Chaotic diffusion plot + ax2 = fig.add_subplot(122, projection="3d") + plot_single_trajectory_3d( + chaotic_dataset, ax2, "Chaotic Diffusion", target_direction + ) + + fig.suptitle(title, fontsize=16) + plt.tight_layout() + return fig, (ax1, ax2) + + +def plot_single_trajectory_3d( + dataset: xr.Dataset, + ax: plt.Axes, + subtitle: str, + target_direction: np.ndarray, +): + """ + Plot trajectories for a single dataset in 3D. + + Parameters + ---------- + dataset : xr.Dataset + Dataset containing trajectories + ax : plt.Axes + 3D axes object + subtitle : str + Subtitle for the plot + target_direction : np.ndarray + Target direction vector + """ + n_tracks = len(np.unique(dataset["track_id"].values)) + colors = plt.cm.tab10(np.linspace(0, 1, n_tracks)) + + unique_tracks_df = ( + dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() + ) + + for i, (fov_name, track_id) in enumerate( + zip(unique_tracks_df["fov_name"], unique_tracks_df["track_id"]) + ): + track_data = dataset.where( + (dataset["fov_name"] == fov_name) & (dataset["track_id"] == track_id), + drop=True, + ) + + # Sort by time + time_order = np.argsort(track_data["t"].values) + embeddings = track_data["features"].values[time_order] + + x, y, z = embeddings[:, 0], embeddings[:, 1], embeddings[:, 2] + color = colors[int(track_id) % len(colors)] + + # Plot trajectory + ax.plot(x, y, z, "-", color=color, alpha=0.7, linewidth=2) + + # Start and end points + ax.scatter( + x[0], + y[0], + z[0], + color=color, + s=100, + marker="o", + edgecolors="black", + linewidth=1, + ) + ax.scatter( + x[-1], + y[-1], + z[-1], + color=color, + s=150, + marker="*", + edgecolors="black", + linewidth=1, + ) + + # Show target direction arrow + origin = np.array([0, 0, 0]) + ax.quiver( + origin[0], + origin[1], + origin[2], + target_direction[0], + target_direction[1], + target_direction[2], + color="red", + arrow_length_ratio=0.1, + linewidth=3, + label="Target Direction", + ) + + ax.set_xlabel("Dimension 1") + ax.set_ylabel("Dimension 2") + ax.set_zlabel("Dimension 3") + ax.set_title(subtitle) + ax.legend() + + +def analyze_step_size_distributions_debug( + smooth_dataset: xr.Dataset, + chaotic_dataset: xr.Dataset, +) -> tuple[plt.Figure, plt.Axes]: + """ + Analyze and plot step size distributions with debugging information. + """ + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + + def extract_step_sizes_simple(dataset, dataset_name): + """Extract step sizes with simple coordinate access.""" + all_step_sizes = [] + + # Get unique track IDs + unique_track_ids = np.unique(dataset["track_id"].values) + + print(f"\n{dataset_name} Dataset:") + print(f"Total samples: {len(dataset['track_id'])}") + print(f"Unique track IDs: {unique_track_ids}") + + for track_id in unique_track_ids: + # Get all data for this track + track_mask = dataset["track_id"] == track_id + track_times = dataset["t"].values[track_mask] + track_embeddings = dataset["features"].values[track_mask] + + # Sort by time + time_order = np.argsort(track_times) + sorted_embeddings = track_embeddings[time_order] + sorted_times = track_times[time_order] + + # Remove duplicates in time (this might be the issue) + unique_times, unique_indices = np.unique(sorted_times, return_index=True) + final_embeddings = sorted_embeddings[unique_indices] + + print( + f"Track {track_id}: {len(sorted_times)} total, {len(unique_times)} unique timepoints" + ) + + # Calculate step sizes + if len(final_embeddings) > 1: + steps = np.diff(final_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + print(f"Track {track_id}: {len(step_sizes)} steps") + + print(f"Total steps in {dataset_name}: {len(all_step_sizes)}") + return np.array(all_step_sizes) + + # Extract step sizes with debug info + smooth_steps = extract_step_sizes_simple(smooth_dataset, "Smooth") + chaotic_steps = extract_step_sizes_simple(chaotic_dataset, "Chaotic") + + # Plot histograms + ax1.hist( + smooth_steps, + bins=50, + alpha=0.7, + color="#2ca02c", + label=f"Smooth (n={len(smooth_steps)}, μ={np.mean(smooth_steps):.3f}, σ={np.std(smooth_steps):.3f})", + ) + ax1.hist( + chaotic_steps, + bins=50, + alpha=0.7, + color="#d62728", + label=f"Chaotic (n={len(chaotic_steps)}, μ={np.mean(chaotic_steps):.3f}, σ={np.std(chaotic_steps):.3f})", + ) + ax1.set_xlabel("Step Size") + ax1.set_ylabel("Frequency") + ax1.set_title("Step Size Distribution") + ax1.legend() + + # Plot coefficient of variation + cv_smooth = np.std(smooth_steps) / np.mean(smooth_steps) + cv_chaotic = np.std(chaotic_steps) / np.mean(chaotic_steps) + + ax2.bar( + ["Smooth", "Chaotic"], + [cv_smooth, cv_chaotic], + color=["#2ca02c", "#d62728"], + alpha=0.7, + ) + ax2.set_ylabel("Coefficient of Variation (σ/μ)") + ax2.set_title("Step Size Variability") + + plt.tight_layout() + return fig, (ax1, ax2) + + +def plot_trajectory_comparison_3d_multi( + datasets: dict[str, xr.Dataset], + target_direction: np.ndarray = None, + title: str = "3D Trajectory Comparison: Multiple Movement Types", +) -> tuple[plt.Figure, list[plt.Axes]]: + """ + Plot 3D trajectories for multiple movement types. + + Parameters + ---------- + datasets : dict[str, xr.Dataset] + Dictionary mapping movement type name to dataset + target_direction : np.ndarray + Target direction vector + title : str + Plot title + """ + n_types = len(datasets) + cols = 2 + rows = (n_types + 1) // 2 + + fig = plt.figure(figsize=(12, 6 * rows)) + + # Default target direction + if target_direction is None: + target_direction = np.array([2.0, 0.0, 0.0]) + + axes = [] + for i, (movement_type, dataset) in enumerate(datasets.items()): + ax = fig.add_subplot(rows, cols, i + 1, projection="3d") + plot_single_trajectory_3d( + dataset, + ax, + f"{movement_type.replace('_', ' ').title()} Movement", + target_direction, + ) + axes.append(ax) + + fig.suptitle(title, fontsize=16) + plt.tight_layout() + return fig, axes + + +def plot_msd_comparison_multi( + msd_data_dict: dict[str, dict[int, list[float]]], + title: str = "MSD: Multiple Movement Types Comparison", + log_scale: bool = True, + show_power_law_fits: bool = True, +) -> tuple[plt.Figure, plt.Axes]: + """ + Plot MSD curves for multiple movement types. + """ + fig, ax = plt.subplots(figsize=(12, 8)) + + # Color palette for different movement types + colors = { + "smooth": "#2ca02c", + "mild_chaos": "#ff7f0e", + "moderate_chaos": "#d62728", + "high_chaos": "#9467bd", + } + + for movement_type, msd_data in msd_data_dict.items(): + time_lags = sorted(msd_data.keys()) + msd_means = [] + msd_stds = [] + + for tau in time_lags: + displacements = np.array(msd_data[tau]) + msd_means.append(np.mean(displacements)) + msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) + + time_lags = np.array(time_lags) + msd_means = np.array(msd_means) + msd_stds = np.array(msd_stds) + + # Plot with error bars + color = colors.get(movement_type, "#1f77b4") + ax.errorbar( + time_lags, + msd_means, + yerr=msd_stds, + marker="o", + label=f"{movement_type.replace('_', ' ').title()}", + color=color, + capsize=3, + capthick=1, + linewidth=2, + markersize=6, + ) + + # Fit power law if requested + if show_power_law_fits and len(time_lags) > 3: + valid_mask = (time_lags > 0) & (msd_means > 0) + if np.sum(valid_mask) > 3: + log_tau = np.log(time_lags[valid_mask]) + log_msd = np.log(msd_means[valid_mask]) + + slope, intercept, r_value, p_value, std_err = stats.linregress( + log_tau, log_msd + ) + + # Plot fit line + tau_fit = np.linspace( + time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 + ) + msd_fit = np.exp(intercept) * tau_fit**slope + + ax.plot( + tau_fit, + msd_fit, + "--", + color=color, + alpha=0.7, + label=f"{movement_type}: α={slope:.2f} (R²={r_value**2:.3f})", + ) + + ax.set_xlabel("Time Lag (τ)", fontsize=12) + ax.set_ylabel("Mean Squared Displacement", fontsize=12) + ax.set_title(title, fontsize=14) + + if log_scale: + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(True, alpha=0.3) + + ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + plt.tight_layout() + return fig, ax + + +def analyze_step_size_distributions_multi( + datasets: dict[str, xr.Dataset], +) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]: + """ + Analyze step size distributions for multiple movement types. + """ + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + colors = { + "smooth": "#2ca02c", + "mild_chaos": "#ff7f0e", + "moderate_chaos": "#d62728", + "high_chaos": "#9467bd", + } + + def extract_step_sizes_simple(dataset): + """Extract step sizes with simple coordinate access.""" + all_step_sizes = [] + unique_track_ids = np.unique(dataset["track_id"].values) + + for track_id in unique_track_ids: + track_mask = dataset["track_id"] == track_id + track_times = dataset["t"].values[track_mask] + track_embeddings = dataset["features"].values[track_mask] + + time_order = np.argsort(track_times) + sorted_embeddings = track_embeddings[time_order] + sorted_times = track_times[time_order] + + # Remove duplicates in time + unique_times, unique_indices = np.unique(sorted_times, return_index=True) + final_embeddings = sorted_embeddings[unique_indices] + + if len(final_embeddings) > 1: + steps = np.diff(final_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + + return np.array(all_step_sizes) + + # Extract step sizes for all datasets + all_step_data = {} + cv_values = [] + labels = [] + + for movement_type, dataset in datasets.items(): + steps = extract_step_sizes_simple(dataset) + all_step_data[movement_type] = steps + + # Calculate coefficient of variation + cv = np.std(steps) / np.mean(steps) + cv_values.append(cv) + labels.append(movement_type.replace("_", " ").title()) + + # Plot histograms + for movement_type, steps in all_step_data.items(): + color = colors.get(movement_type, "#1f77b4") + ax1.hist( + steps, + bins=50, + alpha=0.7, + color=color, + label=f"{movement_type.replace('_', ' ').title()} (n={len(steps)}, μ={np.mean(steps):.3f}, σ={np.std(steps):.3f})", + ) + + ax1.set_xlabel("Step Size") + ax1.set_ylabel("Frequency") + ax1.set_title("Step Size Distributions") + ax1.legend() + + # Plot coefficient of variation + bar_colors = [ + colors.get(movement_type, "#1f77b4") for movement_type in datasets.keys() + ] + bars = ax2.bar(labels, cv_values, color=bar_colors, alpha=0.7) + ax2.set_ylabel("Coefficient of Variation (σ/μ)") + ax2.set_title("Step Size Variability") + ax2.tick_params(axis="x", rotation=45) + + plt.tight_layout() + return fig, (ax1, ax2) + + +# %% +if __name__ == "__main__": + # Note: direction of the embedding to simulate movement/infection. + target_direction = np.array([10.0, 0, 0.0]) + + movement_types = ["smooth", "mild_chaos", "moderate_chaos", "high_chaos"] + + datasets = {} + print("=== Generating Datasets ===") + for movement_type in movement_types: + print(f"Generating {movement_type} dataset...") + datasets[movement_type] = generate_directional_embeddings_corrected( + n_tracks=5, + n_timepoints=100, + movement_type=movement_type, + target_direction=target_direction, + normalize_method=None, + seed=42, + ) + + print("=== Computing MSD for All Movement Types ===") + msd_data_dict = {} + for movement_type, dataset in datasets.items(): + print(f"Computing MSD for {movement_type}...") + msd_data_dict[movement_type] = compute_msd_pairwise_optimized(dataset) + + print("\n=== Normalizing MSD by Embedding Variance ===") + normalized_msd_data_dict = normalize_msd_by_embedding_variance( + msd_data_dict, datasets + ) + + print("=== MSD vs Time Plot (Raw) ===") + fig_msd_raw, ax_msd_raw = plot_msd_comparison_multi( + msd_data_dict, title="MSD: Raw Values (All Movement Types)" + ) + plt.show() + + print("=== MSD vs Time Plot (Normalized by Embedding Variance) ===") + fig_msd_norm, ax_msd_norm = plot_msd_comparison_multi( + normalized_msd_data_dict, + title="MSD: Normalized by Embedding Variance (All Movement Types)", + ) + plt.show() + + print("=== 3D Trajectory Comparison (All Types) ===") + fig_3d, axes_3d = plot_trajectory_comparison_3d_multi(datasets, target_direction) + plt.show() + + print("=== Step Size Distribution Analysis (All Types) ===") + fig_step, (ax_step1, ax_step2) = analyze_step_size_distributions_multi(datasets) + plt.show() + + print("=== Step Size Normalization Analysis ===") + step_stats = normalize_step_sizes_by_embedding_variance(datasets) + + print("=== Summary Statistics ===") + for movement_type, dataset in datasets.items(): + print(f"\n{movement_type.replace('_', ' ').title()} Movement:") + print(f" Dataset shape: {dataset.dims}") + print(f" Total samples: {len(dataset.sample)}") + + # Calculate mean step size and CV + def get_step_stats(dataset): + all_step_sizes = [] + unique_track_ids = np.unique(dataset["track_id"].values) + for track_id in unique_track_ids: + track_mask = dataset["track_id"] == track_id + track_embeddings = dataset["features"].values[track_mask] + if len(track_embeddings) > 1: + steps = np.diff(track_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + return np.array(all_step_sizes) + + steps = get_step_stats(dataset) + mean_step = np.mean(steps) + std_step = np.std(steps) + cv = std_step / mean_step + + print(f" Mean step size: {mean_step:.4f}") + print(f" Step size std: {std_step:.4f}") + print(f" Coefficient of variation: {cv:.4f}") + + +# %% diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..6ac3fc0ef 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -150,3 +150,115 @@ def clustering_evaluation(embeddings, annotations, method="nmi"): raise ValueError("Invalid method. Choose 'nmi' or 'ari'.") return score + + +def compute_msd_from_distance_matrix( + distance_matrix: NDArray, timepoints: ArrayLike, squared: bool = True +) -> dict[int, list[float]]: + """ + Compute MSD from a precomputed distance matrix using diagonal extraction. + + This is the most efficient approach for MSD computation when you already + have a distance matrix. Uses the compare_time_offset function internally. + + Parameters + ---------- + distance_matrix : NDArray + Square distance matrix (n_timepoints, n_timepoints) + timepoints : ArrayLike + Time points corresponding to each row/column + squared : bool, optional + Whether to square the distances (for true MSD), by default True + + Returns + ------- + dict[int, list[float]] + Dictionary mapping time lag τ to list of displacement values + """ + from collections import defaultdict + + if squared: + distance_matrix = distance_matrix**2 + + timepoints = np.array(timepoints) + displacement_per_tau = defaultdict(list) + n_timepoints = len(timepoints) + + # Use diagonal extraction for efficiency + for time_offset in range(1, n_timepoints): + # Extract diagonal at this offset using existing function + diagonal_displacements = compare_time_offset(distance_matrix, time_offset) + + # Map to actual time lags τ + for i, displacement in enumerate(diagonal_displacements): + tau = int(timepoints[i + time_offset] - timepoints[i]) + displacement_per_tau[tau].append(displacement) + + return dict(displacement_per_tau) + + +def compute_msd_from_pairwise_distances( + features: ArrayLike, timepoints: ArrayLike, metric: str = "euclidean" +) -> dict[int, list[float]]: + """ + Compute Mean Square Displacement (MSD) from pairwise distances. + + This is an efficient implementation that uses diagonal extraction + instead of nested loops for better performance. + + Parameters + ---------- + features : ArrayLike + Feature matrix (n_timepoints, n_features) for a single track + timepoints : ArrayLike + Time points corresponding to each feature vector + metric : str, optional + Distance metric to use, by default "euclidean" + + Returns + ------- + dict[int, list[float]] + Dictionary mapping time lag τ to list of displacement values + """ + # Ensure proper ordering by time + time_order = np.argsort(timepoints) + features = np.array(features)[time_order] + timepoints = np.array(timepoints)[time_order] + + # Compute pairwise distance matrix + distance_matrix = pairwise_distance_matrix(features, metric=metric) + + # Use the optimized diagonal extraction method + return compute_msd_from_distance_matrix( + distance_matrix, timepoints, squared=(metric == "euclidean") + ) + + +def compute_track_msd_statistics( + features: ArrayLike, timepoints: ArrayLike, metric: str = "euclidean" +) -> tuple[dict[int, float], dict[int, float]]: + """ + Compute MSD statistics (mean and std) for a single track. + + Parameters + ---------- + features : ArrayLike + Feature matrix (n_timepoints, n_features) for a single track + timepoints : ArrayLike + Time points corresponding to each feature vector + metric : str, optional + Distance metric to use, by default "euclidean" + + Returns + ------- + tuple[dict[int, float], dict[int, float]] + Tuple of (mean_msd, std_msd) dictionaries mapping τ to statistics + """ + msd_per_tau = compute_msd_from_pairwise_distances(features, timepoints, metric) + + mean_msd = { + tau: np.mean(displacements) for tau, displacements in msd_per_tau.items() + } + std_msd = {tau: np.std(displacements) for tau, displacements in msd_per_tau.items()} + + return mean_msd, std_msd From 905247c03078a6c131984b198065ee0457f8e846 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 10 Jul 2025 10:28:36 -0700 Subject: [PATCH 002/101] update the msd calculation to re-use cdist functions in the repo --- .../evaluation/ALFI_MSD_v2.py | 356 +++++++++--------- viscy/representation/evaluation/distance.py | 108 +++--- 2 files changed, 219 insertions(+), 245 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index b79c7cbe7..94db31675 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -1,227 +1,207 @@ # %% from pathlib import Path + import matplotlib.pyplot as plt import numpy as np +import xarray as xr +from scipy import stats + from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.distance import ( - compute_displacement, - compute_displacement_statistics, + compute_msd, ) # Paths to datasets feature_paths = { - "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", - "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr", + "14 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr", + "91 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr", } -# Colors for different time intervals -interval_colors = { - "7 min interval": "blue", - "21 min interval": "red", -} + +cmap = plt.get_cmap("tab10") # or use "Set2", "tab20", etc. +labels = list(feature_paths.keys()) +interval_colors = {label: cmap(i % cmap.N) for i, label in enumerate(labels)} + +# Print and check each path +for label, path in feature_paths.items(): + print(f"{label} color: {interval_colors[label]}") + assert Path(path).exists(), f"Path {path} does not exist" # %% Compute MSD for each dataset results = {} raw_displacements = {} for label, path in feature_paths.items(): + results[label] = {} print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) # Compute displacements - displacements = compute_displacement( + displacements_per_tau = compute_msd( embedding_dataset=embedding_dataset, - distance_metric="euclidean_squared", + distance_metric="euclidean", ) - means, stds = compute_displacement_statistics(displacements) - results[label] = (means, stds) - raw_displacements[label] = displacements + embeddings_variance = np.var(embedding_dataset["features"].values) - # Print some statistics - taus = sorted(means.keys()) - print(f" Number of different τ values: {len(taus)}") - print(f" τ range: {min(taus)} to {max(taus)}") - print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") + # Normalize MSD by embeddings variance + for tau, displacements in displacements_per_tau.items(): + results[label][tau] = [disp / embeddings_variance for disp in displacements] -# %% Plot MSD vs time (linear scale) -plt.figure(figsize=(10, 6)) - -# Plot each time interval -for interval_label, path in feature_paths.items(): - means, stds = results[interval_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - plt.plot( - taus, - mean_values, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - plt.scatter( - taus, - mean_values, - color=interval_colors[interval_label], - s=20, - label=interval_label, - zorder=2, - ) -plt.xlabel("Time Shift (τ)") -plt.ylabel("Mean Square Displacement") -plt.title("MSD vs Time Shift") -plt.grid(True, alpha=0.3) -plt.legend() -plt.tight_layout() -plt.show() - -# %% Plot MSD vs time (log-log scale with slopes) -plt.figure(figsize=(10, 6)) - -# Plot each time interval -for interval_label, path in feature_paths.items(): - means, stds = results[interval_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - # Filter out non-positive values for log scale - valid_mask = np.array(mean_values) > 0 - valid_taus = np.array(taus)[valid_mask] - valid_means = np.array(mean_values)[valid_mask] - - # Calculate slopes for different regions - log_taus = np.log(valid_taus) - log_means = np.log(valid_means) - - # Early slope (first third of points) - n_points = len(log_taus) - early_end = n_points // 3 - early_slope, early_intercept = np.polyfit( - log_taus[:early_end], log_means[:early_end], 1 +# %% Plot MSD vs time (linear scale) +show_power_law_fits = True +log_scale = True +title = "MSD vs Time Shift" + +fig, ax = plt.subplots(figsize=(10, 7)) + +for model_type, msd_data in results.items(): + time_lags = sorted(msd_data.keys()) + msd_means = [] + msd_stds = [] + + for tau in time_lags: + displacements = np.array(msd_data[tau]) + msd_means.append(np.mean(displacements)) + msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) + + time_lags = np.array(time_lags) + msd_means = np.array(msd_means) + msd_stds = np.array(msd_stds) + + # Plot with error bars + color = interval_colors.get(model_type, "#1f77b4") + ax.errorbar( + time_lags, + msd_means, + yerr=msd_stds, + marker="o", + label=f"{model_type.replace('_', ' ').title()}", + color=color, + capsize=3, + capthick=1, + linewidth=2, + markersize=6, ) + # Fit power law if requested + if show_power_law_fits and len(time_lags) > 3: + valid_mask = (time_lags > 0) & (msd_means > 0) + if np.sum(valid_mask) > 3: + log_tau = np.log(time_lags[valid_mask]) + log_msd = np.log(msd_means[valid_mask]) + + slope, intercept, r_value, p_value, std_err = stats.linregress( + log_tau, log_msd + ) + + # Plot fit line + tau_fit = np.linspace( + time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 + ) + msd_fit = np.exp(intercept) * tau_fit**slope + + ax.plot( + tau_fit, + msd_fit, + "--", + color=color, + alpha=0.7, + label=f"{model_type}: α={slope:.2f} (R²={r_value**2:.3f})", + ) + + ax.set_xlabel("Time Lag (τ)", fontsize=12) + ax.set_ylabel("Mean Squared Displacement", fontsize=12) + ax.set_title(title, fontsize=14) + + if log_scale: + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(True, alpha=0.3) + + ax.legend() + plt.tight_layout() +plt.savefig("msd_vs_time_shift.png", dpi=300) +# %% +# Step size analysis - # Late slope (last third of points) - late_start = 2 * (n_points // 3) - late_slope, late_intercept = np.polyfit( - log_taus[late_start:], log_means[late_start:], 1 - ) - plt.plot( - valid_taus, - valid_means, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - plt.scatter( - valid_taus, - valid_means, - color=interval_colors[interval_label], - s=20, - label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})", - zorder=2, - ) +def extract_step_sizes(embedding_dataset: xr.Dataset): + """Extract step sizes with simple coordinate access.""" - # Plot fitted lines for early and late regions - early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) - late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:]) - - plt.plot( - valid_taus[:early_end], - early_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) - plt.plot( - valid_taus[late_start:], - late_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, + unique_tracks_df = ( + embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() ) + all_step_sizes = [] + + for fov_name, track_id in zip( + unique_tracks_df["fov_name"], unique_tracks_df["track_id"] + ): + track_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + time_order = np.argsort(track_data["t"].values) + times = track_data["t"].values[time_order] + track_embeddings = track_data["features"].values[time_order] + if len(times) != len(np.unique(times)): + print(f"Duplicates found in FOV {fov_name}, track {track_id}") + + if len(track_embeddings) > 1: + steps = np.diff(track_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + + return np.array(all_step_sizes) + + +all_step_data = {} +cv_values = [] +labels = [] -plt.xscale("log") -plt.yscale("log") -plt.xlabel("Time Shift (τ)") -plt.ylabel("Mean Square Displacement") -plt.title("MSD vs Time Shift (log-log)") -plt.grid(True, alpha=0.3, which="both") -plt.legend( - title="α = slope in log-log space", bbox_to_anchor=(1.05, 1), loc="upper left" -) -plt.tight_layout() -plt.show() - -# %% Plot slopes analysis -early_slopes = [] -late_slopes = [] -intervals = [] - -for interval_label in feature_paths.keys(): - means, _ = results[interval_label] - - # Calculate slopes - taus = np.array(sorted(means.keys())) - mean_values = np.array([means[tau] for tau in taus]) - valid_mask = mean_values > 0 - - if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes - log_taus = np.log(taus[valid_mask]) - log_means = np.log(mean_values[valid_mask]) - - # Calculate early and late slopes - n_points = len(log_taus) - early_end = n_points // 3 - late_start = 2 * (n_points // 3) - - early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) - late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1) - - early_slopes.append(early_slope) - late_slopes.append(late_slope) - intervals.append(interval_label) - -# Create bar plot -plt.figure(figsize=(12, 6)) - -x = np.arange(len(intervals)) -width = 0.35 - -plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) -plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7) - -# Add reference lines -plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") -plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) +for label, path in feature_paths.items(): + print(f"\nProcessing {label}...") + embedding_dataset = read_embedding_dataset(Path(path)) + steps = extract_step_sizes(embedding_dataset) + all_step_data[label] = steps -plt.xlabel("Time Interval") -plt.ylabel("Slope (α)") -plt.title("MSD Slopes by Time Interval") -plt.xticks(x, intervals, rotation=45) -plt.legend() + # Calculate coefficient of variation + cv = np.std(steps) / np.mean(steps) + cv_values.append(cv) + labels.append(label.replace("_", " ").title()) -# Add annotations for diffusion regimes -plt.text( - plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center" -) -plt.text( - plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center" -) +# %% +# Plot histograms +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + +for model_type, steps in all_step_data.items(): + color = interval_colors.get(model_type, "#1f77b4") + ax1.hist( + steps, + bins=50, + alpha=0.7, + color=color, + label=f"{model_type.replace('_', ' ').title()} (n={len(steps)}, μ={np.mean(steps):.3f}, σ={np.std(steps):.3f})", + ) -plt.grid(True, alpha=0.3) +ax1.set_xlabel("Step Size") +ax1.set_ylabel("Frequency") +ax1.set_title("Step Size Distributions") +ax1.legend() + +# Plot coefficient of variation +bar_colors = [ + interval_colors.get(model_type, "#1f77b4") for model_type in results.keys() +] +bars = ax2.bar(labels, cv_values, color=bar_colors, alpha=0.7) +ax2.set_ylabel("Coefficient of Variation (σ/μ)") +ax2.set_title("Step Size Variability") +ax2.tick_params(axis="x", rotation=45) plt.tight_layout() -plt.show() +# plt.show() +plt.savefig("step_size_distributions.png", dpi=300) # %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index a920eb072..2911c9b7a 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,9 +1,18 @@ +import logging from collections import defaultdict from typing import Literal import numpy as np +import xarray as xr from sklearn.metrics.pairwise import cosine_similarity +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, +) + +_logger = logging.getLogger(__name__) + def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): """Extract embeddings and calculate cosine similarities for a specific cell""" @@ -21,84 +30,69 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities.tolist() -def compute_displacement( - embedding_dataset, - distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", +def compute_msd( + embedding_dataset: xr.Dataset, + distance_metric: Literal["euclidean", "cosine"] = "euclidean", ) -> dict[int, list[float]]: - """Compute the displacement or mean square displacement (MSD) of embeddings. - - For each time difference τ, computes either: - - |r(t + τ) - r(t)|² for squared Euclidean (MSD) - - cos_sim(r(t + τ), r(t)) for cosine - for all particles and initial times t. + """ + Compute Mean Squared Displacement using pairwise distance matrix. Parameters ---------- - embedding_dataset : xarray.Dataset + embedding_dataset : xr.Dataset Dataset containing embeddings and metadata - distance_metric : str - The metric to use for computing distances between embeddings. - Valid options are: - - "euclidean": Euclidean distance (L2 norm) - - "euclidean_squared": Squared Euclidean distance (for MSD, default) - - "cosine": Cosine similarity - - "cosine_dissimilarity": 1 - cosine similarity + distance_metric : Literal["euclidean", "cosine"] + Distance metric to use Returns ------- dict[int, list[float]] - Dictionary mapping τ to list of displacements for all particles and initial times + Dictionary mapping time lag τ to list of squared displacements """ - # Get unique tracks efficiently using pandas operations + from collections import defaultdict + unique_tracks_df = ( embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() ) - # Get data from dataset - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - # Initialize results dictionary with empty lists displacement_per_tau = defaultdict(list) - # Process each track for fov_name, track_id in zip( unique_tracks_df["fov_name"], unique_tracks_df["track_id"] ): - # Get sorted track data - mask = (fov_names == fov_name) & (track_ids == track_id) - times = timepoints[mask] - track_embeddings = embeddings[mask] + # Filter data for this track + track_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) # Sort by time - time_order = np.argsort(times) - times = times[time_order] - track_embeddings = track_embeddings[time_order] - - # Process each time point - for t_idx, t in enumerate(times[:-1]): - current_embedding = track_embeddings[t_idx] - - # Check all possible future time points - for future_idx, future_time in enumerate( - times[t_idx + 1 :], start=t_idx + 1 - ): - tau = future_time - t - future_embedding = track_embeddings[future_idx] - - if distance_metric in ["cosine"]: - dot_product = np.dot(current_embedding, future_embedding) - norms = np.linalg.norm(current_embedding) * np.linalg.norm( - future_embedding - ) - similarity = dot_product / norms - displacement = similarity - else: # Euclidean metrics - diff_squared = np.sum((current_embedding - future_embedding) ** 2) - displacement = diff_squared - displacement_per_tau[int(tau)].append(displacement) + time_order = np.argsort(track_data["t"].values) + times = track_data["t"].values[time_order] + track_embeddings = track_data["features"].values[time_order] + + # Compute pairwise distance matrix + if distance_metric == "euclidean": + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric="euclidean" + ) + distance_matrix = distance_matrix**2 # Square for MSD + elif distance_metric == "cosine": + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric="cosine" + ) + else: + raise ValueError(f"Unsupported distance metric: {distance_metric}") + + # Extract displacements using diagonal offsets + n_timepoints = len(times) + for time_offset in range(1, n_timepoints): + diagonal_displacements = compare_time_offset(distance_matrix, time_offset) + + for i, displacement in enumerate(diagonal_displacements): + tau = int(times[i + time_offset] - times[i]) + displacement_per_tau[tau].append(displacement) return dict(displacement_per_tau) From c39d1d6c65592581634973407a6525e1528f3899 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 10 Jul 2025 10:30:35 -0700 Subject: [PATCH 003/101] adding a test for the msd --- .../tests}/test_distance.py | 197 +----------------- 1 file changed, 3 insertions(+), 194 deletions(-) rename {tests/representation => applications/contrastive_phenotyping/tests}/test_distance.py (82%) diff --git a/tests/representation/test_distance.py b/applications/contrastive_phenotyping/tests/test_distance.py similarity index 82% rename from tests/representation/test_distance.py rename to applications/contrastive_phenotyping/tests/test_distance.py index d1146f251..448e25fc7 100644 --- a/tests/representation/test_distance.py +++ b/applications/contrastive_phenotyping/tests/test_distance.py @@ -7,9 +7,8 @@ import xarray as xr from scipy import stats -from viscy.representation.evaluation.clustering import ( - compare_time_offset, - pairwise_distance_matrix, +from viscy.representation.evaluation.distance import ( + compute_msd, ) @@ -336,196 +335,6 @@ def extract_step_sizes_simple(dataset): return fig, (ax1, ax2, ax3, ax4) -def compute_msd_pairwise_optimized( - embedding_dataset: xr.Dataset, - distance_metric: Literal["euclidean", "cosine"] = "euclidean", -) -> dict[int, list[float]]: - """ - Compute Mean Squared Displacement using pairwise distance matrix. - - Uses compare_time_offset for efficient diagonal extraction. - - Parameters - ---------- - embedding_dataset : xr.Dataset - Dataset containing embeddings and metadata - distance_metric : Literal["euclidean", "cosine"] - Distance metric to use - - Returns - ------- - dict[int, list[float]] - Dictionary mapping time lag τ to list of squared displacements - """ - from collections import defaultdict - - unique_tracks_df = ( - embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() - ) - - displacement_per_tau = defaultdict(list) - - for fov_name, track_id in zip( - unique_tracks_df["fov_name"], unique_tracks_df["track_id"] - ): - # Filter data for this track - track_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - - # Sort by time - time_order = np.argsort(track_data["t"].values) - times = track_data["t"].values[time_order] - track_embeddings = track_data["features"].values[time_order] - - # Compute pairwise distance matrix - if distance_metric == "euclidean": - distance_matrix = pairwise_distance_matrix( - track_embeddings, metric="euclidean" - ) - distance_matrix = distance_matrix**2 # Square for MSD - elif distance_metric == "cosine": - distance_matrix = pairwise_distance_matrix( - track_embeddings, metric="cosine" - ) - else: - raise ValueError(f"Unsupported distance metric: {distance_metric}") - - # Extract displacements using diagonal offsets - n_timepoints = len(times) - for time_offset in range(1, n_timepoints): - diagonal_displacements = compare_time_offset(distance_matrix, time_offset) - - for i, displacement in enumerate(diagonal_displacements): - tau = int(times[i + time_offset] - times[i]) - displacement_per_tau[tau].append(displacement) - - return dict(displacement_per_tau) - - -def normalize_msd_by_embedding_variance( - msd_data_dict: dict[str, dict[int, list[float]]], - datasets: dict[str, xr.Dataset], -) -> dict[str, dict[int, list[float]]]: - """ - Normalize MSD values by the embedding variance for each movement type. - - This enables fair comparison between different embedding models or movement types - by removing scale differences. - - Parameters - ---------- - msd_data_dict : dict[str, dict[int, list[float]]] - Dictionary mapping movement type to MSD data - datasets : dict[str, xr.Dataset] - Dictionary mapping movement type to dataset (for computing variance) - - Returns - ------- - dict[str, dict[int, list[float]]] - Normalized MSD data with same structure as input - """ - normalized_msd_data = {} - - for movement_type, msd_data in msd_data_dict.items(): - # Calculate embedding variance for this movement type - embeddings = datasets[movement_type]["features"].values - embedding_variance = np.var(embeddings) - - print(f"{movement_type}: embedding_variance = {embedding_variance:.4f}") - - # Normalize all MSD values by this variance - normalized_msd_data[movement_type] = {} - for tau, displacements in msd_data.items(): - normalized_msd_data[movement_type][tau] = [ - disp / embedding_variance for disp in displacements - ] - - return normalized_msd_data - - -def normalize_step_sizes_by_embedding_variance( - datasets: dict[str, xr.Dataset], -) -> dict[str, dict[str, float]]: - """ - Normalize step size statistics by embedding variance for fair comparison. - - Parameters - ---------- - datasets : dict[str, xr.Dataset] - Dictionary mapping movement type to dataset - - Returns - ------- - dict[str, dict[str, float]] - Dictionary with normalized step size statistics - """ - step_stats = {} - - print("\n=== Step Size Statistics (Normalized by Embedding Variance) ===") - print("-" * 70) - - for movement_type, dataset in datasets.items(): - # Calculate embedding variance for normalization - embeddings = dataset["features"].values - embedding_variance = np.var(embeddings) - - # Extract step sizes - all_step_sizes = [] - unique_track_ids = np.unique(dataset["track_id"].values) - - for track_id in unique_track_ids: - track_mask = dataset["track_id"] == track_id - track_embeddings = dataset["features"].values[track_mask] - track_times = dataset["t"].values[track_mask] - - # Sort by time and remove duplicates - time_order = np.argsort(track_times) - sorted_embeddings = track_embeddings[time_order] - sorted_times = track_times[time_order] - unique_times, unique_indices = np.unique(sorted_times, return_index=True) - final_embeddings = sorted_embeddings[unique_indices] - - if len(final_embeddings) > 1: - steps = np.diff(final_embeddings, axis=0) - step_sizes = np.linalg.norm(steps, axis=1) - all_step_sizes.extend(step_sizes) - - step_sizes = np.array(all_step_sizes) - - # Calculate raw statistics - raw_mean = np.mean(step_sizes) - raw_std = np.std(step_sizes) - raw_cv = raw_std / raw_mean - - # Calculate normalized statistics - norm_mean = raw_mean / np.sqrt(embedding_variance) - norm_std = raw_std / np.sqrt(embedding_variance) - norm_cv = norm_std / norm_mean # CV remains the same after scaling - - step_stats[movement_type] = { - "raw_mean": raw_mean, - "raw_std": raw_std, - "raw_cv": raw_cv, - "norm_mean": norm_mean, - "norm_std": norm_std, - "norm_cv": norm_cv, - "embedding_variance": embedding_variance, - "n_steps": len(step_sizes), - } - - print( - f"{movement_type:15} | Raw: μ={raw_mean:.4f}, σ={raw_std:.4f}, CV={raw_cv:.4f}" - ) - print(f"{'':15} | Norm: μ={norm_mean:.4f}, σ={norm_std:.4f}, CV={norm_cv:.4f}") - print(f"{'':15} | Var={embedding_variance:.4f}, N={len(step_sizes)}") - print("-" * 70) - - return step_stats - - def plot_msd_comparison( msd_data_dict: dict[str, dict[int, list[float]]], title: str = "MSD: Smooth vs Chaotic Diffusion (Same Direction)", @@ -1086,7 +895,7 @@ def extract_step_sizes_simple(dataset): msd_data_dict = {} for movement_type, dataset in datasets.items(): print(f"Computing MSD for {movement_type}...") - msd_data_dict[movement_type] = compute_msd_pairwise_optimized(dataset) + msd_data_dict[movement_type] = compute_msd(dataset) print("\n=== Normalizing MSD by Embedding Variance ===") normalized_msd_data_dict = normalize_msd_by_embedding_variance( From cf97df429b27a01a6d9bb8035765d07ec7c0ef60 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 14 Jul 2025 14:28:55 -0700 Subject: [PATCH 004/101] removing unused msd functions --- viscy/representation/evaluation/clustering.py | 85 +------------------ viscy/representation/evaluation/distance.py | 1 - 2 files changed, 1 insertion(+), 85 deletions(-) diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index 6ac3fc0ef..fcd3964d6 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -152,90 +152,8 @@ def clustering_evaluation(embeddings, annotations, method="nmi"): return score -def compute_msd_from_distance_matrix( - distance_matrix: NDArray, timepoints: ArrayLike, squared: bool = True -) -> dict[int, list[float]]: - """ - Compute MSD from a precomputed distance matrix using diagonal extraction. - - This is the most efficient approach for MSD computation when you already - have a distance matrix. Uses the compare_time_offset function internally. - - Parameters - ---------- - distance_matrix : NDArray - Square distance matrix (n_timepoints, n_timepoints) - timepoints : ArrayLike - Time points corresponding to each row/column - squared : bool, optional - Whether to square the distances (for true MSD), by default True - - Returns - ------- - dict[int, list[float]] - Dictionary mapping time lag τ to list of displacement values - """ - from collections import defaultdict - - if squared: - distance_matrix = distance_matrix**2 - - timepoints = np.array(timepoints) - displacement_per_tau = defaultdict(list) - n_timepoints = len(timepoints) - - # Use diagonal extraction for efficiency - for time_offset in range(1, n_timepoints): - # Extract diagonal at this offset using existing function - diagonal_displacements = compare_time_offset(distance_matrix, time_offset) - - # Map to actual time lags τ - for i, displacement in enumerate(diagonal_displacements): - tau = int(timepoints[i + time_offset] - timepoints[i]) - displacement_per_tau[tau].append(displacement) - - return dict(displacement_per_tau) - - -def compute_msd_from_pairwise_distances( - features: ArrayLike, timepoints: ArrayLike, metric: str = "euclidean" -) -> dict[int, list[float]]: - """ - Compute Mean Square Displacement (MSD) from pairwise distances. - - This is an efficient implementation that uses diagonal extraction - instead of nested loops for better performance. - - Parameters - ---------- - features : ArrayLike - Feature matrix (n_timepoints, n_features) for a single track - timepoints : ArrayLike - Time points corresponding to each feature vector - metric : str, optional - Distance metric to use, by default "euclidean" - - Returns - ------- - dict[int, list[float]] - Dictionary mapping time lag τ to list of displacement values - """ - # Ensure proper ordering by time - time_order = np.argsort(timepoints) - features = np.array(features)[time_order] - timepoints = np.array(timepoints)[time_order] - - # Compute pairwise distance matrix - distance_matrix = pairwise_distance_matrix(features, metric=metric) - - # Use the optimized diagonal extraction method - return compute_msd_from_distance_matrix( - distance_matrix, timepoints, squared=(metric == "euclidean") - ) - - def compute_track_msd_statistics( - features: ArrayLike, timepoints: ArrayLike, metric: str = "euclidean" + msd_per_tau: dict[int, list[float]], ) -> tuple[dict[int, float], dict[int, float]]: """ Compute MSD statistics (mean and std) for a single track. @@ -254,7 +172,6 @@ def compute_track_msd_statistics( tuple[dict[int, float], dict[int, float]] Tuple of (mean_msd, std_msd) dictionaries mapping τ to statistics """ - msd_per_tau = compute_msd_from_pairwise_distances(features, timepoints, metric) mean_msd = { tau: np.mean(displacements) for tau, displacements in msd_per_tau.items() diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 2911c9b7a..236f989c5 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -49,7 +49,6 @@ def compute_msd( dict[int, list[float]] Dictionary mapping time lag τ to list of squared displacements """ - from collections import defaultdict unique_tracks_df = ( embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() From 4c1a49227d5b3b5b329f4a5033e14b38cfd6afd2 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 14 Jul 2025 20:31:24 -0700 Subject: [PATCH 005/101] renaming msd to compute_track_displacement --- .../evaluation/ALFI_MSD_v2.py | 30 ++++++++++++------- viscy/representation/evaluation/distance.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index 94db31675..98014a4f3 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -8,7 +8,7 @@ from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.distance import ( - compute_msd, + compute_track_displacement, ) # Paths to datasets @@ -34,27 +34,34 @@ results = {} raw_displacements = {} +DISTANCE_METRIC = "cosine" for label, path in feature_paths.items(): results[label] = {} print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) # Compute displacements - displacements_per_tau = compute_msd( + displacements_per_tau = compute_track_displacement( embedding_dataset=embedding_dataset, - distance_metric="euclidean", + distance_metric=DISTANCE_METRIC, ) - embeddings_variance = np.var(embedding_dataset["features"].values) - # Normalize MSD by embeddings variance - for tau, displacements in displacements_per_tau.items(): - results[label][tau] = [disp / embeddings_variance for disp in displacements] + # Store displacements with conditional normalization + if DISTANCE_METRIC == "cosine": + # Cosine distance is already scale-invariant, no normalization needed + for tau, displacements in displacements_per_tau.items(): + results[label][tau] = displacements + else: + # Normalize by embeddings variance for euclidean distance + embeddings_variance = np.var(embedding_dataset["features"].values) + for tau, displacements in displacements_per_tau.items(): + results[label][tau] = [disp / embeddings_variance for disp in displacements] # %% Plot MSD vs time (linear scale) show_power_law_fits = True log_scale = True -title = "MSD vs Time Shift" +title = "Mean Track Displacement vs Time Shift" fig, ax = plt.subplots(figsize=(10, 7)) @@ -63,6 +70,7 @@ msd_means = [] msd_stds = [] + # Compute mean and std of MSD for each time lag for tau in time_lags: displacements = np.array(msd_data[tau]) msd_means.append(np.mean(displacements)) @@ -113,7 +121,7 @@ ) ax.set_xlabel("Time Lag (τ)", fontsize=12) - ax.set_ylabel("Mean Squared Displacement", fontsize=12) + ax.set_ylabel("Mean Track Displacement", fontsize=12) ax.set_title(title, fontsize=14) if log_scale: @@ -123,7 +131,7 @@ ax.legend() plt.tight_layout() -plt.savefig("msd_vs_time_shift.png", dpi=300) +plt.savefig(f"msd_vs_time_shift_{DISTANCE_METRIC}.png", dpi=300) # %% # Step size analysis @@ -202,6 +210,6 @@ def extract_step_sizes(embedding_dataset: xr.Dataset): ax2.tick_params(axis="x", rotation=45) plt.tight_layout() # plt.show() -plt.savefig("step_size_distributions.png", dpi=300) +plt.savefig(f"step_size_distributions_{DISTANCE_METRIC}.png", dpi=300) # %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 236f989c5..037abe999 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -30,7 +30,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities.tolist() -def compute_msd( +def compute_track_displacement( embedding_dataset: xr.Dataset, distance_metric: Literal["euclidean", "cosine"] = "euclidean", ) -> dict[int, list[float]]: From 8638168af20774572ff446e7af1889ae8f68ca85 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 15 Jul 2025 09:30:09 -0700 Subject: [PATCH 006/101] default to cosine distance --- viscy/representation/evaluation/distance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 037abe999..ab9f3e05c 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -32,7 +32,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_track_displacement( embedding_dataset: xr.Dataset, - distance_metric: Literal["euclidean", "cosine"] = "euclidean", + distance_metric: Literal["euclidean", "cosine"] = "cosine", ) -> dict[int, list[float]]: """ Compute Mean Squared Displacement using pairwise distance matrix. @@ -42,7 +42,7 @@ def compute_track_displacement( embedding_dataset : xr.Dataset Dataset containing embeddings and metadata distance_metric : Literal["euclidean", "cosine"] - Distance metric to use + Distance metric to use. Default is cosine. Returns ------- From c40b64ce582f6af4b56694094a6af3fd6093b0b9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 16 Jul 2025 16:52:43 -0700 Subject: [PATCH 007/101] adding the gradient attribution video. --- .../figures/grad_attr_time.py | 833 ++++++++++++++++++ 1 file changed, 833 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/grad_attr_time.py diff --git a/applications/contrastive_phenotyping/figures/grad_attr_time.py b/applications/contrastive_phenotyping/figures/grad_attr_time.py new file mode 100644 index 000000000..fbe585241 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/grad_attr_time.py @@ -0,0 +1,833 @@ +# %% +import logging +import warnings +from pathlib import Path + +import matplotlib as mpl +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import xarray as xr +from cmap import Colormap +from lightning.pytorch import seed_everything +from skimage.exposure import rescale_intensity +from sklearn.metrics import ( + accuracy_score, + auc, + f1_score, + precision_recall_curve, + roc_auc_score, +) + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.engine import ContrastiveEncoder, ContrastiveModule +from viscy.representation.evaluation import load_annotation +from viscy.representation.evaluation.lca import ( + AssembledClassifier, + fit_logistic_regression, + linear_from_binary_logistic_regression, +) +from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd + +seed_everything(42, workers=True) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +# %% +# Dataset for display and occlusion analysis +data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +tracks_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +annotation_occlusion_infection_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" +annotation_occlusion_division_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" +fov = "/B/4/8" +track = [44, 46] + +# %% +dm = TripletDataModule( + data_path=data_path, + tracks_path=tracks_path, + source_channel=["Phase3D", "RFP"], + z_range=[25, 40], + batch_size=48, + num_workers=0, + initial_yx_patch_size=(128, 128), + final_yx_patch_size=(128, 128), + normalizations=[ + NormalizeSampled( + keys=["Phase3D"], level="fov_statistics", subtrahend="mean", divisor="std" + ), + ScaleIntensityRangePercentilesd( + keys=["RFP"], lower=50, upper=99, b_min=0.0, b_max=1.0 + ), + ], + predict_cells=True, + include_fov_names=[fov] * len(track), + include_track_ids=track, +) +dm.setup("predict") +len(dm.predict_dataset) + +# %% +# load model +model = ContrastiveModule.load_from_checkpoint( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/epoch=178-step=16826.ckpt", + encoder=ContrastiveEncoder( + backbone="convnext_tiny", + in_channels=2, + in_stack_depth=15, + stem_kernel_size=(5, 4, 4), + stem_stride=(5, 4, 4), + embedding_dim=768, + projection_dim=32, + ), +).eval() + +# %% +# TODO add the patsh to the combination of sec61 and tomm20 +# train linear classifier +# INFECTION +## Embedding and Annotations + +path_infection_embedding_1 = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + +path_annotations_infection_1 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" +) +# TOMM20 +path_infection_embedding_2 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/1-predictions/sensor_160patch_99ckpt_max.zarr" +) +path_annotations_infection_2 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_infection_annotation.csv" +) + +# SEC61 +path_infection_embedding_3 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/prediction_infection/2chan_192patch_100ckpt_timeAware_ntxent_rerun.zarr" +) + +path_annotations_infection_3 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_infection_annotation.csv" +) + +# CELL DIVISION +path_annotations_division_1 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" +) +path_division_embedding_1 = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178_gt_tracks.zarr" +) +# TOMM20 +path_annotations_division_2 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv" +) +# SEC61 +path_annotations_division_3 = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_cell_state_annotation.csv" +) +# %% +######### +# Make tuple of tuples of embedding and annotations + +# Train FOVs - use a broader set since we have multiple datasets + +infection_classifier_pairs = ( + ( + path_infection_embedding_1, + path_annotations_infection_1, + ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/6", "/B/4/7"], + ), + (path_infection_embedding_2, path_annotations_infection_2, "all"), + (path_infection_embedding_3, path_annotations_infection_3, "all"), +) + +# NOTE: embedding 1 and annotations 1 are not used. They are not wll annotated for division +division_classifier_pairs = ( + # ( + # path_division_embedding_1, + # path_annotations_division_1, + # ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/6", "/B/4/7"], + # ), + # (path_infection_embedding_2, path_annotations_division_2, "all"), + (path_infection_embedding_3, path_annotations_division_3, "all"), +) + + +def load_and_combine_datasets( + datasets, + target_type="infection", + standardization_mapping=None, +): + """Load and combine multiple embedding datasets with their annotations. + + Parameters + ---------- + datasets : list of tuple + List of (embedding_path, annotation_path, train_fovs) tuples containing + paths to embedding files, annotation CSV files, and training FOVs. + target_type : str, default='infection' + Type of classification target. Either 'infection' or 'division' - determines + which column to look for in the annotation files. + standardization_mapping : dict, optional + Dictionary to standardize different annotation formats across datasets. + Maps original values to standardized values. + Example: {'infected': 2, 'uninfected': 1, 'background': 0, + 2.0: 2, 1.0: 1, 0.0: 0, 'mitosis': 2, 'interphase': 1, 'unknown': 0} + + Returns + ------- + combined_features : xarray.DataArray + Combined feature embeddings from all successfully loaded datasets. + combined_annotations : pandas.Series + Combined and standardized annotations from all datasets. + + Raises + ------ + ValueError + If no datasets were successfully loaded. + """ + + all_features = [] + all_annotations = [] + + # Default standardization mappings + if standardization_mapping is None: + if target_type == "infection": + standardization_mapping = { + # String formats + "infected": 2, + "uninfected": 1, + "background": 0, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + 1: 1, + 0: 0, + } + elif target_type == "division": + standardization_mapping = { + # String formats + "mitosis": 2, + "interphase": 1, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + 1: 1, + 0: 0, + } + + for emb_path, ann_path, train_fovs in datasets: + try: + logger.debug(f"Loading dataset: {emb_path}") + dataset = read_embedding_dataset(emb_path) + + # Read annotation CSV to detect column names + logger.debug(f"Reading annotation CSV: {ann_path}") + ann_df = pd.read_csv(ann_path) + # make sure the ann_fov_names start with '/' otherwise add it, and strip whitespace + ann_df["fov_name"] = ann_df["fov_name"].apply( + lambda x: ( + "/" + x.strip() if not x.strip().startswith("/") else x.strip() + ) + ) + + if train_fovs == "all": + train_fovs = np.unique(dataset["fov_name"]) + + # Auto-detect annotation column based on target_type + annotation_key = None + if target_type == "infection": + for col in [ + "infection_state", + "infection", + "infection_status", + ]: + if col in ann_df.columns: + annotation_key = col + break + + elif target_type == "division": + for col in ["division", "cell_division", "cell_state"]: + if col in ann_df.columns: + annotation_key = col + break + + if annotation_key is None: + print(f" No {target_type} column found, skipping...") + continue + + # Filter the dataset to only include the FOVs in the annotation + # Use xarray's native filtering methods + ann_fov_names = set(ann_df["fov_name"].unique()) + train_fovs = set(train_fovs) + + logger.debug(f"Dataset FOVs: {dataset['fov_name'].values}") + logger.debug(f"Annotation FOV names: {ann_fov_names}") + logger.debug(f"Train FOVs: {train_fovs}") + logger.debug(f"Dataset samples before filtering: {len(dataset.sample)}") + + # Filter and get only the intersection of train_fovs and ann_fov_names + common_fovs = train_fovs.intersection(ann_fov_names) + # missed out fovs in the dataset + missed_fovs = train_fovs - common_fovs + # missed out fovs in the annotations + missed_fovs_ann = ann_fov_names - common_fovs + + if len(common_fovs) == 0: + raise ValueError( + f"No common FOVs found between dataset and annotations: {train_fovs} not in {ann_fov_names}" + ) + elif len(missed_fovs) > 0: + warnings.warn( + f"No matching found for FOVs in the train dataset: {missed_fovs}" + ) + elif len(missed_fovs_ann) > 0: + warnings.warn( + f"No matching found for FOVs in the annotations: {missed_fovs_ann}" + ) + + logger.debug(f"Intersection of train_fovs and ann_fov_names: {common_fovs}") + + # Filter the dataset to only include the intersection of train_fovs and ann_fov_names + dataset = dataset.where( + dataset["fov_name"].isin(list(common_fovs)), drop=True + ) + + logger.debug(f"Dataset samples after filtering: {len(dataset.sample)}") + + # Load annotations without class mapping first + annotations = load_annotation(dataset, ann_path, annotation_key) + + # Check unique values before standardization + unique_vals = annotations.unique() + logger.debug(f"Original unique values: {unique_vals}") + + # Apply standardization mapping + standardized_annotations = annotations.copy() + if standardization_mapping: + for original_val, standard_val in standardization_mapping.items(): + mask = annotations == original_val + if mask.any(): + standardized_annotations[mask] = standard_val + logger.debug( + f"Mapped {original_val} -> {standard_val} ({mask.sum()} instances)" + ) + + # Check standardized values + std_unique_vals = standardized_annotations.unique() + logger.debug(f"Standardized unique values: {std_unique_vals}") + + # Convert to categorical for consistency + standardized_annotations = standardized_annotations.astype("category") + + # Keep features as xarray DataArray for compatibility with fit_logistic_regression + all_features.append(dataset["features"]) + all_annotations.append(standardized_annotations) + + logger.debug(f"Features shape: {dataset['features'].shape}") + logger.debug(f"Annotations shape: {standardized_annotations.shape}") + except Exception as e: + raise ValueError(f"Error loading dataset {emb_path}: {e}") + + # Combine all datasets + if all_features: + # Extract features and coordinates from each dataset + all_features_arrays = [] + all_coords = [] + + for dataset in all_features: + # Extract the features array + features_array = dataset["features"].values + all_features_arrays.append(features_array) + + # Extract coordinates + coords_dict = {} + for coord_name in dataset.coords: + if coord_name != "sample": # skip sample coordinate + coords_dict[coord_name] = dataset.coords[coord_name].values + all_coords.append(coords_dict) + + # Combine feature arrays + combined_features_array = np.concatenate(all_features_arrays, axis=0) + + # Combine coordinates (excluding 'features' from coordinates) + combined_coords = {} + for coord_name in all_coords[0].keys(): + if coord_name != "features": # Don't include 'features' in coordinates + coord_values = [] + for coords_dict in all_coords: + coord_values.extend(coords_dict[coord_name]) + combined_coords[coord_name] = coord_values + + # Create new combined dataset in the correct format + coords_dict = { + "sample": range(len(combined_features_array)), + } + + # Add each coordinate as a 1D coordinate along the sample dimension + for coord_name, coord_values in combined_coords.items(): + coords_dict[coord_name] = ("sample", coord_values) + + combined_dataset = xr.Dataset( + { + "features": (("sample", "features"), combined_features_array), + }, + coords=coords_dict, + ) + + # Set the index properly like the original datasets + if "fov_name" in combined_coords: + available_coords = [ + coord + for coord in combined_coords.keys() + if coord in ["fov_name", "track_id", "t"] + ] + combined_dataset = combined_dataset.set_index(sample=available_coords) + + combined_annotations = pd.concat(all_annotations, ignore_index=True) + + logger.debug(f"Combined features shape: {combined_dataset['features'].shape}") + logger.debug(f"Combined annotations shape: {combined_annotations.shape}") + + # Final check of combined annotations + final_unique = combined_annotations.unique() + logger.debug(f"Final combined unique values: {final_unique}") + + return combined_dataset["features"], combined_annotations + + +# %% + +# Load and combine infection datasets +logger.info("Loading infection classification datasets...") +infection_features, infection_labels = load_and_combine_datasets( + infection_classifier_pairs, + target_type="infection", +) +# %% +# Load and combine division datasets +logger.info("Loading division classification datasets...") +division_features, division_labels = load_and_combine_datasets( + division_classifier_pairs, + target_type="division", +) + + +# %% + +logistic_regression_infection, _ = fit_logistic_regression( + features=infection_features.copy(), + annotations=infection_labels.copy(), + train_ratio=0.8, + remove_background_class=True, + scale_features=True, + class_weight="balanced", + solver="liblinear", + random_state=42, +) +# %% + +logistic_regression_division, _ = fit_logistic_regression( + division_features.copy(), + division_labels.copy(), + train_ratio=0.8, + remove_background_class=True, + scale_features=True, + class_weight="balanced", + solver="liblinear", + random_state=42, +) + +# %% +linear_classifier_infection = linear_from_binary_logistic_regression( + logistic_regression_infection +) +assembled_classifier_infection = ( + AssembledClassifier(model.model, linear_classifier_infection) + .eval() + .to(model.device) +) + +# %% +linear_classifier_division = linear_from_binary_logistic_regression( + logistic_regression_division +) +assembled_classifier_division = ( + AssembledClassifier(model.model, linear_classifier_division).eval().to(model.device) +) + + +# %% +# Loading the lineage images +img = [] +for sample in dm.predict_dataloader(): + img.append(sample["anchor"].numpy()) +img = np.concatenate(img, axis=0) +print(img.shape) + +# %% +img_tensor = torch.from_numpy(img).to(model.device) + +with torch.inference_mode(): + infection_probs = assembled_classifier_infection(img_tensor).sigmoid() + division_probs = assembled_classifier_division(img_tensor).sigmoid() + +# %% +attr_kwargs = dict( + img=img_tensor, + sliding_window_shapes=(1, 15, 12, 12), + strides=(1, 15, 4, 4), + show_progress=True, +) + + +infection_attribution = ( + assembled_classifier_infection.attribute_occlusion(**attr_kwargs).cpu().numpy() +) +division_attribution = ( + assembled_classifier_division.attribute_occlusion(**attr_kwargs).cpu().numpy() +) + + +# %% +def clip_rescale(img, low, high): + return rescale_intensity(img.clip(low, high), out_range=(0, 1)) + + +def clim_percentile(heatmap, low=1, high=99): + lo, hi = np.percentile(heatmap, (low, high)) + return clip_rescale(heatmap, lo, hi) + + +g_lim = 1 +z_slice = 5 +phase = clim_percentile(img[:, 0, z_slice]) +rfp = clim_percentile(img[:, 1, z_slice]) +img_render = np.concatenate([phase, rfp], axis=2) +phase_heatmap_inf = infection_attribution[:, 0, z_slice] +rfp_heatmap_inf = infection_attribution[:, 1, z_slice] +inf_render = clip_rescale( + np.concatenate([phase_heatmap_inf, rfp_heatmap_inf], axis=2), -g_lim, g_lim +) +phase_heatmap_div = division_attribution[:, 0, z_slice] +rfp_heatmap_div = division_attribution[:, 1, z_slice] +div_render = clip_rescale( + np.concatenate([phase_heatmap_div, rfp_heatmap_div], axis=2), -g_lim, g_lim +) + + +# %% +# Filter the dataframe to only include the fovs and track_id of the current fov +infection = pd.read_csv(annotation_occlusion_infection_path) +infection = infection[infection["fov_name"] == fov[1:]] +infection = infection[infection["track_id"].isin(track)] +track_classes_infection = infection["infection_state"] + +# load division annotations +division = pd.read_csv(annotation_occlusion_division_path) +division = division[division["fov_name"] == fov[1:]] +division = division[division["track_id"].isin(track)] + +division["division"] = 1 # default: not dividing +division.loc[division["t"].between(16, 22, inclusive="both"), "division"] = ( + 2 # dividing for t in 16-20 +) + +track_classes_division = division["division"] + + +# %% +plt.style.use("./figure.mplstyle") + +all_time_points = list(range(len(img_render))) +selected_time_points = all_time_points[ + :: max(1, len(all_time_points) // 8) +] # Show up to 8 time points + + +sps = len(selected_time_points) + +icefire = Colormap("icefire").to_mpl() + +f, ax = plt.subplots(3, sps, figsize=(2 * sps, 3), layout="compressed") +for i, time in enumerate(selected_time_points): + hpi = 3 + 0.5 * time + prob = infection_probs[time].item() + inf_binary = str(bool(track_classes_infection.iloc[time] - 1)).lower() + div_binary = str(bool(track_classes_division.iloc[time] - 1)).lower() + ax[0, i].imshow(img_render[time], cmap="gray") + ax[0, i].set_title(f"{hpi} HPI") + ax[1, i].imshow(inf_render[time], cmap=icefire, vmin=0, vmax=1) + ax[1, i].set_title( + f"infected: {prob:.3f}\n" f"label: {inf_binary}", + ) + ax[2, i].imshow(div_render[time], cmap=icefire, vmin=0, vmax=1) + ax[2, i].set_title( + f"dividing: {division_probs[time].item():.3f}\n" f"label: {div_binary}", + ) +for a in ax.ravel(): + a.axis("off") +norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) +cbar = f.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=icefire), + orientation="vertical", + ax=ax[1:].ravel().tolist(), + format=mpl.ticker.StrMethodFormatter("{x:.1f}"), +) +cbar.set_label("occlusion attribution") + +# %% +# f.savefig( +# Path.home() +# / "mydata" +# / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/fig_explanation_patch12_stride4.pdf", +# dpi=300, +# ) + +# %% +# Create video animation of occlusion analysis +icefire = Colormap("icefire").to_mpl() +plt.style.use("./figure.mplstyle") + +fig, ax = plt.subplots(3, 1, figsize=(6, 8), layout="compressed") + +# Initialize plots +im1 = ax[0].imshow(img_render[0], cmap="gray") +ax[0].set_title("Original Image") +ax[0].axis("off") + +im2 = ax[1].imshow(inf_render[0], cmap=icefire, vmin=0, vmax=1) +ax[1].set_title("Infection Occlusion Attribution") +ax[1].axis("off") + +im3 = ax[2].imshow(div_render[0], cmap=icefire, vmin=0, vmax=1) +ax[2].set_title("Division Occlusion Attribution") +ax[2].axis("off") + +# Store initial border colors +for a in ax: + for spine in a.spines.values(): + spine.set_linewidth(3) + spine.set_color("black") + +# Add colorbar +norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) +cbar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=icefire), + ax=ax[1:], + orientation="horizontal", + shrink=0.8, + pad=0.1, +) +cbar.set_label("Occlusion Attribution") + + +# Animation function +def animate(frame): + time = frame + hpi = 3 + 0.5 * time + + # Update images + im1.set_array(img_render[time]) + im2.set_array(inf_render[time]) + im3.set_array(div_render[time]) + + # Update titles with probabilities + inf_prob = infection_probs[time].item() + div_prob = division_probs[time].item() + inf_binary = bool(track_classes_infection.iloc[time] - 1) + div_binary = bool(track_classes_division.iloc[time] - 1) + + # Color code labels - red for true, green for false + inf_color = "darkorange" if inf_binary else "blue" + div_color = "darkorange" if div_binary else "blue" + + # Make label text bold when true + inf_weight = "bold" if inf_binary else "normal" + div_weight = "bold" if div_binary else "normal" + + # Update border colors to highlight true labels + for spine in ax[1].spines.values(): + spine.set_color(inf_color) + spine.set_linewidth(4 if inf_binary else 2) + + for spine in ax[2].spines.values(): + spine.set_color(div_color) + spine.set_linewidth(4 if div_binary else 2) + + ax[0].set_title(f"Original Image - {hpi:.1f} HPI", fontsize=12, fontweight="bold") + ax[1].set_title( + f"Infection Attribution - Prob: {inf_prob:.3f} (Label: {str(inf_binary).lower()})", + fontsize=12, + fontweight=inf_weight, + color=inf_color, + ) + ax[2].set_title( + f"Division Attribution - Prob: {div_prob:.3f} (Label: {str(div_binary).lower()})", + fontsize=12, + fontweight=div_weight, + color=div_color, + ) + + return [im1, im2, im3] + + +# Create animation +anim = animation.FuncAnimation( + fig, animate, frames=len(img_render), interval=200, blit=True, repeat=True +) + +# %% +# Save as video +video_path = ( + Path.home() + / "mydata" + / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/occlusion_analysis_video.mp4" +) +video_path.parent.mkdir(parents=True, exist_ok=True) + +# Save as MP4 +Writer = animation.writers["ffmpeg"] +writer = Writer(fps=5, metadata=dict(artist="VisCy"), bitrate=1800) +anim.save(str(video_path), writer=writer) + +print(f"Video saved to: {video_path}") + + +# %% +# Performance metrics over time +def calculate_metrics_over_time(y_true, y_pred_probs, threshold=0.5): + """Calculate accuracy, F1, and AUC for each time point""" + y_pred = (y_pred_probs > threshold).astype(int) + + metrics = {"accuracy": [], "f1": [], "auc": []} + + for i in range(len(y_true)): + # Get predictions up to current time point + true_up_to_i = y_true[: i + 1] + pred_up_to_i = y_pred[: i + 1] + prob_up_to_i = y_pred_probs[: i + 1] + + # Skip if we don't have both classes + if len(np.unique(true_up_to_i)) < 2: + metrics["accuracy"].append(np.nan) + metrics["f1"].append(np.nan) + metrics["auc"].append(np.nan) + continue + + # Calculate metrics + acc = accuracy_score(true_up_to_i, pred_up_to_i) + f1 = f1_score(true_up_to_i, pred_up_to_i, average="binary") + try: + auc_score = roc_auc_score(true_up_to_i, prob_up_to_i) + except: + auc_score = np.nan + + metrics["accuracy"].append(acc) + metrics["f1"].append(f1) + metrics["auc"].append(auc_score) + + return metrics + + +# Ensure we have matching lengths - use the minimum length +min_length = min( + len(track_classes_infection), len(track_classes_division), len(infection_probs) +) + +# Convert labels to binary for metrics calculation - truncate to min_length +inf_true = (track_classes_infection.values[:min_length] - 1).astype(bool).astype(int) +div_true = track_classes_division.values[:min_length].astype(bool).astype(int) + +inf_probs = infection_probs[:min_length].cpu().numpy() +div_probs = division_probs[:min_length].cpu().numpy() + +print(f"Using {min_length} time points for metrics calculation") +print(f"Infection labels shape: {inf_true.shape}") +print(f"Division labels shape: {div_true.shape}") +print(f"Infection probs shape: {inf_probs.shape}") +print(f"Division probs shape: {div_probs.shape}") + +# Calculate metrics +inf_metrics = calculate_metrics_over_time(inf_true, inf_probs) +div_metrics = calculate_metrics_over_time(div_true, div_probs) + +# Time points +time_points = np.arange(len(inf_true)) +hpi_values = 3 + 0.5 * time_points + +# Create metrics plot +fig, axes = plt.subplots(2, 3, figsize=(15, 8), layout="compressed") + +# Infection metrics +axes[0, 0].plot( + hpi_values, inf_metrics["accuracy"], "b-", linewidth=2, label="Accuracy" +) +axes[0, 0].set_title("Infection Classification Accuracy Over Time") +axes[0, 0].set_xlabel("Hours Post Infection (HPI)") +axes[0, 0].set_ylabel("Accuracy") +axes[0, 0].grid(True, alpha=0.3) +axes[0, 0].set_ylim(0, 1) + +axes[0, 1].plot(hpi_values, inf_metrics["f1"], "g-", linewidth=2, label="F1 Score") +axes[0, 1].set_title("Infection Classification F1 Score Over Time") +axes[0, 1].set_xlabel("Hours Post Infection (HPI)") +axes[0, 1].set_ylabel("F1 Score") +axes[0, 1].grid(True, alpha=0.3) +axes[0, 1].set_ylim(0, 1) + +axes[0, 2].plot(hpi_values, inf_metrics["auc"], "r-", linewidth=2, label="AUC") +axes[0, 2].set_title("Infection Classification AUC Over Time") +axes[0, 2].set_xlabel("Hours Post Infection (HPI)") +axes[0, 2].set_ylabel("AUC") +axes[0, 2].grid(True, alpha=0.3) +axes[0, 2].set_ylim(0, 1) + +# Division metrics +axes[1, 0].plot( + hpi_values, div_metrics["accuracy"], "b-", linewidth=2, label="Accuracy" +) +axes[1, 0].set_title("Division Classification Accuracy Over Time") +axes[1, 0].set_xlabel("Hours Post Infection (HPI)") +axes[1, 0].set_ylabel("Accuracy") +axes[1, 0].grid(True, alpha=0.3) +axes[1, 0].set_ylim(0, 1) + +axes[1, 1].plot(hpi_values, div_metrics["f1"], "g-", linewidth=2, label="F1 Score") +axes[1, 1].set_title("Division Classification F1 Score Over Time") +axes[1, 1].set_xlabel("Hours Post Infection (HPI)") +axes[1, 1].set_ylabel("F1 Score") +axes[1, 1].grid(True, alpha=0.3) +axes[1, 1].set_ylim(0, 1) + +axes[1, 2].plot(hpi_values, div_metrics["auc"], "r-", linewidth=2, label="AUC") +axes[1, 2].set_title("Division Classification AUC Over Time") +axes[1, 2].set_xlabel("Hours Post Infection (HPI)") +axes[1, 2].set_ylabel("AUC") +axes[1, 2].grid(True, alpha=0.3) +axes[1, 2].set_ylim(0, 1) + +plt.tight_layout() + +# %% +# Save metrics plot +metrics_path = ( + Path.home() + / "mydata" + / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/performance_metrics_over_time.pdf" +) +fig.savefig(str(metrics_path), dpi=300, bbox_inches="tight") +print(f"Metrics plot saved to: {metrics_path}") From b51c1b8affa64127f7ee109a0ecff7f854ff1717 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 16 Jul 2025 16:59:40 -0700 Subject: [PATCH 008/101] extend to training ratios --- viscy/representation/evaluation/lca.py | 57 +++++++++++++++++++------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 7c5216193..9965dcbac 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -9,6 +9,7 @@ from numpy.typing import NDArray from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report +from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from torch import Tensor from xarray import DataArray @@ -19,7 +20,8 @@ def fit_logistic_regression( features: DataArray, annotations: pd.Series, - train_fovs: list[str], + train_fovs: list[str] | None = None, + train_ratio: float = 0.8, remove_background_class: bool = True, scale_features: bool = False, class_weight: Mapping | str | None = "balanced", @@ -38,8 +40,14 @@ def fit_logistic_regression( annotations : pd.Series Categorical class annotations with label values starting from 0. Must have 3 classes (when remove background is True) or 2 classes. - train_fovs : list[str] + train_fovs : list[str] | None, optional List of FOVs to use for training. The rest will be used for testing. + If None, uses stratified sampling based on train_ratio. + train_ratio : float, optional + Proportion of samples to use for training (0.0 to 1.0). + Used when train_fovs is None. + Uses stratified sampling to ensure balanced class representation. + Default is 0.8 (80% training, 20% testing). remove_background_class : bool, optional Remove background class (0), by default True scale_features : bool, optional @@ -56,23 +64,44 @@ def fit_logistic_regression( tuple[LogisticRegression, tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]]] Trained classifier and data split [[X_train, y_train], [X_test, y_test]]. """ - fov_selection = features["fov_name"].isin(train_fovs) - train_selection = fov_selection - test_selection = ~fov_selection annotations = annotations.cat.codes.values.copy() + + # Handle background class removal before splitting for stratification if remove_background_class: - label_selection = annotations != 0 - train_selection &= label_selection - test_selection &= label_selection - annotations -= 1 - train_features = features.values[train_selection] - test_features = features.values[test_selection] + valid_indices = annotations != 0 + features_filtered = features[valid_indices] + annotations_filtered = annotations[valid_indices] - 1 + else: + features_filtered = features + annotations_filtered = annotations + + # Determine train/test split + if train_fovs is not None: + fov_selection = features_filtered["fov_name"].isin(train_fovs) + train_selection = fov_selection + test_selection = ~fov_selection + else: + # Use stratified sampling + n_samples = len(annotations_filtered) + indices = range(n_samples) + train_indices, test_indices = train_test_split( + indices, + test_size=1 - train_ratio, + stratify=annotations_filtered, + random_state=random_state, + ) + train_selection = pd.Series(False, index=range(n_samples)) + train_selection.iloc[train_indices] = True + test_selection = ~train_selection + train_features = features_filtered.values[train_selection] + test_features = features_filtered.values[test_selection] + train_annotations = annotations_filtered[train_selection] + test_annotations = annotations_filtered[test_selection] + if scale_features: scaler = StandardScaler() train_features = scaler.fit_transform(train_features) - test_features = scaler.fit_transform(test_features) - train_annotations = annotations[train_selection] - test_annotations = annotations[test_selection] + test_features = scaler.transform(test_features) logistic_regression = LogisticRegression( class_weight=class_weight, random_state=random_state, From 4eabce3b073c0678b9927783edcff31064212e94 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 14:25:20 -0700 Subject: [PATCH 009/101] demo beta_vae 2.5D --- .../DynaCLR/BetaVAE/debug_dimensions.py | 169 +++++++ .../DynaCLR/BetaVAE/debug_stem.py | 39 ++ .../benchmarking/DynaCLR/BetaVAE/test_run.py | 206 ++++++++ .../representation/disentanglement_metrics.py | 366 ++++++++++++++ viscy/representation/engine.py | 379 +++++++++++++++ viscy/representation/vae.py | 257 ++++++++++ viscy/representation/vae_logging.py | 448 ++++++++++++++++++ 7 files changed, 1864 insertions(+) create mode 100644 applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py create mode 100644 applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py create mode 100644 applications/benchmarking/DynaCLR/BetaVAE/test_run.py create mode 100644 viscy/representation/disentanglement_metrics.py create mode 100644 viscy/representation/vae.py create mode 100644 viscy/representation/vae_logging.py diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py new file mode 100644 index 000000000..70be8f8a8 --- /dev/null +++ b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py @@ -0,0 +1,169 @@ +# %% +import torch +from viscy.representation.vae import VaeEncoder, VaeDecoder + + +def debug_vae_dimensions(): + """Debug VAE encoder/decoder dimension compatibility.""" + + print("=== VAE Dimension Debugging ===\n") + + # Configuration from test_run.py + z_stack_depth = 32 + input_shape = (1, 1, z_stack_depth, 192, 192) + latent_dim = 256 + latent_spatial_size = 3 + + print(f"Input shape: {input_shape}") + print(f"Expected latent dim: {latent_dim}") + print(f"Expected latent spatial size: {latent_spatial_size}") + print() + + # Create encoder + encoder = VaeEncoder( + backbone="resnet50", + in_channels=1, + in_stack_depth=z_stack_depth, + embedding_dim=latent_dim, + stem_kernel_size=(8, 4, 4), + stem_stride=(8, 4, 4), + ) + + # Create decoder + decoder = VaeDecoder( + decoder_channels=[1024, 512, 256, 128], + latent_dim=latent_dim, + out_channels=1, + out_stack_depth=z_stack_depth, + latent_spatial_size=latent_spatial_size, + head_expansion_ratio=1, + head_pool=False, + upsample_mode="deconv", + conv_blocks=2, + norm_name="batch", + upsample_pre_conv=None, + strides=[2, 2, 2, 2], + ) + + print("=== ENCODER FORWARD PASS ===") + + # Test encoder + x = torch.randn(*input_shape) + print(f"Input to encoder: {x.shape}") + + try: + # Step through encoder + print("\n1. Stem processing:") + x_stem = encoder.stem(x) + print(f" After stem: {x_stem.shape}") + + print("\n2. Backbone processing:") + features = encoder.encoder(x_stem) + for i, feat in enumerate(features): + print(f" Feature {i}: {feat.shape}") + + print("\n3. Final processing:") + x_final = features[-1] + print(f" Final features: {x_final.shape}") + + x_pooled = encoder.global_pool(x_final) + print(f" After global pool: {x_pooled.shape}") + + x_flat = x_pooled.flatten(1) + print(f" After flatten: {x_flat.shape}") + + # Full encoder output + encoder_output = encoder(x) + mu = encoder_output.embedding + logvar = encoder_output.log_covariance + print(f" Final mu: {mu.shape}") + print(f" Final logvar: {logvar.shape}") + + print("\n=== DECODER FORWARD PASS ===") + + # Test decoder with latent vector + z = torch.randn(1, latent_dim) + print(f"Input to decoder: {z.shape}") + + print("\n1. Latent projection:") + x_proj = decoder.latent_proj(z) + print(f" After projection: {x_proj.shape}") + + x_reshaped = x_proj.view(1, -1, latent_spatial_size, latent_spatial_size) + print(f" After reshape: {x_reshaped.shape}") + + print("\n2. Decoder stages:") + x_current = x_reshaped + for i, stage in enumerate(decoder.decoder_stages): + x_current = stage(x_current) + print(f" After stage {i}: {x_current.shape}") + + print("\n3. Head processing:") + final_output = decoder.head(x_current) + print(f" Final output: {final_output.shape}") + + # Full decoder output + decoder_output = decoder(z) + reconstruction = decoder_output["reconstruction"] + print(f" Full reconstruction: {reconstruction.shape}") + + print("\n=== DIMENSION ANALYSIS ===") + print(f"✓ Encoder input: {input_shape}") + print(f"✓ Encoder output: {mu.shape}") + print(f"✓ Decoder input: {z.shape}") + print(f"✓ Decoder output: {reconstruction.shape}") + + # Calculate tensor sizes + input_size = torch.numel(x) + recon_size = torch.numel(reconstruction) + print(f" Input tensor size: {input_size}") + print(f" Reconstruction tensor size: {recon_size}") + print(f" Size ratio: {recon_size / input_size:.2f}") + + # Check if reconstruction matches input + if reconstruction.shape == x.shape: + print("✓ SUCCESS: Reconstruction shape matches input shape!") + else: + print(f"✗ ERROR: Shape mismatch!") + print(f" Input: {x.shape}") + print(f" Reconstruction: {reconstruction.shape}") + + # Analyze each dimension + for i, (inp_dim, recon_dim) in enumerate( + zip(x.shape, reconstruction.shape) + ): + if inp_dim != recon_dim: + print( + f" Dimension {i}: {inp_dim} → {recon_dim} (factor: {recon_dim/inp_dim:.2f})" + ) + + except Exception as e: + print(f"✗ ERROR during forward pass: {e}") + print(f"Error type: {type(e).__name__}") + import traceback + + traceback.print_exc() + + # Let's check what spatial size the encoder actually produces + print("\n=== ENCODER SPATIAL SIZE ANALYSIS ===") + try: + x_stem = encoder.stem(x) + features = encoder.encoder(x_stem) + final_feat = features[-1] + actual_spatial_size = final_feat.shape[-1] # Assuming square + print(f"Actual spatial size from encoder: {actual_spatial_size}") + print(f"Expected spatial size for decoder: {latent_spatial_size}") + + if actual_spatial_size != latent_spatial_size: + print( + f"✗ MISMATCH: Encoder produces {actual_spatial_size}x{actual_spatial_size}, decoder expects {latent_spatial_size}x{latent_spatial_size}" + ) + print(f" Suggested fix: Set latent_spatial_size={actual_spatial_size}") + + except Exception as inner_e: + print(f"Error in spatial size analysis: {inner_e}") + + +if __name__ == "__main__": + debug_vae_dimensions() +# %% diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py b/applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py new file mode 100644 index 000000000..9b52fab5b --- /dev/null +++ b/applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import torch +from viscy.representation.vae import VaeEncoder + +# Test the stem layer computation +z_stack_depth = 32 +encoder = VaeEncoder( + backbone="resnet50", + in_channels=1, + in_stack_depth=z_stack_depth, + embedding_dim=256, + stem_kernel_size=(8, 4, 4), + stem_stride=(8, 4, 4), +) + +# Create test input +x = torch.randn(1, 1, z_stack_depth, 192, 192) +print(f"Input shape: {x.shape}") + +# Test stem output +stem_output = encoder.stem(x) +print(f"Stem output shape: {stem_output.shape}") + +# Check what the ResNet expects +import timm +resnet50 = timm.create_model("resnet50", pretrained=True, features_only=True) +print(f"ResNet50 conv1 expects input channels: {resnet50.conv1.in_channels}") +print(f"ResNet50 conv1 produces output channels: {resnet50.conv1.out_channels}") + +# Test if we can pass stem output to ResNet +try: + # Remove conv1 like in the encoder + resnet50.conv1 = torch.nn.Identity() + resnet_output = resnet50(stem_output) + print(f"ResNet output shapes: {[f.shape for f in resnet_output]}") + print("SUCCESS: No channel mismatch!") +except Exception as e: + print(f"ERROR: {e}") \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py new file mode 100644 index 000000000..d23fee85e --- /dev/null +++ b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py @@ -0,0 +1,206 @@ +# %% +import torch +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from monai.transforms.intensity.dictionary import ( + RandAdjustContrastd, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + ScaleIntensityRangePercentilesd, +) +from monai.transforms.spatial.dictionary import RandAffined + +from viscy.data.triplet import TripletDataModule +from viscy.representation.engine import VaeModule +from viscy.representation.vae import VaeDecoder, VaeEncoder +from viscy.trainer import VisCyTrainer +from viscy.transforms import ( + NormalizeSampled, +) + + +# %% +def channel_augmentations(processing_channel: str): + return [ + RandAffined( + keys=[processing_channel], + prob=0.8, + scale_range=[0, 0.2, 0.2], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.01, 0.01], + padding_mode="zeros", + ), + RandAdjustContrastd( + keys=[processing_channel], + prob=0.5, + gamma=(0.8, 1.2), + ), + RandScaleIntensityd( + keys=[processing_channel], + prob=0.5, + factors=0.5, + ), + RandGaussianSmoothd( + keys=[processing_channel], + prob=0.5, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + ), + RandGaussianNoised( + keys=[processing_channel], + prob=0.5, + mean=0.0, + std=0.2, + ), + ] + + +# %% +def channel_normalization( + phase_channel: str | None = None, + fl_channel: str | None = None, +): + if phase_channel: + return [ + NormalizeSampled( + keys=[phase_channel], + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) + ] + elif fl_channel: + return [ + ScaleIntensityRangePercentilesd( + keys=[fl_channel], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + ) + ] + else: + raise NotImplementedError("Either phase_channel or fl_channel must be provided") + + +if __name__ == "__main__": + seed_everything(42) + + # use tensor cores on Ampere GPUs (24-bit tensorfloat matmul) + torch.set_float32_matmul_precision("high") + + initial_yx_patch_size = (384, 384) + final_yx_patch_size = (192, 192) + batch_size = 64 + num_workers = 12 + time_interval = 1 + z_stack_depth = 32 + + print("Creating model components...") + + # Create encoder with debug info + encoder = VaeEncoder( + backbone="resnet50", + in_channels=1, + in_stack_depth=z_stack_depth, + embedding_dim=256, + stem_kernel_size=(8, 4, 4), + stem_stride=(8, 4, 4), + ) + print(f"Encoder created successfully") + + # Test encoder forward pass + test_input = torch.randn(1, 1, z_stack_depth, 192, 192) + try: + encoder_output = encoder(test_input) + print(f"Encoder test passed: {encoder_output.embedding.shape}") + except Exception as e: + print(f"Encoder test failed: {e}") + exit(1) + + # Create decoder + decoder = VaeDecoder( + decoder_channels=[1024, 512, 256, 128], + latent_dim=256, + out_channels=1, + out_stack_depth=z_stack_depth, + latent_spatial_size=3, + head_expansion_ratio=2, + head_pool=False, + upsample_mode="deconv", + conv_blocks=2, + norm_name="batch", + upsample_pre_conv=None, + strides=[2, 2, 2, 2], + ) + print(f"Decoder created successfully") + + # Create VaeModule + model = VaeModule( + encoder=encoder, + decoder=decoder, + example_input_array_shape=(1, 1, z_stack_depth, 192, 192), + latent_dim=256, + beta=3.0, + lr=2e-4, + ) + print(f"VaeModule created successfully") + + # Test model forward pass + try: + model_output = model(test_input) + print(f"Model test passed: loss={model_output['loss']}") + except Exception as e: + print(f"Model test failed: {e}") + exit(1) + + # Create data module + print("Creating data module...") + dm = TripletDataModule( + data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/2-assemble/2024_10_16_A549_SEC61_ZIKV_DENV.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", + source_channel=["Phase3D"], + z_range=(5, 37), + initial_yx_patch_size=initial_yx_patch_size, + final_yx_patch_size=final_yx_patch_size, + batch_size=batch_size, + num_workers=num_workers, + time_interval=time_interval, + augmentations=channel_augmentations("Phase3D"), + normalizations=channel_normalization(phase_channel="Phase3D"), + fit_include_wells=["B/3", "B/4", "C/3", "C/4"], + ) + print(f"DataModule created successfully") + + # Create trainer + trainer = VisCyTrainer( + accelerator="gpu", + strategy="ddp", + devices=4, + num_nodes=1, + precision="16-mixed", + # fast_dev_run=True, + max_epochs=100, + log_every_n_steps=10, + check_val_every_n_epoch=1, + logger=TensorBoardLogger( + save_dir="/hpc/projects/organelle_phenotyping/models/SEC61B/vae", + name="betavae_phase3D_ddp", + version="beta_3_16slice", + ), + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint( + monitor="loss/val", save_top_k=5, save_last=True, every_n_epochs=1 + ), + ], + use_distributed_sampler=True, + ) + + print("Starting training...") + trainer.fit(model, dm) + +# %% diff --git a/viscy/representation/disentanglement_metrics.py b/viscy/representation/disentanglement_metrics.py new file mode 100644 index 000000000..195bcbabf --- /dev/null +++ b/viscy/representation/disentanglement_metrics.py @@ -0,0 +1,366 @@ +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy import stats +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score + +_logger = logging.getLogger(__name__) + + +class DisentanglementMetrics: + """ + Disentanglement metrics for VAE evaluation on microscopy data. + + Implements MIG, SAP, DCI, and Beta-VAE score metrics for evaluating + how well the VAE learns disentangled representations. + """ + + def __init__(self, device: str = "cuda"): + self.device = device + + def compute_all_metrics( + self, + vae_model: nn.Module, + dataloader: torch.utils.data.DataLoader, + max_samples: int = 1000, + n_factors: Optional[int] = None, + ) -> Dict[str, float]: + """ + Compute all disentanglement metrics. + + Args: + vae_model: Trained VAE model + dataloader: DataLoader with labeled data + max_samples: Maximum number of samples to use + n_factors: Number of known generative factors (if available) + + Returns: + Dictionary of metric scores + """ + latents, factors = self._extract_latents_and_factors( + vae_model, dataloader, max_samples + ) + + metrics = {} + + # MIG Score + metrics["MIG"] = self.compute_mig(latents, factors) + + # SAP Score + metrics["SAP"] = self.compute_sap(latents, factors) + + # DCI Scores + dci_scores = self.compute_dci(latents, factors) + metrics.update(dci_scores) + + # Beta-VAE Score (unsupervised) + metrics["Beta_VAE_Score"] = self.compute_beta_vae_score( + vae_model, dataloader, max_samples + ) + + return metrics + + def _extract_latents_and_factors( + self, + vae_model: nn.Module, + dataloader: torch.utils.data.DataLoader, + max_samples: int, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract latent representations and generative factors. + + For microscopy data, we'll extract simple visual factors like: + - Cell size (approximated from pixel intensity) + - Cell count (approximated from connected components) + - Brightness (mean intensity) + - Contrast (std of intensity) + """ + vae_model.eval() + latents = [] + factors = [] + + samples_collected = 0 + + with torch.no_grad(): + for batch in dataloader: + if samples_collected >= max_samples: + break + + x = batch["anchor"].to(self.device) + batch_size = x.shape[0] + + # Extract latent representations + model_output = vae_model(x) + z = ( + model_output.z + if hasattr(model_output, "z") + else model_output.embedding + ) + latents.append(z.cpu().numpy()) + + # Extract visual factors from images + batch_factors = self._extract_visual_factors(x.cpu()) + factors.append(batch_factors) + + samples_collected += batch_size + + latents = np.vstack(latents)[:max_samples] + factors = np.vstack(factors)[:max_samples] + + return latents, factors + + def _extract_visual_factors(self, images: torch.Tensor) -> np.ndarray: + """ + Extract visual factors from microscopy images. + + Args: + images: Batch of images (B, C, D, H, W) + + Returns: + Array of shape (B, n_factors) with extracted factors + """ + batch_size = images.shape[0] + factors = [] + + for i in range(batch_size): + img = images[i].numpy() # (C, D, H, W) + + # Take middle z-slice for 2D analysis + mid_z = img.shape[1] // 2 + img_2d = img[:, mid_z, :, :] # (C, H, W) + + # Factor 1: Brightness (mean intensity) + brightness = np.mean(img_2d) + + # Factor 2: Contrast (std of intensity) + contrast = np.std(img_2d) + + # Factor 3: Cell size (approximated by high-intensity regions) + binary_mask = img_2d[0] > np.percentile(img_2d[0], 75) + cell_size = np.sum(binary_mask) / (img_2d.shape[1] * img_2d.shape[2]) + + # Factor 4: Texture complexity (gradient magnitude) + grad_x = np.gradient(img_2d[0], axis=1) + grad_y = np.gradient(img_2d[0], axis=0) + texture = np.mean(np.sqrt(grad_x**2 + grad_y**2)) + + factors.append([brightness, contrast, cell_size, texture]) + + return np.array(factors) + + def compute_mig(self, latents: np.ndarray, factors: np.ndarray) -> float: + """ + Compute Mutual Information Gap (MIG). + + MIG = (1/K) * Σ_k (I(z_j*; v_k) - I(z_j'; v_k)) + where j* = argmax_j I(z_j; v_k) and j' = argmax_{j≠j*} I(z_j; v_k) + """ + + def mutual_info_continuous(x, y): + """Estimate mutual information between continuous variables.""" + # Discretize continuous variables + x_discrete = self._discretize(x) + y_discrete = self._discretize(y) + + # Compute mutual information + return self._mutual_info_discrete(x_discrete, y_discrete) + + n_factors = factors.shape[1] + n_latents = latents.shape[1] + + # Compute mutual information matrix + mi_matrix = np.zeros((n_latents, n_factors)) + + for i in range(n_latents): + for j in range(n_factors): + mi_matrix[i, j] = mutual_info_continuous(latents[:, i], factors[:, j]) + + # Compute MIG + mig = 0.0 + for j in range(n_factors): + mi_values = mi_matrix[:, j] + sorted_indices = np.argsort(mi_values)[::-1] + + if len(sorted_indices) > 1: + gap = mi_values[sorted_indices[0]] - mi_values[sorted_indices[1]] + mig += gap / np.max(mi_values) if np.max(mi_values) > 0 else 0 + + return mig / n_factors + + def compute_sap(self, latents: np.ndarray, factors: np.ndarray) -> float: + """ + Compute Attribute Predictability Score (SAP). + + SAP measures how well a simple classifier can predict factors from latents. + """ + n_factors = factors.shape[1] + scores = [] + + for i in range(n_factors): + # Discretize factor for classification + factor_discrete = self._discretize(factors[:, i], n_bins=10) + + # Train classifier + clf = LogisticRegression(random_state=42, max_iter=1000) + clf.fit(latents, factor_discrete) + + # Evaluate + pred = clf.predict(latents) + score = accuracy_score(factor_discrete, pred) + scores.append(score) + + return np.mean(scores) + + def compute_dci(self, latents: np.ndarray, factors: np.ndarray) -> Dict[str, float]: + """ + Compute Disentanglement, Completeness, and Informativeness (DCI). + """ + n_factors = factors.shape[1] + n_latents = latents.shape[1] + + # Train predictors for each factor + importance_matrix = np.zeros((n_factors, n_latents)) + + for i in range(n_factors): + # Discretize factor + factor_discrete = self._discretize(factors[:, i], n_bins=10) + + # Train random forest to get feature importance + rf = RandomForestClassifier(n_estimators=100, random_state=42) + rf.fit(latents, factor_discrete) + + importance_matrix[i, :] = rf.feature_importances_ + + # Normalize importance matrix + importance_matrix = importance_matrix / ( + np.sum(importance_matrix, axis=1, keepdims=True) + 1e-8 + ) + + # Compute DCI metrics + disentanglement = self._compute_disentanglement(importance_matrix) + completeness = self._compute_completeness(importance_matrix) + informativeness = self._compute_informativeness(importance_matrix) + + return { + "DCI_Disentanglement": disentanglement, + "DCI_Completeness": completeness, + "DCI_Informativeness": informativeness, + } + + def compute_beta_vae_score( + self, + vae_model: nn.Module, + dataloader: torch.utils.data.DataLoader, + max_samples: int, + ) -> float: + """ + Compute Beta-VAE score (unsupervised disentanglement metric). + + Measures how well individual latent dimensions affect reconstruction + when perturbed independently. + """ + vae_model.eval() + scores = [] + + samples_collected = 0 + + with torch.no_grad(): + for batch in dataloader: + if samples_collected >= max_samples: + break + + x = batch["anchor"].to(self.device) + batch_size = x.shape[0] + + # Get latent representation + model_output = vae_model(x) + z = ( + model_output.z + if hasattr(model_output, "z") + else model_output.embedding + ) + + # Compute baseline reconstruction + baseline_recon = vae_model.decoder(z) + if hasattr(baseline_recon, "reconstruction"): + baseline_recon = baseline_recon.reconstruction + + # Perturb each latent dimension + for dim in range(z.shape[1]): + z_perturbed = z.clone() + z_perturbed[:, dim] += torch.randn_like(z_perturbed[:, dim]) * 0.5 + + # Get perturbed reconstruction + perturbed_recon = vae_model.decoder(z_perturbed) + if hasattr(perturbed_recon, "reconstruction"): + perturbed_recon = perturbed_recon.reconstruction + + # Compute reconstruction difference + diff = F.mse_loss(baseline_recon, perturbed_recon, reduction="none") + diff = diff.mean( + dim=(1, 2, 3, 4) + ) # Average over spatial dimensions + + # Score is inverse of reconstruction change + score = 1.0 / (1.0 + diff.mean().item()) + scores.append(score) + + samples_collected += batch_size + + return np.mean(scores) + + def _discretize(self, x: np.ndarray, n_bins: int = 20) -> np.ndarray: + """Discretize continuous variable into bins.""" + return np.digitize(x, np.linspace(x.min(), x.max(), n_bins)) + + def _mutual_info_discrete(self, x: np.ndarray, y: np.ndarray) -> float: + """Compute mutual information between discrete variables.""" + # Joint histogram + xy = np.stack([x, y], axis=1) + unique_xy, counts_xy = np.unique(xy, axis=0, return_counts=True) + p_xy = counts_xy / counts_xy.sum() + + # Marginal histograms + unique_x, counts_x = np.unique(x, return_counts=True) + p_x = counts_x / counts_x.sum() + + unique_y, counts_y = np.unique(y, return_counts=True) + p_y = counts_y / counts_y.sum() + + # Compute MI + mi = 0.0 + for i, (x_val, y_val) in enumerate(unique_xy): + p_joint = p_xy[i] + p_x_marginal = p_x[unique_x == x_val][0] + p_y_marginal = p_y[unique_y == y_val][0] + + if p_joint > 0 and p_x_marginal > 0 and p_y_marginal > 0: + mi += p_joint * np.log(p_joint / (p_x_marginal * p_y_marginal)) + + return mi + + def _compute_disentanglement(self, importance_matrix: np.ndarray) -> float: + """Compute disentanglement score from importance matrix.""" + disentanglement = 0.0 + for i in range(importance_matrix.shape[0]): + if np.sum(importance_matrix[i]) > 0: + disentanglement += 1.0 - stats.entropy(importance_matrix[i]) + return disentanglement / importance_matrix.shape[0] + + def _compute_completeness(self, importance_matrix: np.ndarray) -> float: + """Compute completeness score from importance matrix.""" + completeness = 0.0 + for j in range(importance_matrix.shape[1]): + if np.sum(importance_matrix[:, j]) > 0: + completeness += 1.0 - stats.entropy(importance_matrix[:, j]) + return completeness / importance_matrix.shape[1] + + def _compute_informativeness(self, importance_matrix: np.ndarray) -> float: + """Compute informativeness score from importance matrix.""" + return np.mean(np.sum(importance_matrix, axis=1)) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 7a35d93f1..5ab7fac36 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -11,6 +11,9 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder +from viscy.representation.disentanglement_metrics import DisentanglementMetrics +from viscy.representation.vae import VaeDecoder, VaeEncoder +from viscy.representation.vae_logging import BetaVaeLogger from viscy.utils.log_images import detach_sample, render_images _logger = logging.getLogger("lightning.pytorch") @@ -245,3 +248,379 @@ def predict_step( "projections": projections, "index": batch["index"], } + + +class BetaVAE(nn.Module): + """Native Beta-VAE implementation with reparameterization trick. + + Parameters + ---------- + encoder : nn.Module + Encoder model + decoder : nn.Module + Decoder model + latent_dim : int + Latent dimension + beta : float + Beta parameter for the Beta-VAE loss. Default is 1.0 equivalent to a VAE. + + Returns + ------- + dict + Dictionary containing the reconstruction, latent, mu, logvar, recon_loss, kl_loss, reg_loss, and total_loss. + """ + + def __init__( + self, encoder: nn.Module, decoder: nn.Module, latent_dim: int, beta: float = 1.0 + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.latent_dim = latent_dim + self.beta = beta + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """Reparameterization trick: sample from N(mu, var) using N(0,1).""" + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x: Tensor) -> dict: + """Forward pass through Beta-VAE.""" + # Encode + encoder_output = self.encoder(x) + mu = encoder_output.embedding + logvar = encoder_output.log_covariance + + # Reparameterize + z = self.reparameterize(mu, logvar) + + # Decode + reconstruction = self.decoder(z) + + # Compute losses + recon_loss = F.mse_loss(reconstruction, x, reduction="sum") / x.size(0) + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) + total_loss = recon_loss + self.beta * kl_loss + + return { + "recon_x": reconstruction, + "z": z, + "mu": mu, + "logvar": logvar, + "recon_loss": recon_loss, + "kl_loss": kl_loss, + "reg_loss": kl_loss, # For compatibility with logging + "loss": total_loss, + } + + +class VaeModule(LightningModule): + """Native PyTorch Lightning Beta-VAE implementation.""" + + def __init__( + self, + encoder: VaeEncoder, + decoder: VaeDecoder, + latent_dim: int = 128, + beta: float = 1.0, + lr: float = 1e-3, + log_batches_per_epoch: int = 8, + log_samples_per_batch: int = 1, + example_input_array_shape: Sequence[int] = (1, 2, 30, 256, 256), + compute_disentanglement: bool = True, + disentanglement_frequency: int = 10, + # Deprecated parameters for backward compatibility + model_name: str = "BetaVAE", + loss: str = "mse", + ): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + self.latent_dim = latent_dim + self.beta = beta + self.lr = lr + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + + self.example_input_array = torch.rand(*example_input_array_shape) + self.compute_disentanglement = compute_disentanglement + self.disentanglement_frequency = disentanglement_frequency + + # Create the Beta-VAE model + self.model = BetaVAE( + encoder=self.encoder, decoder=self.decoder, latent_dim=latent_dim, beta=beta + ) + + # Initialize tracking lists + self.training_step_outputs = [] + self.validation_step_outputs = [] + + # Enhanced β-VAE logging - initialize early + self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) + + # Note: DisentanglementMetrics will be initialized in setup() when device is available + self.disentanglement_metrics = None + + def setup(self, stage: str = None): + """Setup hook to initialize device-dependent components.""" + super().setup(stage) + # Initialize DisentanglementMetrics after device is available + if self.disentanglement_metrics is None: + self.disentanglement_metrics = DisentanglementMetrics(device=self.device) + + def forward(self, x: Tensor) -> dict: + """Forward pass through VAE model.""" + return self.model(x) + + def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Training step with VAE loss computation.""" + x = batch["anchor"] + model_output = self(x) + + # Beta-VAE computes loss internally + loss = model_output["loss"] + + # Log basic metrics + self._log_metrics( + loss=loss, + recon_loss=model_output["recon_loss"], + kl_loss=model_output["kl_loss"], + stage="train", + ) + + # Log enhanced β-VAE metrics + self.vae_logger.log_enhanced_metrics( + lightning_module=self, model_output=model_output, batch=batch, stage="train" + ) + + # Log samples + self._log_step_samples(batch_idx, x, model_output["recon_x"], "train") + + return loss + + def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Validation step with VAE loss computation.""" + x = batch["anchor"] + model_output = self(x) + + # Beta-VAE computes loss internally + loss = model_output["loss"] + + # Log basic metrics + self._log_metrics( + loss=loss, + recon_loss=model_output["recon_loss"], + kl_loss=model_output["kl_loss"], + stage="val", + ) + + # Log enhanced β-VAE metrics + self.vae_logger.log_enhanced_metrics( + lightning_module=self, model_output=model_output, batch=batch, stage="val" + ) + + # Log samples + self._log_step_samples(batch_idx, x, model_output["recon_x"], "val") + + return loss + + def _log_metrics(self, loss, recon_loss, kl_loss, stage: Literal["train", "val"]): + """Log VAE-specific metrics.""" + metrics = { + f"loss/{stage}": loss, + f"recon_loss/{stage}": recon_loss, + f"kl_loss/{stage}": kl_loss, + } + + self.log_dict( + metrics, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + def _log_step_samples( + self, batch_idx, original, reconstruction, stage: Literal["train", "val"] + ): + """Log sample reconstructions.""" + if batch_idx < self.log_batches_per_epoch: + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + + # Store samples for epoch end logging + samples = { + "original": original.detach().cpu()[: self.log_samples_per_batch], + "reconstruction": reconstruction.detach().cpu()[ + : self.log_samples_per_batch + ], + } + output_list.append(samples) + + def _log_samples(self, key: str, samples_list: list): + """Log reconstruction samples at epoch end.""" + if len(samples_list) > 0: + # Take middle z-slice for visualization + mid_z = samples_list[0]["original"].shape[2] // 2 + + originals = [] + reconstructions = [] + + for sample in samples_list: + orig = sample["original"][:, :, mid_z].numpy() + recon = sample["reconstruction"][:, :, mid_z].numpy() + + originals.extend([orig[i] for i in range(orig.shape[0])]) + reconstructions.extend([recon[i] for i in range(recon.shape[0])]) + + # Create grid with originals and reconstructions + combined = [] + for orig, recon in zip(originals[:4], reconstructions[:4]): + combined.append([orig, recon]) + + grid = render_images(combined, cmaps=["gray", "gray"]) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def on_train_epoch_end(self) -> None: + """Log training samples at epoch end.""" + super().on_train_epoch_end() + self._log_samples("train_reconstructions", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self) -> None: + """Log validation samples at epoch end.""" + super().on_validation_epoch_end() + self._log_samples("val_reconstructions", self.validation_step_outputs) + self.validation_step_outputs = [] + + # Compute disentanglement metrics periodically + if ( + self.compute_disentanglement + and self.current_epoch % self.disentanglement_frequency == 0 + and self.current_epoch > 0 + ): + self._compute_and_log_disentanglement_metrics() + + # Log enhanced β-VAE visualizations periodically + if self.current_epoch % 20 == 0 and self.current_epoch > 0: + self._log_enhanced_visualizations() + + def _compute_and_log_disentanglement_metrics(self): + """Compute and log disentanglement metrics.""" + try: + # Check if disentanglement metrics are initialized + if self.disentanglement_metrics is None: + _logger.warning( + "DisentanglementMetrics not initialized, skipping computation" + ) + return + + # Get validation dataloader + val_dataloader = ( + self.trainer.val_dataloaders[0] + if self.trainer.val_dataloaders + else None + ) + + if val_dataloader is None: + _logger.warning( + "No validation dataloader available for disentanglement metrics" + ) + return + + # Compute metrics + _logger.info( + f"Computing disentanglement metrics at epoch {self.current_epoch}" + ) + metrics = self.disentanglement_metrics.compute_all_metrics( + vae_model=self, + dataloader=val_dataloader, + max_samples=200, + ) + + # Log metrics + for metric_name, metric_value in metrics.items(): + self.log( + f"disentanglement/{metric_name}", + metric_value, + on_step=False, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + _logger.info(f"Disentanglement metrics: {metrics}") + + except Exception as e: + _logger.error(f"Error computing disentanglement metrics: {e}") + # Continue training even if metrics fail + + def _log_enhanced_visualizations(self): + """Log enhanced β-VAE visualizations.""" + try: + # Get validation dataloader + val_dataloader = ( + self.trainer.val_dataloaders[0] + if self.trainer.val_dataloaders + else None + ) + + if val_dataloader is None: + _logger.warning("No validation dataloader available for visualizations") + return + + _logger.info( + f"Logging enhanced β-VAE visualizations at epoch {self.current_epoch}" + ) + + # Log latent traversals + self.vae_logger.log_latent_traversal( + lightning_module=self, n_dims=8, n_steps=11 + ) + + # Log latent interpolations + self.vae_logger.log_latent_interpolation( + lightning_module=self, n_pairs=3, n_steps=11 + ) + + # Log factor traversal matrix + self.vae_logger.log_factor_traversal_matrix( + lightning_module=self, n_dims=8, n_steps=7 + ) + + # Log latent space visualization (every 40 epochs to avoid overhead) + if self.current_epoch % 40 == 0: + self.vae_logger.log_latent_space_visualization( + lightning_module=self, + dataloader=val_dataloader, + max_samples=500, + method="pca", + ) + + except Exception as e: + _logger.error(f"Error logging enhanced visualizations: {e}") + # Continue training even if visualizations fail + + def configure_optimizers(self): + """Configure optimizer for VAE training.""" + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + return optimizer + + def predict_step(self, batch: TripletSample, batch_idx, dataloader_idx=0) -> dict: + """Prediction step for VAE inference.""" + x = batch["anchor"] + model_output = self(x) + + return { + "latent": model_output["z"], + "reconstruction": model_output["recon_x"], + "index": batch["index"], + } diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py new file mode 100644 index 000000000..6078a1769 --- /dev/null +++ b/viscy/representation/vae.py @@ -0,0 +1,257 @@ +from types import SimpleNamespace +from typing import Callable, Literal + +import timm +from monai.networks.blocks import ResidualUnit, UpSample +from monai.networks.blocks.dynunet_block import get_conv_layer +from pythae.models.nn import BaseDecoder, BaseEncoder +from torch import Tensor, nn + +from viscy.unet.networks.unext2 import ( + PixelToVoxelHead, + StemDepthtoChannels, + UNeXt2Stem, +) + + +class VaeUpStage(nn.Module): + """VAE upsampling stage without skip connections.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: int, + mode: Literal["deconv", "pixelshuffle"], + conv_blocks: int, + norm_name: str, + upsample_pre_conv: Literal["default"] | Callable | None, + ) -> None: + super().__init__() + spatial_dims = 2 + + if mode == "deconv": + self.upsample = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + stride=scale_factor, + kernel_size=scale_factor, + norm=norm_name, + is_transposed=True, + ) + # Simple conv blocks for deconv mode + self.conv = nn.Sequential( + ResidualUnit( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + norm=norm_name, + ), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + ) + elif mode == "pixelshuffle": + mid_channels = in_channels // scale_factor**2 + self.upsample = UpSample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=mid_channels, + scale_factor=scale_factor, + mode=mode, + pre_conv=upsample_pre_conv, + apply_pad_pool=False, + ) + conv_layers = [] + current_channels = mid_channels + + for i in range(conv_blocks): + block_out_channels = out_channels + conv_layers.extend( + [ + nn.Conv2d( + current_channels, + block_out_channels, + kernel_size=3, + padding=1, + ), + ( + nn.BatchNorm2d(block_out_channels) + if norm_name == "batch" + else nn.InstanceNorm2d(block_out_channels) + ), + nn.ReLU(inplace=True), + ] + ) + current_channels = block_out_channels + + self.conv = nn.Sequential(*conv_layers) + + def forward(self, inp: Tensor) -> Tensor: + """ + :param Tensor inp: Low resolution features + :return Tensor: High resolution features + """ + inp = self.upsample(inp) + return self.conv(inp) + + +class VaeEncoder(nn.Module): + """VAE encoder for microscopy data with 3D to 2D conversion.""" + + def __init__( + self, + backbone: str = "resnet50", + in_channels: int = 2, + in_stack_depth: int = 32, + embedding_dim: int = 128, + stem_kernel_size: tuple[int, int, int] = (8, 4, 4), + stem_stride: tuple[int, int, int] = (8, 2, 2), + drop_path_rate: float = 0.0, + ): + super().__init__() + self.backbone = backbone + self.embedding_dim = embedding_dim + + encoder = timm.create_model( + backbone, + pretrained=False, + features_only=True, + drop_path_rate=drop_path_rate, + ) + + if "resnet" in backbone: + in_channels_encoder = encoder.conv1.out_channels + # remove the original 3D stem for rgb imges to support the multichannel 3D input + encoder.conv1 = nn.Identity() + out_channels_encoder = encoder.feature_info.channels()[-1] + else: + raise ValueError(f"Backbone {backbone} not supported") + + # Stem for 3d multichannel and to convert 3D to 2D + self.stem = StemDepthtoChannels( + in_channels=in_channels, + in_stack_depth=in_stack_depth, + in_channels_encoder=in_channels_encoder, + stem_kernel_size=stem_kernel_size, + stem_stride=stem_stride, + ) + self.encoder = encoder + + self.global_pool = nn.AdaptiveAvgPool2d(1) + + self.fc_mu = nn.Linear(out_channels_encoder, embedding_dim) + self.fc_logvar = nn.Linear(out_channels_encoder, embedding_dim) + + def forward(self, x: Tensor) -> dict: + """Forward pass returning VAE encoder outputs.""" + x = self.stem(x) + + features = self.encoder(x) + + # Take highest resolution features + x = features[-1] + x = self.global_pool(x) + x = x.flatten(1) + + # VAE outputs + mu = self.fc_mu(x) + logvar = self.fc_logvar(x) + + return SimpleNamespace(embedding=mu, log_covariance=logvar) + + +class VaeDecoder(nn.Module): + """VAE decoder for microscopy data with 2D to 3D conversion.""" + + def __init__( + self, + decoder_channels: list[int] = [1024, 512, 256, 128], + latent_dim: int = 128, + out_channels: int = 2, + out_stack_depth: int = 20, + latent_spatial_size: int = 8, + head_expansion_ratio: int = 4, + head_pool: bool = False, + upsample_mode: Literal[ + "deconv", "pixelshuffle" + ] = "pixelshuffle", # Better quality + conv_blocks: int = 2, + norm_name: str = "batch", + upsample_pre_conv: Literal["default"] | Callable | None = None, + strides: list[int] | None = None, + ): + super().__init__() + self.out_channels = out_channels + self.out_stack_depth = out_stack_depth + self.latent_spatial_size = latent_spatial_size + + head_channels = ( + (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio + ) + + # Copy decoder_channels to avoid modifying the original list + decoder_channels_with_head = decoder_channels.copy() + [head_channels] + + # Set optimal default strides for ResNet50 if not provided + num_stages = len(decoder_channels_with_head) - 1 + if strides is None: + if ( + num_stages == 4 + ): # Default [1024, 512, 256, 128] + head = 5 channels, 4 stages + strides = [2, 2, 2, 4] # 8→16→32→64→256 (32x total upsampling) + else: + strides = [2] * num_stages # Fallback to uniform 2x upsampling + elif len(strides) != num_stages: + raise ValueError( + f"Length of strides ({len(strides)}) must match number of stages ({num_stages})" + ) + + # Project latent vector to first feature map + self.latent_proj = nn.Linear( + latent_dim, + decoder_channels_with_head[0] * latent_spatial_size * latent_spatial_size, + ) + + # Build the decoder stages + self.decoder_stages = nn.ModuleList() + + for i in range(num_stages): + in_channels = decoder_channels_with_head[i] + out_channels_stage = decoder_channels_with_head[i + 1] + stride = strides[i] + + stage = VaeUpStage( + in_channels=in_channels, + out_channels=out_channels_stage, + scale_factor=stride, + mode=upsample_mode, + conv_blocks=conv_blocks, + norm_name=norm_name, + upsample_pre_conv=upsample_pre_conv, + ) + self.decoder_stages.append(stage) + + # Head to convert back to 3D (no final_conv needed - last stage outputs head_channels) + self.head = PixelToVoxelHead( + in_channels=head_channels, + out_channels=self.out_channels, + out_stack_depth=self.out_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=head_pool, + ) + + def forward(self, z: Tensor) -> dict: + """Forward pass converting latent to 3D output.""" + batch_size = z.shape[0] + + # Project latent to feature map + x = self.latent_proj(z) + x = x.view(batch_size, -1, self.latent_spatial_size, self.latent_spatial_size) + + for stage in self.decoder_stages: + x = stage(x) + + # Last stage outputs head_channels directly - no final_conv needed + output = self.head(x) + + return output diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py new file mode 100644 index 000000000..44fdf810b --- /dev/null +++ b/viscy/representation/vae_logging.py @@ -0,0 +1,448 @@ +import io +import logging +from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import torch +import torch.nn.functional as F +from PIL import Image +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +from torchvision.utils import make_grid + +_logger = logging.getLogger(__name__) + + +class BetaVaeLogger: + """ + Enhanced logging utilities for β-VAE training with TensorBoard. + + Provides comprehensive logging of β-VAE specific metrics, visualizations, + and latent space analysis for microscopy data. + """ + + def __init__(self, latent_dim: int = 128): + self.latent_dim = latent_dim + + def log_enhanced_metrics( + self, lightning_module, model_output: dict, batch: dict, stage: str = "train" + ): + """ + Log enhanced β-VAE metrics. + + Args: + lightning_module: Lightning module instance + model_output: VAE model output + batch: Input batch + stage: Training stage ("train" or "val") + """ + # Extract components + x = batch["anchor"] + # Handle both Pythae dict format and object format + if isinstance(model_output, dict): + z = model_output["z"] + recon_x = model_output["recon_x"] + recon_loss = model_output["recon_loss"] + kl_loss = model_output["reg_loss"] # Pythae uses 'reg_loss' for KL + else: + z = model_output.z if hasattr(model_output, "z") else model_output.embedding + recon_x = ( + model_output.recon_x + if hasattr(model_output, "recon_x") + else model_output.reconstruction + ) + recon_loss = model_output.recon_loss + kl_loss = model_output.kl_loss + + # Get β from model config + beta = ( + lightning_module.model.model_config.beta + if hasattr(lightning_module.model, "model_config") + else 1.0 + ) + + # 1. Core VAE Loss Components (organized in one TensorBoard group) + total_loss = recon_loss + beta * kl_loss + kl_recon_ratio = kl_loss / (recon_loss + 1e-8) + + metrics = { + # Core loss components + f"loss_components/reconstruction_loss/{stage}": recon_loss, + f"loss_components/kl_loss/{stage}": kl_loss, + f"loss_components/weighted_kl_loss/{stage}": beta * kl_loss, + f"loss_components/total_loss/{stage}": total_loss, + f"loss_components/beta_value/{stage}": beta, + + # Loss analysis ratios + f"loss_analysis/kl_recon_ratio/{stage}": kl_recon_ratio, + f"loss_analysis/recon_contribution/{stage}": recon_loss / total_loss, + } + + # 2. Latent space statistics + latent_mean = torch.mean(z, dim=0) + latent_std = torch.std(z, dim=0) + + metrics.update( + { + f"latent_stats/mean_avg/{stage}": torch.mean(latent_mean), + f"latent_stats/std_avg/{stage}": torch.mean(latent_std), + f"latent_stats/mean_max/{stage}": torch.max(latent_mean), + f"latent_stats/std_max/{stage}": torch.max(latent_std), + } + ) + + # 3. Reconstruction quality metrics + mse_loss = F.mse_loss(recon_x, x) + mae_loss = F.l1_loss(recon_x, x) + + metrics.update( + { + f"reconstruction_quality/mse/{stage}": mse_loss, + f"reconstruction_quality/mae/{stage}": mae_loss, + } + ) + + # 4. Latent capacity metrics + active_dims = torch.sum(torch.var(z, dim=0) > 0.01) + variances = torch.var(z, dim=0) + effective_dim = torch.sum(variances) ** 2 / torch.sum(variances**2) + + metrics.update( + { + f"latent_capacity/active_dims/{stage}": active_dims, + f"latent_capacity/effective_dim/{stage}": effective_dim, + f"latent_capacity/utilization/{stage}": active_dims / self.latent_dim, + } + ) + + # Log all metrics + lightning_module.log_dict( + metrics, + on_step=False, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + # 5. Log latent dimension histograms (periodically) + if stage == "val" and lightning_module.current_epoch % 10 == 0: + self._log_latent_histograms(lightning_module, z, stage) + + def _log_latent_histograms(self, lightning_module, z: torch.Tensor, stage: str): + """Log histograms of latent dimensions.""" + z_np = z.detach().cpu().numpy() + + # Log first 16 dimensions to avoid clutter + n_dims_to_log = min(16, z_np.shape[1]) + + for i in range(n_dims_to_log): + lightning_module.logger.experiment.add_histogram( + f"latent_dim_{i}_distribution/{stage}", + z_np[:, i], + lightning_module.current_epoch, + ) + + def log_latent_traversal( + self, + lightning_module, + n_dims: int = 8, + n_steps: int = 11, + range_vals: Tuple[float, float] = (-3, 3), + ): + """ + Log latent space traversal visualizations. + + Args: + lightning_module: Lightning module instance + n_dims: Number of latent dimensions to traverse + n_steps: Number of steps in traversal + range_vals: Range of values to traverse + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + # Sample a base latent vector + z_base = torch.randn(1, self.latent_dim, device=lightning_module.device) + + # Traverse each dimension + for dim in range(min(n_dims, self.latent_dim)): + traversal_images = [] + + for val in np.linspace(range_vals[0], range_vals[1], n_steps): + z_modified = z_base.clone() + z_modified[0, dim] = val + + # Generate reconstruction + decoder_output = lightning_module.model.decoder(z_modified) + # Handle both Pythae dict format and object format + if isinstance(decoder_output, dict): + recon = decoder_output["reconstruction"] + else: + recon = ( + decoder_output.reconstruction + if hasattr(decoder_output, "reconstruction") + else decoder_output + ) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + traversal_images.append(img_2d) + + # Create grid + grid = make_grid( + torch.stack(traversal_images).unsqueeze(1), + nrow=n_steps, + normalize=True, + ) + + lightning_module.logger.experiment.add_image( + f"latent_traversal/dim_{dim}", grid, lightning_module.current_epoch + ) + + def log_latent_interpolation( + self, lightning_module, n_pairs: int = 3, n_steps: int = 11 + ): + """ + Log latent space interpolation between random pairs. + + Args: + lightning_module: Lightning module instance + n_pairs: Number of interpolation pairs + n_steps: Number of interpolation steps + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + for pair_idx in range(n_pairs): + # Sample two random latent vectors + z1 = torch.randn(1, self.latent_dim, device=lightning_module.device) + z2 = torch.randn(1, self.latent_dim, device=lightning_module.device) + + interp_images = [] + + for alpha in np.linspace(0, 1, n_steps): + z_interp = alpha * z1 + (1 - alpha) * z2 + + # Generate reconstruction + decoder_output = lightning_module.model.decoder(z_interp) + # Handle both Pythae dict format and object format + if isinstance(decoder_output, dict): + recon = decoder_output["reconstruction"] + else: + recon = ( + decoder_output.reconstruction + if hasattr(decoder_output, "reconstruction") + else decoder_output + ) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + interp_images.append(img_2d) + + # Create grid + grid = make_grid( + torch.stack(interp_images).unsqueeze(1), + nrow=n_steps, + normalize=True, + ) + + lightning_module.logger.experiment.add_image( + f"latent_interpolation/pair_{pair_idx}", + grid, + lightning_module.current_epoch, + ) + + def log_factor_traversal_matrix( + self, lightning_module, n_dims: int = 8, n_steps: int = 7 + ): + """ + Log factor traversal matrix showing effect of each latent dimension. + + Args: + lightning_module: Lightning module instance + n_dims: Number of latent dimensions to show + n_steps: Number of steps per dimension + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + # Base latent vector + z_base = torch.randn(1, self.latent_dim, device=lightning_module.device) + + matrix_rows = [] + + for dim in range(min(n_dims, self.latent_dim)): + row_images = [] + + for step in range(n_steps): + val = -3 + 6 * step / (n_steps - 1) # Range [-3, 3] + z_mod = z_base.clone() + z_mod[0, dim] = val + + # Generate reconstruction + decoder_output = lightning_module.model.decoder(z_mod) + # Handle both Pythae dict format and object format + if isinstance(decoder_output, dict): + recon = decoder_output["reconstruction"] + else: + recon = ( + decoder_output.reconstruction + if hasattr(decoder_output, "reconstruction") + else decoder_output + ) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + row_images.append(img_2d) + + matrix_rows.append(torch.stack(row_images)) + + # Create matrix grid + all_images = torch.cat(matrix_rows, dim=0) + grid = make_grid(all_images.unsqueeze(1), nrow=n_steps, normalize=True) + + lightning_module.logger.experiment.add_image( + "factor_traversal_matrix", grid, lightning_module.current_epoch + ) + + def log_latent_space_visualization( + self, lightning_module, dataloader, max_samples: int = 500, method: str = "pca" + ): + """ + Log 2D visualization of latent space using PCA or t-SNE. + + Args: + lightning_module: Lightning module instance + dataloader: DataLoader for samples + max_samples: Maximum samples to visualize + method: Visualization method ("pca" or "tsne") + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + # Collect latent representations + latents = [] + samples_collected = 0 + + with torch.no_grad(): + for batch in dataloader: + if samples_collected >= max_samples: + break + + x = batch["anchor"].to(lightning_module.device) + model_output = lightning_module(x) # Use lightning module forward + # Handle both Pythae dict format and object format + if isinstance(model_output, dict): + z = model_output["z"] + else: + z = ( + model_output.z + if hasattr(model_output, "z") + else model_output.embedding + ) + + latents.append(z.cpu().numpy()) + samples_collected += x.shape[0] + + if not latents: + return + + latents = np.vstack(latents)[:max_samples] + + # Apply dimensionality reduction + if method == "pca": + reducer = PCA(n_components=2) + reduced = reducer.fit_transform(latents) + title = f"PCA Latent Space (Variance: {reducer.explained_variance_ratio_.sum():.2f})" + elif method == "tsne": + reducer = TSNE(n_components=2, random_state=42) + reduced = reducer.fit_transform(latents) + title = "t-SNE Latent Space" + else: + _logger.warning(f"Unknown method: {method}") + return + + # Create scatter plot + plt.figure(figsize=(10, 8)) + plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=20) + plt.title(title) + plt.xlabel("Component 1") + plt.ylabel("Component 2") + plt.grid(True, alpha=0.3) + + # Convert to image + buf = io.BytesIO() + plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") + buf.seek(0) + + # Log to TensorBoard + img = Image.open(buf) + img_array = np.array(img) + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) / 255.0 + + lightning_module.logger.experiment.add_image( + f"latent_space_{method}", img_tensor, lightning_module.current_epoch + ) + + plt.close() + buf.close() + + def log_beta_schedule( + self, lightning_module, beta_schedule: Optional[callable] = None + ): + """ + Log β annealing schedule. + + Args: + lightning_module: Lightning module instance + beta_schedule: Function that returns β value for given epoch + """ + if beta_schedule is None: + # Default β schedule + max_epochs = lightning_module.trainer.max_epochs + epoch = lightning_module.current_epoch + + if epoch < max_epochs * 0.1: # Warm up + beta = 0.1 + elif epoch < max_epochs * 0.5: # Gradual increase + beta = 0.1 + (4.0 - 0.1) * (epoch - max_epochs * 0.1) / ( + max_epochs * 0.4 + ) + else: # Final β + beta = 4.0 + else: + beta = beta_schedule(lightning_module.current_epoch) + + lightning_module.log("beta_schedule", beta) + return beta From c976f981df98fb48d703e502748f100c441d6274 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 15:43:32 -0700 Subject: [PATCH 010/101] improving the logging for readability and drop pythae baseclasses --- .../benchmarking/DynaCLR/BetaVAE/test_run.py | 19 ++- viscy/data/triplet.py | 16 +- viscy/representation/engine.py | 151 +++++++++--------- viscy/representation/vae.py | 6 +- viscy/representation/vae_logging.py | 119 ++++++++------ 5 files changed, 177 insertions(+), 134 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py index d23fee85e..d45881ea8 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py @@ -97,7 +97,7 @@ def channel_normalization( batch_size = 64 num_workers = 12 time_interval = 1 - z_stack_depth = 32 + z_stack_depth = 16 print("Creating model components...") @@ -107,8 +107,8 @@ def channel_normalization( in_channels=1, in_stack_depth=z_stack_depth, embedding_dim=256, - stem_kernel_size=(8, 4, 4), - stem_stride=(8, 4, 4), + stem_kernel_size=(4, 4, 4), + stem_stride=(4, 4, 4), ) print(f"Encoder created successfully") @@ -130,7 +130,7 @@ def channel_normalization( latent_spatial_size=3, head_expansion_ratio=2, head_pool=False, - upsample_mode="deconv", + upsample_mode="pixelshuffle", conv_blocks=2, norm_name="batch", upsample_pre_conv=None, @@ -144,8 +144,11 @@ def channel_normalization( decoder=decoder, example_input_array_shape=(1, 1, z_stack_depth, 192, 192), latent_dim=256, - beta=3.0, + beta=0.5, lr=2e-4, + beta_schedule="linear", + beta_min=0.1, + beta_warmup_epochs=15, ) print(f"VaeModule created successfully") @@ -163,7 +166,7 @@ def channel_normalization( data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/2-assemble/2024_10_16_A549_SEC61_ZIKV_DENV.zarr", tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", source_channel=["Phase3D"], - z_range=(5, 37), + z_range=(10, 26), initial_yx_patch_size=initial_yx_patch_size, final_yx_patch_size=final_yx_patch_size, batch_size=batch_size, @@ -171,6 +174,8 @@ def channel_normalization( time_interval=time_interval, augmentations=channel_augmentations("Phase3D"), normalizations=channel_normalization(phase_channel="Phase3D"), + augment_validation=False, + return_negative=False, fit_include_wells=["B/3", "B/4", "C/3", "C/4"], ) print(f"DataModule created successfully") @@ -189,7 +194,7 @@ def channel_normalization( logger=TensorBoardLogger( save_dir="/hpc/projects/organelle_phenotyping/models/SEC61B/vae", name="betavae_phase3D_ddp", - version="beta_3_16slice", + version="beta_0.5_16slice", ), callbacks=[ LearningRateMonitor(logging_interval="step"), diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index c25a0fc74..24f193998 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -310,6 +310,7 @@ def __init__( num_workers: int = 8, normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], + augment_validation: bool = True, caching: bool = False, fit_include_wells: list[str] | None = None, fit_exclude_fovs: list[str] | None = None, @@ -348,6 +349,9 @@ def __init__( Normalization transforms, by default [] augmentations : list[MapTransform], optional Augmentation transforms, by default [] + augment_validation : bool, optional + Apply augmentations to validation data, by default True. + Set to False for VAE training where clean validation is needed. caching : bool, optional Whether to cache the dataset, by default False fit_include_wells : list[str], optional @@ -402,6 +406,7 @@ def __init__( self.include_track_ids = include_track_ids self.time_interval = time_interval self.return_negative = return_negative + self.augment_validation = augment_validation def _align_tracks_tables_with_positions( self, @@ -466,13 +471,18 @@ def _setup_fit(self, dataset_settings: dict): **dataset_settings, ) + # Choose transforms for validation based on augment_validation parameter + val_positive_transform = augment_transform if self.augment_validation else no_aug_transform + val_negative_transform = augment_transform if self.augment_validation else no_aug_transform + val_anchor_transform = anchor_transform if self.augment_validation else no_aug_transform + self.val_dataset = TripletDataset( positions=val_positions, tracks_tables=val_tracks_tables, initial_yx_patch_size=self.initial_yx_patch_size, - anchor_transform=anchor_transform, - positive_transform=augment_transform, - negative_transform=augment_transform, + anchor_transform=val_anchor_transform, + positive_transform=val_positive_transform, + negative_transform=val_negative_transform, fit=True, return_negative=self.return_negative, **dataset_settings, diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 5ab7fac36..5bf5d9336 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, TypedDict +from typing import Literal, Optional, Sequence, TypedDict import numpy as np import torch @@ -250,71 +250,6 @@ def predict_step( } -class BetaVAE(nn.Module): - """Native Beta-VAE implementation with reparameterization trick. - - Parameters - ---------- - encoder : nn.Module - Encoder model - decoder : nn.Module - Decoder model - latent_dim : int - Latent dimension - beta : float - Beta parameter for the Beta-VAE loss. Default is 1.0 equivalent to a VAE. - - Returns - ------- - dict - Dictionary containing the reconstruction, latent, mu, logvar, recon_loss, kl_loss, reg_loss, and total_loss. - """ - - def __init__( - self, encoder: nn.Module, decoder: nn.Module, latent_dim: int, beta: float = 1.0 - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.latent_dim = latent_dim - self.beta = beta - - def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: - """Reparameterization trick: sample from N(mu, var) using N(0,1).""" - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def forward(self, x: Tensor) -> dict: - """Forward pass through Beta-VAE.""" - # Encode - encoder_output = self.encoder(x) - mu = encoder_output.embedding - logvar = encoder_output.log_covariance - - # Reparameterize - z = self.reparameterize(mu, logvar) - - # Decode - reconstruction = self.decoder(z) - - # Compute losses - recon_loss = F.mse_loss(reconstruction, x, reduction="sum") / x.size(0) - kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) - total_loss = recon_loss + self.beta * kl_loss - - return { - "recon_x": reconstruction, - "z": z, - "mu": mu, - "logvar": logvar, - "recon_loss": recon_loss, - "kl_loss": kl_loss, - "reg_loss": kl_loss, # For compatibility with logging - "loss": total_loss, - } - - class VaeModule(LightningModule): """Native PyTorch Lightning Beta-VAE implementation.""" @@ -324,6 +259,9 @@ def __init__( decoder: VaeDecoder, latent_dim: int = 128, beta: float = 1.0, + beta_schedule: str = None, # "linear", "cosine", "warmup", or None + beta_min: float = 0.1, + beta_warmup_epochs: int = 50, lr: float = 1e-3, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, @@ -340,6 +278,9 @@ def __init__( self.decoder = decoder self.latent_dim = latent_dim self.beta = beta + self.beta_schedule = beta_schedule + self.beta_min = beta_min + self.beta_warmup_epochs = beta_warmup_epochs self.lr = lr self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch @@ -348,17 +289,14 @@ def __init__( self.compute_disentanglement = compute_disentanglement self.disentanglement_frequency = disentanglement_frequency - # Create the Beta-VAE model - self.model = BetaVAE( - encoder=self.encoder, decoder=self.decoder, latent_dim=latent_dim, beta=beta - ) + # Store model components directly (no separate BetaVAE class) # Initialize tracking lists self.training_step_outputs = [] self.validation_step_outputs = [] # Enhanced β-VAE logging - initialize early - self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) + self.vae_logger = BetaVaeLogger(latent_dim=latent_dim, device="cuda") # Note: DisentanglementMetrics will be initialized in setup() when device is available self.disentanglement_metrics = None @@ -370,9 +308,76 @@ def setup(self, stage: str = None): if self.disentanglement_metrics is None: self.disentanglement_metrics = DisentanglementMetrics(device=self.device) + def _get_current_beta(self) -> float: + """Get current beta value based on scheduling.""" + if self.beta_schedule is None: + return self.beta + + epoch = self.current_epoch + + if self.beta_schedule == "linear": + # Linear warmup from beta_min to beta + if epoch < self.beta_warmup_epochs: + return ( + self.beta_min + + (self.beta - self.beta_min) * epoch / self.beta_warmup_epochs + ) + else: + return self.beta + + elif self.beta_schedule == "cosine": + # Cosine warmup from beta_min to beta + if epoch < self.beta_warmup_epochs: + import math + + progress = epoch / self.beta_warmup_epochs + return self.beta_min + (self.beta - self.beta_min) * 0.5 * ( + 1 + math.cos(math.pi * (1 - progress)) + ) + else: + return self.beta + + elif self.beta_schedule == "warmup": + # Keep beta_min for warmup epochs, then jump to beta + return self.beta_min if epoch < self.beta_warmup_epochs else self.beta + + else: + return self.beta + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """Reparameterization trick: sample from N(mu, var) using N(0,1).""" + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + def forward(self, x: Tensor) -> dict: - """Forward pass through VAE model.""" - return self.model(x) + """Forward pass through Beta-VAE.""" + # Encode + encoder_output = self.encoder(x) + mu = encoder_output.embedding + logvar = encoder_output.log_covariance + + # Reparameterize + z = self.reparameterize(mu, logvar) + + # Decode + reconstruction = self.decoder(z) + + # Compute losses with current beta (allows for scheduling) + current_beta = self._get_current_beta() + recon_loss = F.mse_loss(reconstruction, x, reduction="mean") + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) + total_loss = recon_loss + current_beta * kl_loss + + return { + "recon_x": reconstruction, + "z": z, + "mu": mu, + "logvar": logvar, + "recon_loss": recon_loss, + "kl_loss": kl_loss, + "loss": total_loss, + } def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: """Training step with VAE loss computation.""" diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 6078a1769..e297fa9d6 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -4,13 +4,11 @@ import timm from monai.networks.blocks import ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer -from pythae.models.nn import BaseDecoder, BaseEncoder from torch import Tensor, nn from viscy.unet.networks.unext2 import ( PixelToVoxelHead, StemDepthtoChannels, - UNeXt2Stem, ) @@ -172,9 +170,7 @@ def __init__( latent_spatial_size: int = 8, head_expansion_ratio: int = 4, head_pool: bool = False, - upsample_mode: Literal[ - "deconv", "pixelshuffle" - ] = "pixelshuffle", # Better quality + upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, norm_name: str = "batch", upsample_pre_conv: Literal["default"] | Callable | None = None, diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 44fdf810b..5bcdf5bf7 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -1,10 +1,9 @@ import io import logging -from typing import Dict, List, Optional, Tuple +from typing import Optional, Tuple import matplotlib.pyplot as plt import numpy as np -import seaborn as sns import torch import torch.nn.functional as F from PIL import Image @@ -12,6 +11,8 @@ from sklearn.manifold import TSNE from torchvision.utils import make_grid +from viscy.representation.disentanglement_metrics import DisentanglementMetrics + _logger = logging.getLogger(__name__) @@ -23,8 +24,9 @@ class BetaVaeLogger: and latent space analysis for microscopy data. """ - def __init__(self, latent_dim: int = 128): + def __init__(self, latent_dim: int = 128, device: str = "cuda"): self.latent_dim = latent_dim + self.disentanglement_metrics = DisentanglementMetrics(device=device) def log_enhanced_metrics( self, lightning_module, model_output: dict, batch: dict, stage: str = "train" @@ -43,9 +45,9 @@ def log_enhanced_metrics( # Handle both Pythae dict format and object format if isinstance(model_output, dict): z = model_output["z"] - recon_x = model_output["recon_x"] + recon_x = model_output["recon_x"] recon_loss = model_output["recon_loss"] - kl_loss = model_output["reg_loss"] # Pythae uses 'reg_loss' for KL + kl_loss = model_output["kl_loss"] else: z = model_output.z if hasattr(model_output, "z") else model_output.embedding recon_x = ( @@ -56,12 +58,8 @@ def log_enhanced_metrics( recon_loss = model_output.recon_loss kl_loss = model_output.kl_loss - # Get β from model config - beta = ( - lightning_module.model.model_config.beta - if hasattr(lightning_module.model, "model_config") - else 1.0 - ) + # Get β directly from lightning module + beta = getattr(lightning_module, "beta", 1.0) # 1. Core VAE Loss Components (organized in one TensorBoard group) total_loss = recon_loss + beta * kl_loss @@ -74,7 +72,6 @@ def log_enhanced_metrics( f"loss_components/weighted_kl_loss/{stage}": beta * kl_loss, f"loss_components/total_loss/{stage}": total_loss, f"loss_components/beta_value/{stage}": beta, - # Loss analysis ratios f"loss_analysis/kl_recon_ratio/{stage}": kl_recon_ratio, f"loss_analysis/recon_contribution/{stage}": recon_loss / total_loss, @@ -139,7 +136,7 @@ def _log_latent_histograms(self, lightning_module, z: torch.Tensor, stage: str): for i in range(n_dims_to_log): lightning_module.logger.experiment.add_histogram( - f"latent_dim_{i}_distribution/{stage}", + f"latent_distributions/dim_{i}_{stage}", z_np[:, i], lightning_module.current_epoch, ) @@ -177,17 +174,8 @@ def log_latent_traversal( z_modified = z_base.clone() z_modified[0, dim] = val - # Generate reconstruction - decoder_output = lightning_module.model.decoder(z_modified) - # Handle both Pythae dict format and object format - if isinstance(decoder_output, dict): - recon = decoder_output["reconstruction"] - else: - recon = ( - decoder_output.reconstruction - if hasattr(decoder_output, "reconstruction") - else decoder_output - ) + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_modified) # Take middle z-slice for visualization mid_z = recon.shape[2] // 2 @@ -237,17 +225,8 @@ def log_latent_interpolation( for alpha in np.linspace(0, 1, n_steps): z_interp = alpha * z1 + (1 - alpha) * z2 - # Generate reconstruction - decoder_output = lightning_module.model.decoder(z_interp) - # Handle both Pythae dict format and object format - if isinstance(decoder_output, dict): - recon = decoder_output["reconstruction"] - else: - recon = ( - decoder_output.reconstruction - if hasattr(decoder_output, "reconstruction") - else decoder_output - ) + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_interp) # Take middle z-slice for visualization mid_z = recon.shape[2] // 2 @@ -302,17 +281,8 @@ def log_factor_traversal_matrix( z_mod = z_base.clone() z_mod[0, dim] = val - # Generate reconstruction - decoder_output = lightning_module.model.decoder(z_mod) - # Handle both Pythae dict format and object format - if isinstance(decoder_output, dict): - recon = decoder_output["reconstruction"] - else: - recon = ( - decoder_output.reconstruction - if hasattr(decoder_output, "reconstruction") - else decoder_output - ) + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_mod) # Take middle z-slice for visualization mid_z = recon.shape[2] // 2 @@ -446,3 +416,60 @@ def log_beta_schedule( lightning_module.log("beta_schedule", beta) return beta + + def log_disentanglement_metrics( + self, + lightning_module, + dataloader: torch.utils.data.DataLoader, + max_samples: int = 500, + ): + """ + Log disentanglement metrics to TensorBoard every 10 epochs. + + Args: + lightning_module: Lightning module instance + dataloader: DataLoader for evaluation + max_samples: Maximum samples to use for evaluation + """ + # Only compute every 10 epochs to save compute + if lightning_module.current_epoch % 10 != 0: + return + + _logger.info( + f"Computing disentanglement metrics at epoch {lightning_module.current_epoch}" + ) + + try: + # Use the lightning module directly (no separate model attribute after refactoring) + vae_model = lightning_module + + # Compute all disentanglement metrics + metrics = self.disentanglement_metrics.compute_all_metrics( + vae_model=vae_model, dataloader=dataloader, max_samples=max_samples + ) + + # Log metrics with organized naming + tensorboard_metrics = {} + for metric_name, value in metrics.items(): + tensorboard_metrics[f"disentanglement_metrics/{metric_name}"] = value + + lightning_module.log_dict( + tensorboard_metrics, + on_step=False, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + _logger.info(f"Logged disentanglement metrics: {metrics}") + + except Exception as e: + _logger.warning(f"Failed to compute disentanglement metrics: {e}") + # Log a placeholder to indicate the attempt + lightning_module.log( + "disentanglement_metrics/computation_failed", + 1.0, + on_step=False, + on_epoch=True, + logger=True, + ) From 29a822ef1ecdca178fe70b2cc62bf5a72eeaa811 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 15:59:44 -0700 Subject: [PATCH 011/101] condense the logging to have less tabs. --- viscy/representation/vae_logging.py | 60 ++++++++++++----------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 5bcdf5bf7..64d8f4e46 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -61,56 +61,44 @@ def log_enhanced_metrics( # Get β directly from lightning module beta = getattr(lightning_module, "beta", 1.0) - # 1. Core VAE Loss Components (organized in one TensorBoard group) + # Record losses and reconstruction quality metrics total_loss = recon_loss + beta * kl_loss kl_recon_ratio = kl_loss / (recon_loss + 1e-8) + mse_loss = F.mse_loss(recon_x, x) + mae_loss = F.l1_loss(recon_x, x) + metrics = { - # Core loss components - f"loss_components/reconstruction_loss/{stage}": recon_loss, - f"loss_components/kl_loss/{stage}": kl_loss, - f"loss_components/weighted_kl_loss/{stage}": beta * kl_loss, - f"loss_components/total_loss/{stage}": total_loss, - f"loss_components/beta_value/{stage}": beta, - # Loss analysis ratios - f"loss_analysis/kl_recon_ratio/{stage}": kl_recon_ratio, - f"loss_analysis/recon_contribution/{stage}": recon_loss / total_loss, + # All losses in one consolidated group + f"losses/reconstruction/{stage}": recon_loss, + f"losses/kl/{stage}": kl_loss, + f"losses/weighted_kl/{stage}": beta * kl_loss, + f"losses/total/{stage}": total_loss, + f"losses/mse/{stage}": mse_loss, + f"losses/mae/{stage}": mae_loss, + f"losses/beta_value/{stage}": beta, + f"losses/kl_recon_ratio/{stage}": kl_recon_ratio, + f"losses/recon_contribution/{stage}": recon_loss / total_loss, } - # 2. Latent space statistics + # Latent space statistics latent_mean = torch.mean(z, dim=0) latent_std = torch.std(z, dim=0) - metrics.update( - { - f"latent_stats/mean_avg/{stage}": torch.mean(latent_mean), - f"latent_stats/std_avg/{stage}": torch.mean(latent_std), - f"latent_stats/mean_max/{stage}": torch.max(latent_mean), - f"latent_stats/std_max/{stage}": torch.max(latent_std), - } - ) - - # 3. Reconstruction quality metrics - mse_loss = F.mse_loss(recon_x, x) - mae_loss = F.l1_loss(recon_x, x) - - metrics.update( - { - f"reconstruction_quality/mse/{stage}": mse_loss, - f"reconstruction_quality/mae/{stage}": mae_loss, - } - ) - - # 4. Latent capacity metrics active_dims = torch.sum(torch.var(z, dim=0) > 0.01) variances = torch.var(z, dim=0) effective_dim = torch.sum(variances) ** 2 / torch.sum(variances**2) metrics.update( { - f"latent_capacity/active_dims/{stage}": active_dims, - f"latent_capacity/effective_dim/{stage}": effective_dim, - f"latent_capacity/utilization/{stage}": active_dims / self.latent_dim, + # Consolidated latent statistics + f"latent_statistics/mean_avg/{stage}": torch.mean(latent_mean), + f"latent_statistics/std_avg/{stage}": torch.mean(latent_std), + f"latent_statistics/mean_max/{stage}": torch.max(latent_mean), + f"latent_statistics/std_max/{stage}": torch.max(latent_std), + f"latent_statistics/active_dims/{stage}": active_dims, + f"latent_statistics/effective_dim/{stage}": effective_dim, + f"latent_statistics/utilization/{stage}": active_dims / self.latent_dim, } ) @@ -123,7 +111,7 @@ def log_enhanced_metrics( sync_dist=True, ) - # 5. Log latent dimension histograms (periodically) + # Log latent dimension histograms (periodically) if stage == "val" and lightning_module.current_epoch % 10 == 0: self._log_latent_histograms(lightning_module, z, stage) From 86b3467359b0f862054bcdcc0dc8d9f479755672 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 16:50:58 -0700 Subject: [PATCH 012/101] fix disentagle metrics --- .../benchmarking/DynaCLR/BetaVAE/test_run.py | 10 ++++++- viscy/representation/engine.py | 28 +++++++++++-------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py index d45881ea8..aa4b5abb2 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py @@ -178,9 +178,17 @@ def channel_normalization( return_negative=False, fit_include_wells=["B/3", "B/4", "C/3", "C/4"], ) + dm.setup("fit") print(f"DataModule created successfully") + train_size = len(dm.train_dataset) + val_size = len(dm.val_dataset) + batches_per_epoch = train_size // batch_size - # Create trainer + print(f"Training samples: {train_size:,}") + print(f"Validation samples: {val_size:,}") + print(f"Batches per epoch: {batches_per_epoch:,}") + + # # Create trainer trainer = VisCyTrainer( accelerator="gpu", strategy="ddp", diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 5bf5d9336..4560e1751 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -528,12 +528,14 @@ def _compute_and_log_disentanglement_metrics(self): ) return - # Get validation dataloader - val_dataloader = ( - self.trainer.val_dataloaders[0] - if self.trainer.val_dataloaders - else None - ) + # Get validation dataloader - handle both single DataLoader and list cases + val_dataloaders = self.trainer.val_dataloaders + if val_dataloaders is None: + val_dataloader = None + elif isinstance(val_dataloaders, list): + val_dataloader = val_dataloaders[0] if val_dataloaders else None + else: + val_dataloader = val_dataloaders if val_dataloader is None: _logger.warning( @@ -571,12 +573,14 @@ def _compute_and_log_disentanglement_metrics(self): def _log_enhanced_visualizations(self): """Log enhanced β-VAE visualizations.""" try: - # Get validation dataloader - val_dataloader = ( - self.trainer.val_dataloaders[0] - if self.trainer.val_dataloaders - else None - ) + # Get validation dataloader - handle both single DataLoader and list cases + val_dataloaders = self.trainer.val_dataloaders + if val_dataloaders is None: + val_dataloader = None + elif isinstance(val_dataloaders, list): + val_dataloader = val_dataloaders[0] if val_dataloaders else None + else: + val_dataloader = val_dataloaders if val_dataloader is None: _logger.warning("No validation dataloader available for visualizations") From 2bb6d19bc7c804aef62ee058b52378c25d1b9924 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 17:54:06 -0700 Subject: [PATCH 013/101] fixing beta warmup bug --- .../benchmarking/DynaCLR/BetaVAE/test_run.py | 8 +++--- viscy/representation/engine.py | 25 ++----------------- viscy/representation/vae_logging.py | 8 ++++-- 3 files changed, 12 insertions(+), 29 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py index aa4b5abb2..336b7c369 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py @@ -144,10 +144,10 @@ def channel_normalization( decoder=decoder, example_input_array_shape=(1, 1, z_stack_depth, 192, 192), latent_dim=256, - beta=0.5, + beta=1.5, lr=2e-4, beta_schedule="linear", - beta_min=0.1, + beta_min=0.5, beta_warmup_epochs=15, ) print(f"VaeModule created successfully") @@ -201,8 +201,8 @@ def channel_normalization( check_val_every_n_epoch=1, logger=TensorBoardLogger( save_dir="/hpc/projects/organelle_phenotyping/models/SEC61B/vae", - name="betavae_phase3D_ddp", - version="beta_0.5_16slice", + name="betavae_phase3D", + version="beta_1.5_16slice", ), callbacks=[ LearningRateMonitor(logging_interval="step"), diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 4560e1751..9552933e0 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -289,22 +289,17 @@ def __init__( self.compute_disentanglement = compute_disentanglement self.disentanglement_frequency = disentanglement_frequency - # Store model components directly (no separate BetaVAE class) - - # Initialize tracking lists self.training_step_outputs = [] self.validation_step_outputs = [] - # Enhanced β-VAE logging - initialize early self.vae_logger = BetaVaeLogger(latent_dim=latent_dim, device="cuda") - # Note: DisentanglementMetrics will be initialized in setup() when device is available self.disentanglement_metrics = None def setup(self, stage: str = None): """Setup hook to initialize device-dependent components.""" super().setup(stage) - # Initialize DisentanglementMetrics after device is available + if self.disentanglement_metrics is None: self.disentanglement_metrics = DisentanglementMetrics(device=self.device) @@ -363,7 +358,7 @@ def forward(self, x: Tensor) -> dict: # Decode reconstruction = self.decoder(z) - # Compute losses with current beta (allows for scheduling) + # Compute losses with current beta current_beta = self._get_current_beta() recon_loss = F.mse_loss(reconstruction, x, reduction="mean") kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) @@ -387,14 +382,6 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: # Beta-VAE computes loss internally loss = model_output["loss"] - # Log basic metrics - self._log_metrics( - loss=loss, - recon_loss=model_output["recon_loss"], - kl_loss=model_output["kl_loss"], - stage="train", - ) - # Log enhanced β-VAE metrics self.vae_logger.log_enhanced_metrics( lightning_module=self, model_output=model_output, batch=batch, stage="train" @@ -413,14 +400,6 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: # Beta-VAE computes loss internally loss = model_output["loss"] - # Log basic metrics - self._log_metrics( - loss=loss, - recon_loss=model_output["recon_loss"], - kl_loss=model_output["kl_loss"], - stage="val", - ) - # Log enhanced β-VAE metrics self.vae_logger.log_enhanced_metrics( lightning_module=self, model_output=model_output, batch=batch, stage="val" diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 64d8f4e46..a9f3eefaf 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -58,8 +58,12 @@ def log_enhanced_metrics( recon_loss = model_output.recon_loss kl_loss = model_output.kl_loss - # Get β directly from lightning module - beta = getattr(lightning_module, "beta", 1.0) + # Get current β (scheduled value, not static) + beta = getattr( + lightning_module, + "_get_current_beta", + lambda: getattr(lightning_module, "beta", 1.0), + )() # Record losses and reconstruction quality metrics total_loss = recon_loss + beta * kl_loss From 8e8eba8211a2509fe5511349d5bdc68ad74d6ff6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 22:37:27 -0700 Subject: [PATCH 014/101] renaming to loss --- .../benchmarking/DynaCLR/BetaVAE/test_run.py | 5 ++++- viscy/representation/vae_logging.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py index 336b7c369..d1c0290b1 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/test_run.py @@ -207,7 +207,10 @@ def channel_normalization( callbacks=[ LearningRateMonitor(logging_interval="step"), ModelCheckpoint( - monitor="loss/val", save_top_k=5, save_last=True, every_n_epochs=1 + monitor="loss/total/val", + save_top_k=5, + save_last=True, + every_n_epochs=1, ), ], use_distributed_sampler=True, diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index a9f3eefaf..9e103d0cb 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -74,15 +74,15 @@ def log_enhanced_metrics( metrics = { # All losses in one consolidated group - f"losses/reconstruction/{stage}": recon_loss, - f"losses/kl/{stage}": kl_loss, - f"losses/weighted_kl/{stage}": beta * kl_loss, - f"losses/total/{stage}": total_loss, - f"losses/mse/{stage}": mse_loss, - f"losses/mae/{stage}": mae_loss, - f"losses/beta_value/{stage}": beta, - f"losses/kl_recon_ratio/{stage}": kl_recon_ratio, - f"losses/recon_contribution/{stage}": recon_loss / total_loss, + f"loss/total/{stage}": total_loss, + f"loss/reconstruction/{stage}": recon_loss, + f"loss/kl/{stage}": kl_loss, + f"loss/weighted_kl/{stage}": beta * kl_loss, + f"loss/mse/{stage}": mse_loss, + f"loss/mae/{stage}": mae_loss, + f"loss/beta_value/{stage}": beta, + f"loss/kl_recon_ratio/{stage}": kl_recon_ratio, + f"loss/recon_contribution/{stage}": recon_loss / total_loss, } # Latent space statistics From 116183dce7061d63577d60eeab44480940c2ac9d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 23 Jul 2025 14:55:04 -0700 Subject: [PATCH 015/101] updating architecture to flatten vs spatial VAE with convs --- .../DynaCLR/BetaVAE/debug_dimensions.py | 194 +++++++++++++----- viscy/representation/engine.py | 16 +- viscy/representation/vae.py | 100 ++++++--- viscy/representation/vae_logging.py | 78 ++++++- 4 files changed, 297 insertions(+), 91 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py index 70be8f8a8..aacedfe88 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py @@ -6,17 +6,15 @@ def debug_vae_dimensions(): """Debug VAE encoder/decoder dimension compatibility.""" - print("=== VAE Dimension Debugging ===\n") + print("=== VAE Dimension Debugging (Updated Architecture) ===\n") - # Configuration from test_run.py - z_stack_depth = 32 - input_shape = (1, 1, z_stack_depth, 192, 192) - latent_dim = 256 - latent_spatial_size = 3 + # Configuration matching current config + z_stack_depth = 16 + input_shape = (1, 1, z_stack_depth, 192, 192) # 1 channel to match config + latent_dim = 1024 # Updated to new default print(f"Input shape: {input_shape}") print(f"Expected latent dim: {latent_dim}") - print(f"Expected latent spatial size: {latent_spatial_size}") print() # Create encoder @@ -24,9 +22,9 @@ def debug_vae_dimensions(): backbone="resnet50", in_channels=1, in_stack_depth=z_stack_depth, - embedding_dim=latent_dim, - stem_kernel_size=(8, 4, 4), - stem_stride=(8, 4, 4), + latent_dim=latent_dim, + stem_kernel_size=(4, 2, 2), + stem_stride=(4, 2, 2), ) # Create decoder @@ -35,14 +33,12 @@ def debug_vae_dimensions(): latent_dim=latent_dim, out_channels=1, out_stack_depth=z_stack_depth, - latent_spatial_size=latent_spatial_size, - head_expansion_ratio=1, + head_expansion_ratio=2, head_pool=False, - upsample_mode="deconv", + upsample_mode="pixelshuffle", conv_blocks=2, norm_name="batch", - upsample_pre_conv=None, - strides=[2, 2, 2, 2], + strides=[2, 2, 2, 1], ) print("=== ENCODER FORWARD PASS ===") @@ -53,23 +49,22 @@ def debug_vae_dimensions(): try: # Step through encoder - print("\n1. Stem processing:") + print("\\n1. Stem processing:") x_stem = encoder.stem(x) print(f" After stem: {x_stem.shape}") - print("\n2. Backbone processing:") + print("\\n2. Backbone processing:") features = encoder.encoder(x_stem) for i, feat in enumerate(features): print(f" Feature {i}: {feat.shape}") - print("\n3. Final processing:") + print("\\n3. Final processing:") x_final = features[-1] print(f" Final features: {x_final.shape}") - x_pooled = encoder.global_pool(x_final) - print(f" After global pool: {x_pooled.shape}") - - x_flat = x_pooled.flatten(1) + # Flatten spatial dimensions (new approach) + batch_size = x_final.size(0) + x_flat = x_final.view(batch_size, -1) print(f" After flatten: {x_flat.shape}") # Full encoder output @@ -79,46 +74,59 @@ def debug_vae_dimensions(): print(f" Final mu: {mu.shape}") print(f" Final logvar: {logvar.shape}") - print("\n=== DECODER FORWARD PASS ===") + print("\\n=== DECODER FORWARD PASS ===") # Test decoder with latent vector z = torch.randn(1, latent_dim) print(f"Input to decoder: {z.shape}") - print("\n1. Latent projection:") - x_proj = decoder.latent_proj(z) - print(f" After projection: {x_proj.shape}") - - x_reshaped = x_proj.view(1, -1, latent_spatial_size, latent_spatial_size) - print(f" After reshape: {x_reshaped.shape}") - - print("\n2. Decoder stages:") - x_current = x_reshaped + print("\\n1. Reshape to spatial:") + batch_size = z.size(0) + z_spatial = decoder.latent_reshape(z) + print(f" After linear reshape: {z_spatial.shape}") + + z_spatial_reshaped = z_spatial.view( + batch_size, + decoder.spatial_channels, + decoder.spatial_size, + decoder.spatial_size, + ) + print(f" After view to spatial: {z_spatial_reshaped.shape}") + + print("\\n2. Latent projection:") + x_proj = decoder.latent_proj(z_spatial_reshaped) + print(f" After conv projection: {x_proj.shape}") + + print("\\n3. Decoder stages:") + x_current = x_proj for i, stage in enumerate(decoder.decoder_stages): x_current = stage(x_current) print(f" After stage {i}: {x_current.shape}") - print("\n3. Head processing:") + print("\\n4. Head processing:") final_output = decoder.head(x_current) print(f" Final output: {final_output.shape}") - # Full decoder output - decoder_output = decoder(z) - reconstruction = decoder_output["reconstruction"] + # Full decoder output (now returns tensor directly, not dict) + reconstruction = decoder(z) print(f" Full reconstruction: {reconstruction.shape}") - print("\n=== DIMENSION ANALYSIS ===") + print("\\n=== DIMENSION ANALYSIS ===") print(f"✓ Encoder input: {input_shape}") print(f"✓ Encoder output: {mu.shape}") print(f"✓ Decoder input: {z.shape}") print(f"✓ Decoder output: {reconstruction.shape}") - # Calculate tensor sizes + # Calculate tensor sizes and compression ratio input_size = torch.numel(x) + latent_size = torch.numel(mu) recon_size = torch.numel(reconstruction) - print(f" Input tensor size: {input_size}") - print(f" Reconstruction tensor size: {recon_size}") - print(f" Size ratio: {recon_size / input_size:.2f}") + + print(f" Input tensor size: {input_size:,}") + print(f" Latent tensor size: {latent_size:,}") + print(f" Reconstruction tensor size: {recon_size:,}") + print(f" Compression ratio: {input_size / latent_size:.1f}:1") + print(f" Size ratio (recon/input): {recon_size / input_size:.2f}") # Check if reconstruction matches input if reconstruction.shape == x.shape: @@ -137,6 +145,89 @@ def debug_vae_dimensions(): f" Dimension {i}: {inp_dim} → {recon_dim} (factor: {recon_dim/inp_dim:.2f})" ) + print("\\n=== VAE LOSS COMPUTATION TEST ===") + + # Simulate full VAE forward pass with loss computation + print("Testing full VAE forward pass with loss computation...") + + # Sample from latent distribution (reparameterization trick) + eps = torch.randn_like(mu) + z_sampled = mu + torch.exp(0.5 * logvar) * eps + print(f"Sampled latent z: {z_sampled.shape}") + + # Decode the sampled latent + reconstruction_from_sampled = decoder(z_sampled) + print(f"Reconstruction from sampled z: {reconstruction_from_sampled.shape}") + + # Compute VAE losses + import torch.nn.functional as F + + # Reconstruction loss (MSE) + recon_loss = F.mse_loss(reconstruction_from_sampled, x, reduction="mean") + print(f"Reconstruction loss (MSE): {recon_loss.item():.6e}") + + # KL divergence loss + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) + print(f"KL divergence loss: {kl_loss.item():.6e}") + + # Total VAE loss with different beta values + betas = [0.1, 1.0, 1.5, 4.0] + for beta in betas: + total_loss = recon_loss + beta * kl_loss + print(f"Total loss (β={beta}): {total_loss.item():.6e}") + + # Check for problematic values + print("\\n=== LOSS HEALTH CHECK ===") + + if torch.isnan(recon_loss): + print("✗ CRITICAL: Reconstruction loss is NaN!") + elif torch.isinf(recon_loss): + print("✗ CRITICAL: Reconstruction loss is Inf!") + elif recon_loss.item() > 1e6: + print(f"⚠ WARNING: Very high reconstruction loss: {recon_loss.item():.2e}") + elif recon_loss.item() < 1e-10: + print(f"⚠ WARNING: Very low reconstruction loss: {recon_loss.item():.2e}") + else: + print(f"✓ Reconstruction loss looks reasonable: {recon_loss.item():.6f}") + + if torch.isnan(kl_loss): + print("✗ CRITICAL: KL loss is NaN!") + elif torch.isinf(kl_loss): + print("✗ CRITICAL: KL loss is Inf!") + else: + print(f"✓ KL loss looks reasonable: {kl_loss.item():.6f}") + + # Check reconstruction value ranges + recon_min, recon_max = ( + reconstruction_from_sampled.min(), + reconstruction_from_sampled.max(), + ) + input_min, input_max = x.min(), x.max() + + print(f"\\nValue ranges:") + print(f" Input range: [{input_min.item():.3f}, {input_max.item():.3f}]") + print( + f" Reconstruction range: [{recon_min.item():.3f}, {recon_max.item():.3f}]" + ) + + if recon_max.item() > 100 or recon_min.item() < -100: + print( + "⚠ WARNING: Reconstruction values are very large - possible gradient explosion" + ) + + # Check latent statistics + mu_mean, mu_std = mu.mean(), mu.std() + logvar_mean, logvar_std = logvar.mean(), logvar.std() + + print(f"\\nLatent statistics:") + print(f" μ mean/std: {mu_mean.item():.3f} / {mu_std.item():.3f}") + print(f" log(σ²) mean/std: {logvar_mean.item():.3f} / {logvar_std.item():.3f}") + + if mu_std.item() > 10: + print("⚠ WARNING: μ has very high variance - possible gradient explosion") + if logvar_mean.item() > 10: + print("⚠ WARNING: log(σ²) is very large - possible numerical instability") + except Exception as e: print(f"✗ ERROR during forward pass: {e}") print(f"Error type: {type(e).__name__}") @@ -144,24 +235,23 @@ def debug_vae_dimensions(): traceback.print_exc() - # Let's check what spatial size the encoder actually produces - print("\n=== ENCODER SPATIAL SIZE ANALYSIS ===") + # Check flattened feature size for new architecture + print("\\n=== ENCODER FLATTENED SIZE ANALYSIS ===") try: x_stem = encoder.stem(x) features = encoder.encoder(x_stem) final_feat = features[-1] - actual_spatial_size = final_feat.shape[-1] # Assuming square - print(f"Actual spatial size from encoder: {actual_spatial_size}") - print(f"Expected spatial size for decoder: {latent_spatial_size}") + print(f"Final feature shape: {final_feat.shape}") + + flattened_size = final_feat.view(1, -1).shape[1] + print(f"Flattened size: {flattened_size:,}") + print(f"Expected latent dim: {latent_dim:,}") - if actual_spatial_size != latent_spatial_size: - print( - f"✗ MISMATCH: Encoder produces {actual_spatial_size}x{actual_spatial_size}, decoder expects {latent_spatial_size}x{latent_spatial_size}" - ) - print(f" Suggested fix: Set latent_spatial_size={actual_spatial_size}") + compression_ratio = flattened_size / latent_dim + print(f"Compression ratio: {compression_ratio:.1f}:1") except Exception as inner_e: - print(f"Error in spatial size analysis: {inner_e}") + print(f"Error in flattened size analysis: {inner_e}") if __name__ == "__main__": diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 9552933e0..56213126b 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -257,9 +257,8 @@ def __init__( self, encoder: VaeEncoder, decoder: VaeDecoder, - latent_dim: int = 128, beta: float = 1.0, - beta_schedule: str = None, # "linear", "cosine", "warmup", or None + beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, beta_min: float = 0.1, beta_warmup_epochs: int = 50, lr: float = 1e-3, @@ -276,7 +275,16 @@ def __init__( self.encoder = encoder self.decoder = decoder - self.latent_dim = latent_dim + + # Infer latent dimension from encoder and validate decoder matches + self.latent_dim = encoder.latent_dim + + # Validate that decoder's latent_dim matches encoder's embedding_dim + if hasattr(decoder, "latent_dim") and decoder.latent_dim != self.latent_dim: + raise ValueError( + f"Encoder embedding_dim ({self.latent_dim}) must match " + f"decoder latent_dim ({decoder.latent_dim})" + ) self.beta = beta self.beta_schedule = beta_schedule self.beta_min = beta_min @@ -292,7 +300,7 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] - self.vae_logger = BetaVaeLogger(latent_dim=latent_dim, device="cuda") + self.vae_logger = BetaVaeLogger(latent_dim=self.latent_dim, device="cuda") self.disentanglement_metrics = None diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index e297fa9d6..0bc548a3f 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -2,6 +2,7 @@ from typing import Callable, Literal import timm +import torch from monai.networks.blocks import ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer from torch import Tensor, nn @@ -96,19 +97,24 @@ def forward(self, inp: Tensor) -> Tensor: class VaeEncoder(nn.Module): """VAE encoder for microscopy data with 3D to 2D conversion.""" + # TODO: roll back the Conv2d to AveragePooling and linear layer to global pooling + # TODO: embedding dim + # TODO: check the OG VAE compression rate + # TODO do log grid search for the best embedding dim + def __init__( self, - backbone: str = "resnet50", + backbone: str = "resnet50", # [64, 256, 512, 1024, 2048] channels in_channels: int = 2, - in_stack_depth: int = 32, - embedding_dim: int = 128, - stem_kernel_size: tuple[int, int, int] = (8, 4, 4), - stem_stride: tuple[int, int, int] = (8, 2, 2), + in_stack_depth: int = 16, + latent_dim: int = 1024, + stem_kernel_size: tuple[int, int, int] = (4, 5, 5), + stem_stride: tuple[int, int, int] = (4, 5, 5), # same as kernel size drop_path_rate: float = 0.0, ): super().__init__() self.backbone = backbone - self.embedding_dim = embedding_dim + self.latent_dim = latent_dim encoder = timm.create_model( backbone, @@ -135,25 +141,33 @@ def __init__( ) self.encoder = encoder - self.global_pool = nn.AdaptiveAvgPool2d(1) - - self.fc_mu = nn.Linear(out_channels_encoder, embedding_dim) - self.fc_logvar = nn.Linear(out_channels_encoder, embedding_dim) + # Store for creating linear layers dynamically in forward pass + self.out_channels_encoder = out_channels_encoder + self.fc_mu = None + self.fc_logvar = None - def forward(self, x: Tensor) -> dict: + def forward(self, x: Tensor) -> SimpleNamespace: """Forward pass returning VAE encoder outputs.""" x = self.stem(x) features = self.encoder(x) - # Take highest resolution features - x = features[-1] - x = self.global_pool(x) - x = x.flatten(1) + # Take highest resolution features and flatten + x = features[-1] # [B, C, H, W] + + # Flatten spatial dimensions + batch_size = x.size(0) + x_flat = x.view(batch_size, -1) # [B, C*H*W] + + # Initialize linear layers on first forward pass + if self.fc_mu is None: + flattened_size = x_flat.size(1) + self.fc_mu = nn.Linear(flattened_size, self.latent_dim).to(x.device) + self.fc_logvar = nn.Linear(flattened_size, self.latent_dim).to(x.device) - # VAE outputs - mu = self.fc_mu(x) - logvar = self.fc_logvar(x) + # Apply linear layers to get 1D embeddings + mu = self.fc_mu(x_flat) # [B, embedding_dim] + logvar = self.fc_logvar(x_flat) # [B, embedding_dim] return SimpleNamespace(embedding=mu, log_covariance=logvar) @@ -164,11 +178,10 @@ class VaeDecoder(nn.Module): def __init__( self, decoder_channels: list[int] = [1024, 512, 256, 128], - latent_dim: int = 128, + latent_dim: int = 1024, out_channels: int = 2, - out_stack_depth: int = 20, - latent_spatial_size: int = 8, - head_expansion_ratio: int = 4, + out_stack_depth: int = 16, + head_expansion_ratio: int = 2, head_pool: bool = False, upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, @@ -179,7 +192,6 @@ def __init__( super().__init__() self.out_channels = out_channels self.out_stack_depth = out_stack_depth - self.latent_spatial_size = latent_spatial_size head_channels = ( (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio @@ -194,7 +206,12 @@ def __init__( if ( num_stages == 4 ): # Default [1024, 512, 256, 128] + head = 5 channels, 4 stages - strides = [2, 2, 2, 4] # 8→16→32→64→256 (32x total upsampling) + strides = [ + 2, + 2, + 2, + 1, + ] # Reduce to account for PixelToVoxelHead's 4x upsampling else: strides = [2] * num_stages # Fallback to uniform 2x upsampling elif len(strides) != num_stages: @@ -202,10 +219,16 @@ def __init__( f"Length of strides ({len(strides)}) must match number of stages ({num_stages})" ) - # Project latent vector to first feature map - self.latent_proj = nn.Linear( - latent_dim, - decoder_channels_with_head[0] * latent_spatial_size * latent_spatial_size, + # Store spatial dimensions for reshaping 1D latent back to spatial + self.spatial_size = 6 # Will be computed dynamically based on encoder output + self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) + + # Project 1D latent to spatial format, then to first decoder channels + self.latent_reshape = nn.Linear( + latent_dim, self.spatial_channels * self.spatial_size * self.spatial_size + ) + self.latent_proj = nn.Conv2d( + self.spatial_channels, decoder_channels_with_head[0], kernel_size=1 ) # Build the decoder stages @@ -227,7 +250,7 @@ def __init__( ) self.decoder_stages.append(stage) - # Head to convert back to 3D (no final_conv needed - last stage outputs head_channels) + # Head to convert back to 3D self.head = PixelToVoxelHead( in_channels=head_channels, out_channels=self.out_channels, @@ -236,18 +259,27 @@ def __init__( pool=head_pool, ) - def forward(self, z: Tensor) -> dict: + def forward(self, z: Tensor) -> Tensor: """Forward pass converting latent to 3D output.""" - batch_size = z.shape[0] + # z is now 1D: [batch, latent_dim] + batch_size = z.size(0) + + # Reshape 1D latent back to spatial format + z_spatial = self.latent_reshape(z) # [batch, spatial_channels * H * W] + z_spatial = z_spatial.view( + batch_size, self.spatial_channels, self.spatial_size, self.spatial_size + ) - # Project latent to feature map - x = self.latent_proj(z) - x = x.view(batch_size, -1, self.latent_spatial_size, self.latent_spatial_size) + # Project spatial latent to first decoder channels using 1x1 conv + x = self.latent_proj( + z_spatial + ) # [batch, decoder_channels[0], spatial_H, spatial_W] for stage in self.decoder_stages: x = stage(x) # Last stage outputs head_channels directly - no final_conv needed output = self.head(x) + output = torch.sigmoid(output) # Constrain to [0,1] to match normalized input return output diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 9e103d0cb..74a261f92 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -72,18 +72,31 @@ def log_enhanced_metrics( mse_loss = F.mse_loss(recon_x, x) mae_loss = F.l1_loss(recon_x, x) + # Add gradient explosion diagnostics + grad_diagnostics = self._compute_gradient_diagnostics(lightning_module) + + # Add NaN/Inf detection + nan_inf_diagnostics = self._check_nan_inf(recon_x, x, z) + + # Add shape diagnostics (log occasionally to avoid spam) + if lightning_module.current_epoch % 5 == 0: + self._log_tensor_shapes(lightning_module, x, recon_x, z, stage) + metrics = { # All losses in one consolidated group f"loss/total/{stage}": total_loss, f"loss/reconstruction/{stage}": recon_loss, f"loss/kl/{stage}": kl_loss, f"loss/weighted_kl/{stage}": beta * kl_loss, - f"loss/mse/{stage}": mse_loss, f"loss/mae/{stage}": mae_loss, f"loss/beta_value/{stage}": beta, f"loss/kl_recon_ratio/{stage}": kl_recon_ratio, f"loss/recon_contribution/{stage}": recon_loss / total_loss, } + + # Add diagnostic metrics + metrics.update(grad_diagnostics) + metrics.update(nan_inf_diagnostics) # Latent space statistics latent_mean = torch.mean(z, dim=0) @@ -119,6 +132,69 @@ def log_enhanced_metrics( if stage == "val" and lightning_module.current_epoch % 10 == 0: self._log_latent_histograms(lightning_module, z, stage) + def _compute_gradient_diagnostics(self, lightning_module): + """Compute gradient norms and parameter statistics for explosion detection.""" + grad_diagnostics = {} + + # Compute gradient norms for encoder and decoder + encoder_grad_norm = 0.0 + decoder_grad_norm = 0.0 + encoder_param_norm = 0.0 + decoder_param_norm = 0.0 + + for name, param in lightning_module.named_parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + if 'encoder' in name: + encoder_grad_norm += param_norm.item() ** 2 + elif 'decoder' in name: + decoder_grad_norm += param_norm.item() ** 2 + + # Parameter magnitudes + if 'encoder' in name: + encoder_param_norm += param.data.norm(2).item() ** 2 + elif 'decoder' in name: + decoder_param_norm += param.data.norm(2).item() ** 2 + + grad_diagnostics.update({ + "diagnostics/encoder_grad_norm": encoder_grad_norm ** 0.5, + "diagnostics/decoder_grad_norm": decoder_grad_norm ** 0.5, + "diagnostics/encoder_param_norm": encoder_param_norm ** 0.5, + "diagnostics/decoder_param_norm": decoder_param_norm ** 0.5, + }) + + return grad_diagnostics + + def _check_nan_inf(self, recon_x, x, z): + """Check for NaN/Inf values in tensors.""" + diagnostics = { + "diagnostics/recon_has_nan": torch.isnan(recon_x).any().float(), + "diagnostics/recon_has_inf": torch.isinf(recon_x).any().float(), + "diagnostics/input_has_nan": torch.isnan(x).any().float(), + "diagnostics/latent_has_nan": torch.isnan(z).any().float(), + "diagnostics/recon_max_val": torch.max(torch.abs(recon_x)), + "diagnostics/recon_min_val": torch.min(recon_x), + } + return diagnostics + + def _log_tensor_shapes(self, lightning_module, x, recon_x, z, stage): + """Log tensor shapes to help diagnose architectural mismatches.""" + _logger.info(f"[{stage}] Input shape: {x.shape}") + _logger.info(f"[{stage}] Latent shape: {z.shape}") + _logger.info(f"[{stage}] Reconstruction shape: {recon_x.shape}") + + # Check for shape mismatches + if x.shape != recon_x.shape: + _logger.warning(f"SHAPE MISMATCH: Input {x.shape} != Reconstruction {recon_x.shape}") + + # Log as scalars for TensorBoard tracking + lightning_module.log_dict({ + f"shapes/input_numel_{stage}": x.numel(), + f"shapes/recon_numel_{stage}": recon_x.numel(), + f"shapes/latent_numel_{stage}": z.numel(), + f"shapes/spatial_dims_{stage}": len(z.shape) - 2, # Exclude batch and channel dims + }, on_step=False, on_epoch=True, logger=True) + def _log_latent_histograms(self, lightning_module, z: torch.Tensor, stage: str): """Log histograms of latent dimensions.""" z_np = z.detach().cpu().numpy() From e47de7ce2629c37171d7a63decabf4a02cf96364 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 25 Jul 2025 15:56:58 -0700 Subject: [PATCH 016/101] chaning to use mse with mean reduction and normalizing the kl loss by batch size. --- .../representation/disentanglement_metrics.py | 28 ++-- viscy/representation/engine.py | 127 ++++++------------ viscy/representation/vae.py | 72 ++++++---- viscy/representation/vae_logging.py | 118 ++++++++-------- 4 files changed, 163 insertions(+), 182 deletions(-) diff --git a/viscy/representation/disentanglement_metrics.py b/viscy/representation/disentanglement_metrics.py index 195bcbabf..5bb7b209d 100644 --- a/viscy/representation/disentanglement_metrics.py +++ b/viscy/representation/disentanglement_metrics.py @@ -97,11 +97,15 @@ def _extract_latents_and_factors( # Extract latent representations model_output = vae_model(x) - z = ( - model_output.z - if hasattr(model_output, "z") - else model_output.embedding - ) + # Handle both dict format and object format + if isinstance(model_output, dict): + z = model_output["z"] + else: + z = ( + model_output.z + if hasattr(model_output, "z") + else model_output.embedding + ) latents.append(z.cpu().numpy()) # Extract visual factors from images @@ -280,11 +284,15 @@ def compute_beta_vae_score( # Get latent representation model_output = vae_model(x) - z = ( - model_output.z - if hasattr(model_output, "z") - else model_output.embedding - ) + # Handle both dict format and object format + if isinstance(model_output, dict): + z = model_output["z"] + else: + z = ( + model_output.z + if hasattr(model_output, "z") + else model_output.embedding + ) # Compute baseline reconstruction baseline_recon = vae_model.decoder(z) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 56213126b..6ea532a5a 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -257,6 +257,7 @@ def __init__( self, encoder: VaeEncoder, decoder: VaeDecoder, + loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), beta: float = 1.0, beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, beta_min: float = 0.1, @@ -267,9 +268,8 @@ def __init__( example_input_array_shape: Sequence[int] = (1, 2, 30, 256, 256), compute_disentanglement: bool = True, disentanglement_frequency: int = 10, - # Deprecated parameters for backward compatibility - model_name: str = "BetaVAE", - loss: str = "mse", + log_enhanced_visualizations: bool = False, + log_enhanced_visualizations_frequency: int = 30, ): super().__init__() @@ -278,7 +278,6 @@ def __init__( # Infer latent dimension from encoder and validate decoder matches self.latent_dim = encoder.latent_dim - # Validate that decoder's latent_dim matches encoder's embedding_dim if hasattr(decoder, "latent_dim") and decoder.latent_dim != self.latent_dim: raise ValueError( @@ -292,24 +291,24 @@ def __init__( self.lr = lr self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch - + self.loss_function = loss_function self.example_input_array = torch.rand(*example_input_array_shape) self.compute_disentanglement = compute_disentanglement self.disentanglement_frequency = disentanglement_frequency - + self.log_enhanced_visualizations = log_enhanced_visualizations + self.log_enhanced_visualizations_frequency = ( + log_enhanced_visualizations_frequency + ) self.training_step_outputs = [] self.validation_step_outputs = [] - - self.vae_logger = BetaVaeLogger(latent_dim=self.latent_dim, device="cuda") - - self.disentanglement_metrics = None + self.vae_logger = BetaVaeLogger(latent_dim=self.latent_dim) def setup(self, stage: str = None): """Setup hook to initialize device-dependent components.""" super().setup(stage) - if self.disentanglement_metrics is None: - self.disentanglement_metrics = DisentanglementMetrics(device=self.device) + # Initialize the VAE logger with proper device + self.vae_logger.setup(device=self.device) def _get_current_beta(self) -> float: """Get current beta value based on scheduling.""" @@ -347,30 +346,33 @@ def _get_current_beta(self) -> float: else: return self.beta - def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: - """Reparameterization trick: sample from N(mu, var) using N(0,1).""" - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" # Encode encoder_output = self.encoder(x) - mu = encoder_output.embedding + mu = encoder_output.mean logvar = encoder_output.log_covariance - - # Reparameterize - z = self.reparameterize(mu, logvar) + z = encoder_output.z # Decode reconstruction = self.decoder(z) - # Compute losses with current beta + # Compute losses with current beta (normalized by batch size) current_beta = self._get_current_beta() - recon_loss = F.mse_loss(reconstruction, x, reduction="mean") - kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) - total_loss = recon_loss + current_beta * kl_loss + batch_size = x.size(0) + + # MSE loss normalized by batch size + recon_loss = self.loss_function(reconstruction, x) + + # KL loss normalized by batch size + kl_loss = ( + -0.5 + * current_beta + * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + / batch_size + ) + + total_loss = recon_loss + kl_loss return { "recon_x": reconstruction, @@ -379,22 +381,20 @@ def forward(self, x: Tensor) -> dict: "logvar": logvar, "recon_loss": recon_loss, "kl_loss": kl_loss, - "loss": total_loss, + "total_loss": total_loss, } def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: """Training step with VAE loss computation.""" + x = batch["anchor"] model_output = self(x) - - # Beta-VAE computes loss internally - loss = model_output["loss"] + loss = model_output["total_loss"] # Log enhanced β-VAE metrics self.vae_logger.log_enhanced_metrics( lightning_module=self, model_output=model_output, batch=batch, stage="train" ) - # Log samples self._log_step_samples(batch_idx, x, model_output["recon_x"], "train") @@ -404,9 +404,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: """Validation step with VAE loss computation.""" x = batch["anchor"] model_output = self(x) - - # Beta-VAE computes loss internally - loss = model_output["loss"] + loss = model_output["total_loss"] # Log enhanced β-VAE metrics self.vae_logger.log_enhanced_metrics( @@ -418,23 +416,6 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: return loss - def _log_metrics(self, loss, recon_loss, kl_loss, stage: Literal["train", "val"]): - """Log VAE-specific metrics.""" - metrics = { - f"loss/{stage}": loss, - f"recon_loss/{stage}": recon_loss, - f"kl_loss/{stage}": kl_loss, - } - - self.log_dict( - metrics, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - def _log_step_samples( self, batch_idx, original, reconstruction, stage: Literal["train", "val"] ): @@ -501,20 +482,16 @@ def on_validation_epoch_end(self) -> None: ): self._compute_and_log_disentanglement_metrics() - # Log enhanced β-VAE visualizations periodically - if self.current_epoch % 20 == 0 and self.current_epoch > 0: + if ( + self.log_enhanced_visualizations + and self.current_epoch % self.log_enhanced_visualizations_frequency == 0 + and self.current_epoch > 0 + ): self._log_enhanced_visualizations() def _compute_and_log_disentanglement_metrics(self): """Compute and log disentanglement metrics.""" try: - # Check if disentanglement metrics are initialized - if self.disentanglement_metrics is None: - _logger.warning( - "DisentanglementMetrics not initialized, skipping computation" - ) - return - # Get validation dataloader - handle both single DataLoader and list cases val_dataloaders = self.trainer.val_dataloaders if val_dataloaders is None: @@ -530,32 +507,15 @@ def _compute_and_log_disentanglement_metrics(self): ) return - # Compute metrics - _logger.info( - f"Computing disentanglement metrics at epoch {self.current_epoch}" - ) - metrics = self.disentanglement_metrics.compute_all_metrics( - vae_model=self, + # Use the logger's disentanglement metrics method + self.vae_logger.log_disentanglement_metrics( + lightning_module=self, dataloader=val_dataloader, max_samples=200, ) - # Log metrics - for metric_name, metric_value in metrics.items(): - self.log( - f"disentanglement/{metric_name}", - metric_value, - on_step=False, - on_epoch=True, - logger=True, - sync_dist=True, - ) - - _logger.info(f"Disentanglement metrics: {metrics}") - except Exception as e: _logger.error(f"Error computing disentanglement metrics: {e}") - # Continue training even if metrics fail def _log_enhanced_visualizations(self): """Log enhanced β-VAE visualizations.""" @@ -577,17 +537,17 @@ def _log_enhanced_visualizations(self): f"Logging enhanced β-VAE visualizations at epoch {self.current_epoch}" ) - # Log latent traversals + # Log latent traversals -for how recons change when moving along a latent dim self.vae_logger.log_latent_traversal( lightning_module=self, n_dims=8, n_steps=11 ) - # Log latent interpolations + # Log latent interpolations - smooth transitions between different data points in the latent space self.vae_logger.log_latent_interpolation( lightning_module=self, n_pairs=3, n_steps=11 ) - # Log factor traversal matrix + # Log factor traversal matrix - grid visualization how each dim affects the recon self.vae_logger.log_factor_traversal_matrix( lightning_module=self, n_dims=8, n_steps=7 ) @@ -603,7 +563,6 @@ def _log_enhanced_visualizations(self): except Exception as e: _logger.error(f"Error logging enhanced visualizations: {e}") - # Continue training even if visualizations fail def configure_optimizers(self): """Configure optimizer for VAE training.""" diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 0bc548a3f..4e248b7a5 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -108,6 +108,7 @@ def __init__( in_channels: int = 2, in_stack_depth: int = 16, latent_dim: int = 1024, + input_spatial_size: tuple[int, int] = (256, 256), stem_kernel_size: tuple[int, int, int] = (4, 5, 5), stem_stride: tuple[int, int, int] = (4, 5, 5), # same as kernel size drop_path_rate: float = 0.0, @@ -141,10 +142,40 @@ def __init__( ) self.encoder = encoder - # Store for creating linear layers dynamically in forward pass + # Calculate spatial dimensions after encoder and initialize linear layers self.out_channels_encoder = out_channels_encoder - self.fc_mu = None - self.fc_logvar = None + + if "resnet50" in backbone: + # Calculate spatial size after stem, then ResNet50 downsampling + stem_spatial_h = ( + input_spatial_size[0] - stem_kernel_size[1] + ) // stem_stride[1] + 1 + stem_spatial_w = ( + input_spatial_size[1] - stem_kernel_size[2] + ) // stem_stride[2] + 1 + + # ResNet50 downsamples by 32x total, but stem already downsampled + total_downsample_factor = 32 + stem_downsample_factor = stem_stride[1] # Spatial downsampling from stem + resnet_downsample_factor = total_downsample_factor // stem_downsample_factor + final_h = stem_spatial_h // resnet_downsample_factor + final_w = stem_spatial_w // resnet_downsample_factor + flattened_size = out_channels_encoder * final_h * final_w + else: + raise ValueError( + f"Backbone {backbone} not supported for analytical calculation" + ) + + # Multi-layer perceptron for better representation learning + self.fc = nn.Linear(flattened_size, latent_dim) + self.fc_mu = nn.Linear(latent_dim, latent_dim) + self.fc_logvar = nn.Linear(latent_dim, latent_dim) + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """Reparameterization trick: sample from N(mu, var) using N(0,1).""" + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std def forward(self, x: Tensor) -> SimpleNamespace: """Forward pass returning VAE encoder outputs.""" @@ -154,22 +185,17 @@ def forward(self, x: Tensor) -> SimpleNamespace: # Take highest resolution features and flatten x = features[-1] # [B, C, H, W] + x_flat = x.flatten(1) # [B, C*H*W] - flatten from dim 1 onwards - # Flatten spatial dimensions - batch_size = x.size(0) - x_flat = x.view(batch_size, -1) # [B, C*H*W] - - # Initialize linear layers on first forward pass - if self.fc_mu is None: - flattened_size = x_flat.size(1) - self.fc_mu = nn.Linear(flattened_size, self.latent_dim).to(x.device) - self.fc_logvar = nn.Linear(flattened_size, self.latent_dim).to(x.device) + # Apply intermediate FC layer + x_intermediate = self.fc(x_flat) # [B, intermediate_dim] # Apply linear layers to get 1D embeddings - mu = self.fc_mu(x_flat) # [B, embedding_dim] - logvar = self.fc_logvar(x_flat) # [B, embedding_dim] + mu = self.fc_mu(x_intermediate) # [B, latent_dim] + logvar = self.fc_logvar(x_intermediate) # [B, latent_dim] + z = self.reparameterize(mu, logvar) # [B, latent_dim] - return SimpleNamespace(embedding=mu, log_covariance=logvar) + return SimpleNamespace(mean=mu, log_covariance=logvar, z=z) class VaeDecoder(nn.Module): @@ -188,6 +214,10 @@ def __init__( norm_name: str = "batch", upsample_pre_conv: Literal["default"] | Callable | None = None, strides: list[int] | None = None, + input_spatial_size: tuple[int, int] = ( + 128, + 128, + ), # Input size to calculate spatial dimensions ): super().__init__() self.out_channels = out_channels @@ -197,10 +227,8 @@ def __init__( (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio ) - # Copy decoder_channels to avoid modifying the original list decoder_channels_with_head = decoder_channels.copy() + [head_channels] - # Set optimal default strides for ResNet50 if not provided num_stages = len(decoder_channels_with_head) - 1 if strides is None: if ( @@ -218,9 +246,8 @@ def __init__( raise ValueError( f"Length of strides ({len(strides)}) must match number of stages ({num_stages})" ) - - # Store spatial dimensions for reshaping 1D latent back to spatial - self.spatial_size = 6 # Will be computed dynamically based on encoder output + # Calculate spatial size based on input dimensions and ResNet50 32x downsampling + self.spatial_size = input_spatial_size[0] // 32 # ResNet50 downsamples by 32x self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) # Project 1D latent to spatial format, then to first decoder channels @@ -261,10 +288,10 @@ def __init__( def forward(self, z: Tensor) -> Tensor: """Forward pass converting latent to 3D output.""" - # z is now 1D: [batch, latent_dim] + batch_size = z.size(0) - # Reshape 1D latent back to spatial format + # Reshape 1D latent back to spatial format so we can reconstruct the 2.5D image z_spatial = self.latent_reshape(z) # [batch, spatial_channels * H * W] z_spatial = z_spatial.view( batch_size, self.spatial_channels, self.spatial_size, self.spatial_size @@ -280,6 +307,5 @@ def forward(self, z: Tensor) -> Tensor: # Last stage outputs head_channels directly - no final_conv needed output = self.head(x) - output = torch.sigmoid(output) # Constrain to [0,1] to match normalized input return output diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 74a261f92..a5b67be24 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -24,9 +24,16 @@ class BetaVaeLogger: and latent space analysis for microscopy data. """ - def __init__(self, latent_dim: int = 128, device: str = "cuda"): + def __init__(self, latent_dim: int = 128): self.latent_dim = latent_dim - self.disentanglement_metrics = DisentanglementMetrics(device=device) + self.device = None + self.disentanglement_metrics = None + + def setup(self, device: str): + """Initialize device-dependent components.""" + self.device = device + if self.disentanglement_metrics is None: + self.disentanglement_metrics = DisentanglementMetrics(device=device) def log_enhanced_metrics( self, lightning_module, model_output: dict, batch: dict, stage: str = "train" @@ -42,21 +49,12 @@ def log_enhanced_metrics( """ # Extract components x = batch["anchor"] - # Handle both Pythae dict format and object format - if isinstance(model_output, dict): - z = model_output["z"] - recon_x = model_output["recon_x"] - recon_loss = model_output["recon_loss"] - kl_loss = model_output["kl_loss"] - else: - z = model_output.z if hasattr(model_output, "z") else model_output.embedding - recon_x = ( - model_output.recon_x - if hasattr(model_output, "recon_x") - else model_output.reconstruction - ) - recon_loss = model_output.recon_loss - kl_loss = model_output.kl_loss + + z = model_output["z"] + recon_x = model_output["recon_x"] + recon_loss = model_output["recon_loss"] + kl_loss = model_output["kl_loss"] + total_loss = model_output["total_loss"] # Get current β (scheduled value, not static) beta = getattr( @@ -66,22 +64,18 @@ def log_enhanced_metrics( )() # Record losses and reconstruction quality metrics - total_loss = recon_loss + beta * kl_loss kl_recon_ratio = kl_loss / (recon_loss + 1e-8) - mse_loss = F.mse_loss(recon_x, x) mae_loss = F.l1_loss(recon_x, x) # Add gradient explosion diagnostics grad_diagnostics = self._compute_gradient_diagnostics(lightning_module) - + # Add NaN/Inf detection nan_inf_diagnostics = self._check_nan_inf(recon_x, x, z) - - # Add shape diagnostics (log occasionally to avoid spam) - if lightning_module.current_epoch % 5 == 0: - self._log_tensor_shapes(lightning_module, x, recon_x, z, stage) - + + # Shape diagnostics removed for cleaner logs + metrics = { # All losses in one consolidated group f"loss/total/{stage}": total_loss, @@ -89,11 +83,11 @@ def log_enhanced_metrics( f"loss/kl/{stage}": kl_loss, f"loss/weighted_kl/{stage}": beta * kl_loss, f"loss/mae/{stage}": mae_loss, - f"loss/beta_value/{stage}": beta, + f"beta/{stage}": beta, f"loss/kl_recon_ratio/{stage}": kl_recon_ratio, f"loss/recon_contribution/{stage}": recon_loss / total_loss, } - + # Add diagnostic metrics metrics.update(grad_diagnostics) metrics.update(nan_inf_diagnostics) @@ -135,36 +129,38 @@ def log_enhanced_metrics( def _compute_gradient_diagnostics(self, lightning_module): """Compute gradient norms and parameter statistics for explosion detection.""" grad_diagnostics = {} - + # Compute gradient norms for encoder and decoder encoder_grad_norm = 0.0 decoder_grad_norm = 0.0 - encoder_param_norm = 0.0 + encoder_param_norm = 0.0 decoder_param_norm = 0.0 - + for name, param in lightning_module.named_parameters(): if param.grad is not None: param_norm = param.grad.data.norm(2) - if 'encoder' in name: + if "encoder" in name: encoder_grad_norm += param_norm.item() ** 2 - elif 'decoder' in name: + elif "decoder" in name: decoder_grad_norm += param_norm.item() ** 2 - + # Parameter magnitudes - if 'encoder' in name: + if "encoder" in name: encoder_param_norm += param.data.norm(2).item() ** 2 - elif 'decoder' in name: + elif "decoder" in name: decoder_param_norm += param.data.norm(2).item() ** 2 - - grad_diagnostics.update({ - "diagnostics/encoder_grad_norm": encoder_grad_norm ** 0.5, - "diagnostics/decoder_grad_norm": decoder_grad_norm ** 0.5, - "diagnostics/encoder_param_norm": encoder_param_norm ** 0.5, - "diagnostics/decoder_param_norm": decoder_param_norm ** 0.5, - }) - + + grad_diagnostics.update( + { + "diagnostics/encoder_grad_norm": encoder_grad_norm**0.5, + "diagnostics/decoder_grad_norm": decoder_grad_norm**0.5, + "diagnostics/encoder_param_norm": encoder_param_norm**0.5, + "diagnostics/decoder_param_norm": decoder_param_norm**0.5, + } + ) + return grad_diagnostics - + def _check_nan_inf(self, recon_x, x, z): """Check for NaN/Inf values in tensors.""" diagnostics = { @@ -176,24 +172,6 @@ def _check_nan_inf(self, recon_x, x, z): "diagnostics/recon_min_val": torch.min(recon_x), } return diagnostics - - def _log_tensor_shapes(self, lightning_module, x, recon_x, z, stage): - """Log tensor shapes to help diagnose architectural mismatches.""" - _logger.info(f"[{stage}] Input shape: {x.shape}") - _logger.info(f"[{stage}] Latent shape: {z.shape}") - _logger.info(f"[{stage}] Reconstruction shape: {recon_x.shape}") - - # Check for shape mismatches - if x.shape != recon_x.shape: - _logger.warning(f"SHAPE MISMATCH: Input {x.shape} != Reconstruction {recon_x.shape}") - - # Log as scalars for TensorBoard tracking - lightning_module.log_dict({ - f"shapes/input_numel_{stage}": x.numel(), - f"shapes/recon_numel_{stage}": recon_x.numel(), - f"shapes/latent_numel_{stage}": z.numel(), - f"shapes/spatial_dims_{stage}": len(z.shape) - 2, # Exclude batch and channel dims - }, on_step=False, on_epoch=True, logger=True) def _log_latent_histograms(self, lightning_module, z: torch.Tensor, stage: str): """Log histograms of latent dimensions.""" @@ -263,7 +241,10 @@ def log_latent_traversal( ) lightning_module.logger.experiment.add_image( - f"latent_traversal/dim_{dim}", grid, lightning_module.current_epoch + f"latent_traversal/dim_{dim}", + grid, + lightning_module.current_epoch, + dataformats="CHW", ) def log_latent_interpolation( @@ -317,6 +298,7 @@ def log_latent_interpolation( f"latent_interpolation/pair_{pair_idx}", grid, lightning_module.current_epoch, + dataformats="CHW", ) def log_factor_traversal_matrix( @@ -369,7 +351,10 @@ def log_factor_traversal_matrix( grid = make_grid(all_images.unsqueeze(1), nrow=n_steps, normalize=True) lightning_module.logger.experiment.add_image( - "factor_traversal_matrix", grid, lightning_module.current_epoch + "factor_traversal_matrix", + grid, + lightning_module.current_epoch, + dataformats="CHW", ) def log_latent_space_visualization( @@ -450,7 +435,10 @@ def log_latent_space_visualization( img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) / 255.0 lightning_module.logger.experiment.add_image( - f"latent_space_{method}", img_tensor, lightning_module.current_epoch + f"latent_space_{method}", + img_tensor, + lightning_module.current_epoch, + dataformats="CHW", ) plt.close() From cfdc51ae8642497eeeb68ad9c492f5391ae89213 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 25 Jul 2025 15:59:15 -0700 Subject: [PATCH 017/101] optunea proof of concept --- pyproject.toml | 7 +- viscy/scripts/optimization/__init__.py | 3 + viscy/scripts/optimization/optuna_utils.py | 415 ++++++++++++++++++ .../optimization/optuna_vae_parallel.sh | 53 +++ .../scripts/optimization/optuna_vae_search.py | 230 ++++++++++ .../scripts/optimization/optuna_vae_slurm.sh | 47 ++ 6 files changed, 754 insertions(+), 1 deletion(-) create mode 100644 viscy/scripts/optimization/__init__.py create mode 100644 viscy/scripts/optimization/optuna_utils.py create mode 100644 viscy/scripts/optimization/optuna_vae_parallel.sh create mode 100644 viscy/scripts/optimization/optuna_vae_search.py create mode 100644 viscy/scripts/optimization/optuna_vae_slurm.sh diff --git a/pyproject.toml b/pyproject.toml index 6d0fb1e17..b97ce711b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,11 @@ metrics = [ phate = [ "phate", ] +optimization = [ + "optuna", + "optuna-dashboard", +] + examples = ["napari", "jupyter", "jupytext", "transformers>=4.51.3"] visual = [ "ipykernel", @@ -56,7 +61,7 @@ visual = [ "dash", ] dev = [ - "viscy[metrics,phate,examples,visual]", + "viscy[metrics,phate,examples,visual,optimization]", "pytest", "pytest-cov", "hypothesis", diff --git a/viscy/scripts/optimization/__init__.py b/viscy/scripts/optimization/__init__.py new file mode 100644 index 000000000..57b5ee39e --- /dev/null +++ b/viscy/scripts/optimization/__init__.py @@ -0,0 +1,3 @@ +""" +Optimization scripts for hyperparameter tuning using Optuna and other methods. +""" \ No newline at end of file diff --git a/viscy/scripts/optimization/optuna_utils.py b/viscy/scripts/optimization/optuna_utils.py new file mode 100644 index 000000000..91c9f5369 --- /dev/null +++ b/viscy/scripts/optimization/optuna_utils.py @@ -0,0 +1,415 @@ +import glob +import os +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def extract_tensorboard_metric( + log_dir: str, metric_name: str = "loss/total/val", aggregation: str = "min" +) -> float: + """ + Extract a metric from TensorBoard logs. + + Args: + log_dir: Path to the directory containing TensorBoard logs + metric_name: Name of the metric to extract (e.g., "loss/total/val") + aggregation: How to aggregate the metric values ("min", "max", "last", "mean") + + Returns: + The aggregated metric value, or float('inf') if extraction fails + + Examples: + >>> extract_tensorboard_metric("./logs/version_1", "loss/total/val", "min") + 0.234567 + + >>> extract_tensorboard_metric("./logs/version_1", "accuracy", "max") + 0.891234 + """ + try: + # Find the events file + events_files = list(Path(log_dir).glob("events.out.tfevents.*")) + if not events_files: + print(f"Warning: No events file found in {log_dir}") + return float("inf") + + # Load TensorBoard data + ea = EventAccumulator(str(events_files[0])) + ea.Reload() + + # Extract metric + if metric_name in ea.Tags()["scalars"]: + values = np.array([scalar.value for scalar in ea.Scalars(metric_name)]) + + if aggregation == "min": + result = float(np.min(values)) + elif aggregation == "max": + result = float(np.max(values)) + elif aggregation == "last": + result = float(values[-1]) + elif aggregation == "mean": + result = float(np.mean(values)) + else: + raise ValueError(f"Unknown aggregation: {aggregation}") + + print(f"Extracted {metric_name} ({aggregation}): {result:.6f}") + return result + else: + print(f"Warning: Metric '{metric_name}' not found in {log_dir}") + available_metrics = ea.Tags()["scalars"] + print(f"Available metrics: {available_metrics}") + return float("inf") + + except Exception as e: + print(f"Error extracting {metric_name} from {log_dir}: {e}") + return float("inf") + + +def modify_config( + base_config_path: str, + modifications: Dict[str, Any], + output_path: Optional[str] = None, +) -> str: + """ + Modify a YAML configuration file with new parameter values. + + Supports nested key modification using dot notation (e.g., "model.init_args.beta"). + + Args: + base_config_path: Path to the base configuration file + modifications: Dictionary with nested keys to modify + e.g., {"model.init_args.beta": 10, "trainer.max_epochs": 50} + output_path: Where to save the modified config (if None, creates temp file) + + Returns: + Path to the modified configuration file + + Examples: + >>> modify_config("base.yml", {"model.init_args.lr": 1e-3}, "modified.yml") + "modified.yml" + + >>> temp_path = modify_config("base.yml", {"trainer.max_epochs": 100}) + >>> # Returns path to temporary file + """ + # Load base config + with open(base_config_path, "r") as f: + config = yaml.safe_load(f) + + # Apply modifications + for key_path, value in modifications.items(): + keys = key_path.split(".") + current = config + + # Navigate to the nested dictionary + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + + # Set the final value + current[keys[-1]] = value + + # Save modified config + if output_path is None: + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) + output_path = temp_file.name + temp_file.close() + + with open(output_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + return output_path + + +def run_lightning_training( + config_path: str, + working_dir: str = ".", + timeout: int = 3600, + capture_output: bool = True, +) -> subprocess.CompletedProcess: + """ + Run Lightning training with the given configuration. + + Args: + config_path: Path to the configuration file + working_dir: Working directory for the training process + timeout: Timeout in seconds (default: 1 hour) + capture_output: Whether to capture stdout/stderr + + Returns: + CompletedProcess object with training results + + Examples: + >>> result = run_lightning_training("config.yml", timeout=1800) + >>> if result.returncode == 0: + ... print("Training completed successfully") + """ + cmd = ["python", "-m", "viscy.cli.train", "fit", "--config", config_path] + + print(f"Running command: {' '.join(cmd)}") + + return subprocess.run( + cmd, cwd=working_dir, capture_output=capture_output, text=True, timeout=timeout + ) + + +def suggest_hyperparameters( + trial, param_config: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: + """ + Suggest hyperparameters based on a configuration dictionary. + + Supports different parameter types with flexible configuration options. + + Args: + trial: Optuna trial object + param_config: Configuration for parameters with format: + { + "param_name": { + "type": "float" | "int" | "categorical", + "low": , # for float/int + "high": , # for float/int + "choices": [], # for categorical + "log": True/False, # for float/int (optional) + "step": # for int (optional) + } + } + + Returns: + Dictionary of suggested parameter values + + Examples: + >>> param_config = { + ... "lr": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True}, + ... "batch_size": {"type": "categorical", "choices": [32, 64, 128]}, + ... "epochs": {"type": "int", "low": 10, "high": 100, "step": 10} + ... } + >>> params = suggest_hyperparameters(trial, param_config) + >>> # Returns: {"lr": 0.0001234, "batch_size": 64, "epochs": 50} + """ + params = {} + + for param_name, config in param_config.items(): + param_type = config["type"] + + if param_type == "float": + log_scale = config.get("log", False) + params[param_name] = trial.suggest_float( + param_name, config["low"], config["high"], log=log_scale + ) + elif param_type == "int": + step = config.get("step", 1) + params[param_name] = trial.suggest_int( + param_name, config["low"], config["high"], step=step + ) + elif param_type == "categorical": + params[param_name] = trial.suggest_categorical( + param_name, config["choices"] + ) + else: + raise ValueError(f"Unknown parameter type: {param_type}") + + return params + + +def create_study_with_defaults( + study_name: str, + storage_url: str, + direction: str = "minimize", + sampler_name: str = "TPE", + pruner_name: str = "Median", + sampler_kwargs: Optional[Dict[str, Any]] = None, + pruner_kwargs: Optional[Dict[str, Any]] = None, +): + """ + Create an Optuna study with commonly used samplers and pruners. + + Args: + study_name: Name of the study + storage_url: Storage URL (e.g., "sqlite:///study.db") + direction: Optimization direction ("minimize" or "maximize") + sampler_name: Sampler type ("TPE", "Random", "CmaEs") + pruner_name: Pruner type ("Median", "Hyperband", "None") + sampler_kwargs: Additional sampler arguments (e.g., {"seed": 42}) + pruner_kwargs: Additional pruner arguments (e.g., {"n_startup_trials": 5}) + + Returns: + Optuna study object + + Examples: + >>> study = create_study_with_defaults( + ... "vae_optimization", + ... "sqlite:///vae_study.db", + ... sampler_kwargs={"seed": 42} + ... ) + >>> study.optimize(objective, n_trials=100) + """ + import optuna + + # Set up sampler + sampler_kwargs = sampler_kwargs or {} + if sampler_name == "TPE": + sampler = optuna.samplers.TPESampler(**sampler_kwargs) + elif sampler_name == "Random": + sampler = optuna.samplers.RandomSampler(**sampler_kwargs) + elif sampler_name == "CmaEs": + sampler = optuna.samplers.CmaEsSampler(**sampler_kwargs) + else: + raise ValueError(f"Unknown sampler: {sampler_name}") + + # Set up pruner + pruner_kwargs = pruner_kwargs or {} + if pruner_name == "Median": + pruner = optuna.pruners.MedianPruner(**pruner_kwargs) + elif pruner_name == "Hyperband": + pruner = optuna.pruners.HyperbandPruner(**pruner_kwargs) + elif pruner_name == "None": + pruner = optuna.pruners.NopPruner() + else: + raise ValueError(f"Unknown pruner: {pruner_name}") + + return optuna.create_study( + study_name=study_name, + storage=storage_url, + direction=direction, + load_if_exists=True, + sampler=sampler, + pruner=pruner, + ) + + +def save_best_config( + study, + base_config_path: str, + output_path: str, + param_mapping: Dict[str, str], + additional_modifications: Optional[Dict[str, Any]] = None, +) -> None: + """ + Save the best configuration found by Optuna to a file. + + Args: + study: Completed Optuna study + base_config_path: Path to the base configuration file + output_path: Where to save the best configuration + param_mapping: Mapping from Optuna parameter names to config keys + e.g., {"beta": "model.init_args.beta", "lr": "model.init_args.lr"} + additional_modifications: Additional modifications to apply to the config + e.g., {"trainer.max_epochs": 300, "model.init_args.loss_function.init_args.reduction": "mean"} + + Examples: + >>> param_mapping = { + ... "beta": "model.init_args.beta", + ... "lr": "model.init_args.lr" + ... } + >>> additional_mods = {"trainer.max_epochs": 300} + >>> save_best_config(study, "base.yml", "best.yml", param_mapping, additional_mods) + """ + if study.best_trial is None: + print("No best trial found") + return + + # Create modifications dictionary + modifications = {} + for optuna_param, config_key in param_mapping.items(): + if optuna_param in study.best_params: + modifications[config_key] = study.best_params[optuna_param] + + # Add additional modifications + if additional_modifications: + modifications.update(additional_modifications) + + # Create and save the best configuration + modify_config(base_config_path, modifications, output_path) + + print(f"Best configuration saved to: {output_path}") + print(f"Best value: {study.best_value:.6f}") + print("Best parameters:") + for key, value in study.best_params.items(): + print(f" {key}: {value}") + + +def cleanup_temp_files(file_patterns: List[str], working_dir: str = ".") -> None: + """ + Clean up temporary files matching the given patterns. + + Uses glob patterns to match files for deletion. Handles both files and + directories safely. + + Args: + file_patterns: List of glob patterns for files to delete + e.g., ["trial_*.yml", "temp_*", "*.tmp"] + working_dir: Directory to search in (default: current directory) + + Examples: + >>> cleanup_temp_files(["trial_*.yml", "temp_logs_*"]) + Removed: trial_1.yml + Removed: trial_2.yml + Removed: temp_logs_experiment1 + """ + for pattern in file_patterns: + files = glob.glob(os.path.join(working_dir, pattern)) + for file_path in files: + try: + if os.path.isfile(file_path): + os.remove(file_path) + print(f"Removed: {file_path}") + elif os.path.isdir(file_path): + import shutil + + shutil.rmtree(file_path) + print(f"Removed directory: {file_path}") + except Exception as e: + print(f"Failed to remove {file_path}: {e}") + + +def validate_config_modifications( + base_config_path: str, modifications: Dict[str, Any] +) -> bool: + """ + Validate that configuration modifications are applicable to the base config. + + Checks if the nested keys exist in the base configuration structure. + + Args: + base_config_path: Path to the base configuration file + modifications: Dictionary of modifications to validate + + Returns: + True if all modifications are valid, False otherwise + + Examples: + >>> modifications = {"model.init_args.beta": 10, "invalid.key": 5} + >>> validate_config_modifications("config.yml", modifications) + False # because "invalid.key" doesn't exist in base config + """ + try: + with open(base_config_path, "r") as f: + config = yaml.safe_load(f) + + for key_path in modifications.keys(): + keys = key_path.split(".") + current = config + + # Check if nested path exists + for key in keys[:-1]: + if not isinstance(current, dict) or key not in current: + print(f"Invalid key path: {key_path} (missing: {key})") + return False + current = current[key] + + # Check final key (it's ok if it doesn't exist, we'll create it) + if not isinstance(current, dict): + print(f"Invalid key path: {key_path} (parent is not dict)") + return False + + return True + + except Exception as e: + print(f"Error validating config modifications: {e}") + return False diff --git a/viscy/scripts/optimization/optuna_vae_parallel.sh b/viscy/scripts/optimization/optuna_vae_parallel.sh new file mode 100644 index 000000000..4ddbaddee --- /dev/null +++ b/viscy/scripts/optimization/optuna_vae_parallel.sh @@ -0,0 +1,53 @@ +#!/bin/bash +#SBATCH --job-name=optuna_vae_parallel +#SBATCH --output=optuna_vae_parallel_%A_%a.out +#SBATCH --error=optuna_vae_parallel_%A_%a.err +#SBATCH --time=12:00:00 +#SBATCH --partition=gpu +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=12 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-cpu=6G +#SBATCH --array=1-4 # Run 4 parallel workers + + +# Print job info +echo "Array Job ID: $SLURM_ARRAY_JOB_ID" +echo "Array Task ID: $SLURM_ARRAY_TASK_ID" +echo "Node: $SLURM_NODELIST" +echo "Start Time: $(date)" + +# Change to repo directory +module load anaconda/25.3.1 +conda activate viscy + +# Load environment +OPTUNA_SCRIPT='/home/eduardo.hirata/repos/viscy/viscy/scripts/optimization/optuna_vae_search.py' + +# Shared storage for Optuna study (all workers use same database) +SHARED_DB="/hpc/projects/organelle_phenotyping/models/SEC61B/vae/optuna_results/optuna_parallel_study.db" +OUTPUT_DIR="/hpc/projects/organelle_phenotyping/models/SEC61B/vae//optuna_results/parallel_job_${SLURM_ARRAY_JOB_ID}" +mkdir -p $OUTPUT_DIR + +# Each worker runs a portion of trials +TRIALS_PER_WORKER=15 # 4 workers × 15 trials = 60 total trials + +echo "Worker $SLURM_ARRAY_TASK_ID starting $TRIALS_PER_WORKER trials..." + +# Run Optuna optimization (all workers share the same database) +python \ + --storage_url "sqlite:///$pp" \ + --n_trials $TRIALS_PER_WORKER \ + --timeout 43200 \ + --study_name "vae_parallel_optimization" + +# Only the first worker saves the final results +if [ $SLURM_ARRAY_TASK_ID -eq 1 ]; then + echo "Worker 1 saving final results..." + sleep 60 # Wait for other workers to finish + cp $SHARED_DB $OUTPUT_DIR/ + cp best_vae_config.yml $OUTPUT_DIR/ 2>/dev/null || echo "No best config generated yet" +fi + +echo "Worker $SLURM_ARRAY_TASK_ID completed at: $(date)" \ No newline at end of file diff --git a/viscy/scripts/optimization/optuna_vae_search.py b/viscy/scripts/optimization/optuna_vae_search.py new file mode 100644 index 000000000..34d27c9ef --- /dev/null +++ b/viscy/scripts/optimization/optuna_vae_search.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Optuna hyperparameter optimization for VAE training with PyTorch Lightning. +""" + +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict + +import click +import optuna +import torch +import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def extract_best_val_loss(log_dir: str) -> float: + """Extract the best validation loss from TensorBoard logs.""" + try: + # Find the events file + events_files = list(Path(log_dir).glob("events.out.tfevents.*")) + if not events_files: + print(f"Warning: No events file found in {log_dir}") + return float("inf") + + # Load TensorBoard data + ea = EventAccumulator(str(events_files[0])) + ea.Reload() + + # Extract validation loss + if "loss/total/val" in ea.Tags()["scalars"]: + val_losses = ea.Scalars("loss/total/val") + best_val_loss = min([scalar.value for scalar in val_losses]) + print(f"Best validation loss: {best_val_loss:.6f}") + return best_val_loss + else: + print(f"Warning: No validation loss found in {log_dir}") + return float("inf") + + except Exception as e: + print(f"Error extracting validation loss from {log_dir}: {e}") + return float("inf") + + +def create_trial_config( + base_config_path: str, trial: optuna.Trial, trial_dir: Path +) -> str: + """Create a modified config file for the current trial.""" + + # Load base config + with open(base_config_path, "r") as f: + config = yaml.safe_load(f) + + # Sample hyperparameters + beta = trial.suggest_float("beta", 0.1, 50.0, log=True) + lr = trial.suggest_float("lr", 5e-5, 5e-3, log=True) + warmup_epochs = trial.suggest_int("warmup_epochs", 10, 100) + latent_dim = trial.suggest_categorical("latent_dim", [512, 1024, 2048]) + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) + + # Modify model config + config["model"]["init_args"]["beta"] = beta + config["model"]["init_args"]["lr"] = lr + config["model"]["init_args"]["beta_warmup_epochs"] = warmup_epochs + + # Modify data config + config["data"]["init_args"]["batch_size"] = batch_size + + # Reduce training for faster search + config["trainer"]["max_epochs"] = 30 + config["trainer"]["check_val_every_n_epoch"] = 2 + + # Set unique logging directory + config["trainer"]["logger"]["init_args"]["save_dir"] = str(trial_dir) + config["trainer"]["logger"]["init_args"]["version"] = f"trial_{trial.number}" + + # Fix loss function to use mean reduction + config["model"]["init_args"]["loss_function"]["init_args"]["reduction"] = "mean" + + # Save trial config + trial_config_path = trial_dir / f"trial_{trial.number}_config.yml" + with open(trial_config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + print( + f"Trial {trial.number} params: beta={beta:.4f}, lr={lr:.2e}, " + f"warmup={warmup_epochs}, latent={latent_dim}, batch={batch_size}" + ) + + return str(trial_config_path) + + +def objective(trial: optuna.Trial) -> float: + """Optuna objective function.""" + + # Create temporary directory for this trial + with tempfile.TemporaryDirectory( + prefix=f"optuna_trial_{trial.number}_" + ) as temp_dir: + trial_dir = Path(temp_dir) + + try: + # Create trial config + base_config = "/hpc/projects/organelle_phenotyping/models/SEC61B/vae/fit_phase_only.yml" + trial_config_path = create_trial_config(base_config, trial, trial_dir) + + # Run training + cmd = [ + "python", + "-m", + "viscy.cli.train", + "fit", + "--config", + trial_config_path, + ] + + print(f"Running trial {trial.number}: {' '.join(cmd)}") + + # Run with timeout to prevent hanging + result = subprocess.run( + cmd, + cwd="/hpc/mydata/eduardo.hirata/repos/viscy", + capture_output=True, + text=True, + timeout=3600, # 1 hour timeout + ) + + if result.returncode != 0: + print( + f"Trial {trial.number} failed with return code {result.returncode}" + ) + print(f"STDERR: {result.stderr}") + return float("inf") + + # Extract validation loss + log_dir = trial_dir / f"trial_{trial.number}" + val_loss = extract_best_val_loss(str(log_dir)) + + print(f"Trial {trial.number} completed with val_loss: {val_loss:.6f}") + return val_loss + + except subprocess.TimeoutExpired: + print(f"Trial {trial.number} timed out") + return float("inf") + except Exception as e: + print(f"Trial {trial.number} failed with error: {e}") + return float("inf") + + +def main(): + """Main optimization loop.""" + + # Set up study + study_name = "vae_hyperparameter_optimization" + storage_url = f"sqlite:///optuna_vae_study.db" + + study = optuna.create_study( + study_name=study_name, + storage=storage_url, + direction="minimize", + load_if_exists=True, # Resume if study exists + sampler=optuna.samplers.TPESampler(seed=42), + pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10), + ) + + print( + f"Starting Optuna optimization with {torch.cuda.device_count()} GPUs available" + ) + print(f"Study storage: {storage_url}") + + try: + # Run optimization + study.optimize(objective, n_trials=50, timeout=24 * 3600) # 24 hour timeout + + # Print results + print("\nOptimization completed!") + print(f"Best trial: {study.best_trial.number}") + print(f"Best value: {study.best_value:.6f}") + print("Best params:") + for key, value in study.best_params.items(): + print(f" {key}: {value}") + + # Save best config + best_config_path = "best_vae_config.yml" + base_config = ( + "/hpc/projects/organelle_phenotyping/models/SEC61B/vae/fit_phase_only.yml" + ) + + with open(base_config, "r") as f: + config = yaml.safe_load(f) + + # Apply best parameters + best_params = study.best_params + config["model"]["init_args"]["beta"] = best_params["beta"] + config["model"]["init_args"]["lr"] = best_params["lr"] + config["model"]["init_args"]["beta_warmup_epochs"] = best_params[ + "warmup_epochs" + ] + config["model"]["init_args"]["encoder"]["init_args"]["latent_dim"] = ( + best_params["latent_dim"] + ) + config["model"]["init_args"]["decoder"]["init_args"]["latent_dim"] = ( + best_params["latent_dim"] + ) + config["data"]["init_args"]["batch_size"] = best_params["batch_size"] + config["model"]["init_args"]["loss_function"]["init_args"]["reduction"] = "mean" + + # Restore full training settings + config["trainer"]["max_epochs"] = 300 + config["trainer"]["check_val_every_n_epoch"] = 1 + + with open(best_config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + print(f"Best configuration saved to: {best_config_path}") + + except KeyboardInterrupt: + print("\nOptimization interrupted by user") + print( + f"Current best trial: {study.best_trial.number if study.best_trial else 'None'}" + ) + if study.best_trial: + print(f"Current best value: {study.best_value:.6f}") + + +if __name__ == "__main__": + main() diff --git a/viscy/scripts/optimization/optuna_vae_slurm.sh b/viscy/scripts/optimization/optuna_vae_slurm.sh new file mode 100644 index 000000000..0394aa0db --- /dev/null +++ b/viscy/scripts/optimization/optuna_vae_slurm.sh @@ -0,0 +1,47 @@ +#!/bin/bash +#SBATCH --job-name=optuna_vae_search +#SBATCH --output=/slurm_out/optuna_vae_%j.out +#SBATCH --error=/slurm_out/optuna_vae_%j.err +#SBATCH --time=0-20:00:00 +#SBATCH --partition=gpu +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=12 +#SBATCH --gres=gpu:1 +#SBATCH --mem=64G + +# Print job info +echo "Job ID: $SLURM_JOB_ID" +echo "Job Name: $SLURM_JOB_NAME" +echo "Node: $SLURM_NODELIST" +echo "Start Time: $(date)" + +# Load modules/environment +module load anaconda/25.3.1 +conda activate viscy + +# Set environment variables +export CUDA_VISIBLE_DEVICES=0 +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK + +# Change to repo directory +ROOT_DIR="/home/eduardo.hirata/repos/viscy" +cd $ROOT_DIR + +# Create output directory for this job +OUTPUT_DIR="/hpc/projects/organelle_phenotyping/models/SEC61B/vae/optuna_results/job_${SLURM_JOB_ID}" +mkdir -p $OUTPUT_DIR + +# Run Optuna optimization +echo "Starting Optuna VAE hyperparameter search..." +python viscy/scripts/optimization/optuna_vae_search.py \ + --output_dir $OUTPUT_DIR \ + --n_trials 50 \ + --timeout 86400 + +# Copy results to output directory +cp optuna_vae_study.db $OUTPUT_DIR/ +cp best_vae_config.yml $OUTPUT_DIR/ 2>/dev/null || echo "No best config generated yet" + +echo "Job completed at: $(date)" +echo "Results saved to: $OUTPUT_DIR" \ No newline at end of file From bcc14063f9bfbaec5e513885dda506240133f8c6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 26 Jul 2025 14:19:30 -0700 Subject: [PATCH 018/101] add normalized sampled into the transforms so we can use it with MONAIs vae --- viscy/transforms/__init__.py | 2 ++ viscy/transforms/_redef.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 12177b64b..495069184 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,5 +1,6 @@ from viscy.transforms._redef import ( CenterSpatialCropd, + NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, @@ -34,4 +35,5 @@ "ScaleIntensityRangePercentilesd", "StackChannelsd", "TiledSpatialCropSamplesd", + "NormalizeIntensityd", ] diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index fe168603a..d79a4aff9 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -4,6 +4,7 @@ from monai.transforms import ( CenterSpatialCropd, + NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, @@ -17,6 +18,15 @@ from numpy.typing import DTypeLike +class NormalizeIntensityd(NormalizeIntensityd): + def __init__( + self, + keys: Sequence[str] | str, + **kwargs, + ): + super().__init__(keys=keys, **kwargs) + + class RandWeightedCropd(RandWeightedCropd): def __init__( self, From 53a3e2d560fea292b1b3e0719c056edfb5374490 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 26 Jul 2025 14:21:15 -0700 Subject: [PATCH 019/101] update loss debugging code --- .../DynaCLR/BetaVAE/debug_dimensions.py | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py index aacedfe88..742ad5719 100644 --- a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py +++ b/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py @@ -9,20 +9,29 @@ def debug_vae_dimensions(): print("=== VAE Dimension Debugging (Updated Architecture) ===\n") # Configuration matching current config - z_stack_depth = 16 - input_shape = (1, 1, z_stack_depth, 192, 192) # 1 channel to match config + z_stack_depth = 8 + input_shape = (1, 1, z_stack_depth, 128, 128) # 1 channel to match config latent_dim = 1024 # Updated to new default print(f"Input shape: {input_shape}") print(f"Expected latent dim: {latent_dim}") print() + # Debug encoder channel expectations + import timm + + debug_encoder = timm.create_model("resnet50", pretrained=False, features_only=True) + print(f"ResNet50 conv1.out_channels: {debug_encoder.conv1.out_channels}") + print(f"ResNet50 expects input channels: {debug_encoder.conv1.in_channels}") + print() + # Create encoder encoder = VaeEncoder( backbone="resnet50", in_channels=1, in_stack_depth=z_stack_depth, latent_dim=latent_dim, + input_spatial_size=(128, 128), # Match the actual input size stem_kernel_size=(4, 2, 2), stem_stride=(4, 2, 2), ) @@ -39,6 +48,10 @@ def debug_vae_dimensions(): conv_blocks=2, norm_name="batch", strides=[2, 2, 2, 1], + input_spatial_size=( + 128, + 128, + ), # Add input spatial size for correct spatial_size calculation ) print("=== ENCODER FORWARD PASS ===") @@ -63,22 +76,26 @@ def debug_vae_dimensions(): print(f" Final features: {x_final.shape}") # Flatten spatial dimensions (new approach) - batch_size = x_final.size(0) - x_flat = x_final.view(batch_size, -1) + x_flat = x_final.flatten(1) # Use flatten(1) like updated code print(f" After flatten: {x_flat.shape}") + print("\\n3b. Intermediate FC layer:") + # Test intermediate FC layer (new addition) + x_intermediate = encoder.fc(x_flat) + print(f" After intermediate FC: {x_intermediate.shape}") + # Full encoder output encoder_output = encoder(x) - mu = encoder_output.embedding + mu = encoder_output.mean logvar = encoder_output.log_covariance + z = encoder_output.z print(f" Final mu: {mu.shape}") print(f" Final logvar: {logvar.shape}") + print(f" Sampled z: {z.shape}") print("\\n=== DECODER FORWARD PASS ===") - # Test decoder with latent vector - z = torch.randn(1, latent_dim) - print(f"Input to decoder: {z.shape}") + print(f"Input to decoder (sampled z): {z.shape}") print("\\n1. Reshape to spatial:") batch_size = z.size(0) @@ -155,9 +172,11 @@ def debug_vae_dimensions(): z_sampled = mu + torch.exp(0.5 * logvar) * eps print(f"Sampled latent z: {z_sampled.shape}") - # Decode the sampled latent - reconstruction_from_sampled = decoder(z_sampled) - print(f"Reconstruction from sampled z: {reconstruction_from_sampled.shape}") + # Use the z from encoder (already sampled) + reconstruction_from_sampled = decoder(z) + print( + f"Reconstruction from encoder's sampled z: {reconstruction_from_sampled.shape}" + ) # Compute VAE losses import torch.nn.functional as F @@ -243,7 +262,7 @@ def debug_vae_dimensions(): final_feat = features[-1] print(f"Final feature shape: {final_feat.shape}") - flattened_size = final_feat.view(1, -1).shape[1] + flattened_size = final_feat.flatten(1).shape[1] print(f"Flattened size: {flattened_size:,}") print(f"Expected latent dim: {latent_dim:,}") From a3510d0cc39efb777e131e3b834d8930426d0483 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 26 Jul 2025 22:21:51 -0700 Subject: [PATCH 020/101] adding sync for disentaglement metrics --- viscy/representation/vae_logging.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index a5b67be24..b97042bfb 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -478,6 +478,7 @@ def log_disentanglement_metrics( lightning_module, dataloader: torch.utils.data.DataLoader, max_samples: int = 500, + sync_dist: bool = True, ): """ Log disentanglement metrics to TensorBoard every 10 epochs. @@ -501,7 +502,10 @@ def log_disentanglement_metrics( # Compute all disentanglement metrics metrics = self.disentanglement_metrics.compute_all_metrics( - vae_model=vae_model, dataloader=dataloader, max_samples=max_samples + vae_model=vae_model, + dataloader=dataloader, + max_samples=max_samples, + sync_dist=sync_dist, ) # Log metrics with organized naming @@ -514,7 +518,7 @@ def log_disentanglement_metrics( on_step=False, on_epoch=True, logger=True, - sync_dist=True, + sync_dist=sync_dist, ) _logger.info(f"Logged disentanglement metrics: {metrics}") @@ -528,4 +532,5 @@ def log_disentanglement_metrics( on_step=False, on_epoch=True, logger=True, + sync_dist=sync_dist, ) From 252a4d0c2d4481f302869526ed626a9ffd8cfa30 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 31 Jul 2025 10:26:30 -0700 Subject: [PATCH 021/101] adding the dataloader for rpe1 dataset and plotting utils --- .../rpe1_fucci/linear_classifier.py | 142 +++++++ .../evaluation/rpe1_fucci/phate_plot.py | 172 ++++++++ viscy/data/cell_division_triplet.py | 395 ++++++++++++++++++ 3 files changed, 709 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py create mode 100644 applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py create mode 100644 viscy/data/cell_division_triplet.py diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py new file mode 100644 index 000000000..d3b984d6e --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py @@ -0,0 +1,142 @@ +# %% +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, classification_report +from sklearn.model_selection import train_test_split + +from viscy.representation.embedding_writer import read_embedding_dataset + +test_data_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_data_ckpt264.zarr" +) +cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/pseudolabels/cell_cycle_labels.csv" + +# %% +# Load the data +cell_cycle_labels_df = pd.read_csv(cell_cycle_labels_path, dtype={"dataset_name": str}) +test_embeddings = read_embedding_dataset(test_data_features_path) + +# Extract features (768-dimensional embeddings) +features = test_embeddings.features.values + +# %% +# Create a combined identifier for matching +# The sample coordinate contains (fov_name, id) tuples +sample_coords = test_embeddings.coords["sample"].values +fov_names = [coord[0] for coord in sample_coords] +ids = [coord[1] for coord in sample_coords] + +# Create DataFrame with embeddings and identifiers +embedding_df = pd.DataFrame( + { + "dataset_name": fov_names, + "timepoint": ids, + } +) + +# Merge with cell cycle labels +merged_data = embedding_df.merge( + cell_cycle_labels_df, on=["dataset_name", "timepoint"], how="inner" +) + +print(f"Original embeddings: {len(embedding_df)}") +print(f"Cell cycle labels: {len(cell_cycle_labels_df)}") +print(f"Merged data: {len(merged_data)}") +print(f"Cell cycle distribution:\n{merged_data['cell_cycle_state'].value_counts()}") + +# Get corresponding features for merged samples +merged_indices = merged_data.index.values +X = features[merged_indices] +y = merged_data["cell_cycle_state"].values + +# %% +# First split: 80% train+val, 20% test +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) +print(f"Training set: {X_train.shape[0]} samples") +print(f"Test set: {X_test.shape[0]} samples") + +# %% +# Train logistic regression model +clf = LogisticRegression(random_state=42, max_iter=1000) +clf.fit(X_train, y_train) + +y_test_pred = clf.predict(X_test) +test_accuracy = accuracy_score(y_test, y_test_pred) +print(f"Test accuracy: {test_accuracy:.4f}") + +print("\nTest set classification report:") +print(classification_report(y_test, y_test_pred)) + +# %% +# Enhanced evaluation and visualization +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay + +# 1. Confusion Matrix - shows which classes are confused with each other +cm = confusion_matrix(y_test, y_test_pred) +plt.figure(figsize=(8, 6)) +ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S"]).plot(cmap="Blues") +plt.title("Confusion Matrix") +plt.show() + +# 2. Per-class errors breakdown +print("\nDetailed per-class analysis:") +for class_name in ["G1", "G2", "S"]: + mask = y_test == class_name + correct = (y_test_pred[mask] == class_name).sum() + total = mask.sum() + print(f"{class_name}: {correct}/{total} correct ({correct/total:.3f})") + + # Show what this class was misclassified as + if total > correct: + wrong_preds = y_test_pred[mask & (y_test_pred != class_name)] + unique, counts = np.unique(wrong_preds, return_counts=True) + print(f" Misclassified as: {dict(zip(unique, counts))}") + +# 3. Prediction confidence (probabilities) +y_test_proba = clf.predict_proba(X_test) +class_names = clf.classes_ + +plt.figure(figsize=(12, 4)) +for i, class_name in enumerate(class_names): + plt.subplot(1, 3, i + 1) + plt.hist( + y_test_proba[:, i], bins=20, alpha=0.7, color=["blue", "orange", "green"][i] + ) + plt.title(f"Confidence for {class_name}") + plt.xlabel("Probability") + plt.ylabel("Count") +plt.tight_layout() +plt.show() + +# 4. Most confident correct and incorrect predictions +print("\nMost confident predictions:") +max_proba = np.max(y_test_proba, axis=1) +pred_correct = y_test == y_test_pred + +# Most confident correct predictions +correct_idx = np.where(pred_correct)[0] +most_confident_correct = correct_idx[np.argsort(max_proba[correct_idx])[-5:]] +print("Top 5 most confident CORRECT predictions:") +for idx in most_confident_correct: + print( + f" True: {y_test[idx]}, Pred: {y_test_pred[idx]}, Confidence: {max_proba[idx]:.3f}" + ) + +# Most confident incorrect predictions +incorrect_idx = np.where(~pred_correct)[0] +if len(incorrect_idx) > 0: + most_confident_wrong = incorrect_idx[np.argsort(max_proba[incorrect_idx])[-5:]] + print("\nTop 5 most confident WRONG predictions:") + for idx in most_confident_wrong: + print( + f" True: {y_test[idx]}, Pred: {y_test_pred[idx]}, Confidence: {max_proba[idx]:.3f}" + ) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py new file mode 100644 index 000000000..e7c3b54cf --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py @@ -0,0 +1,172 @@ +# %% Imports +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.dimensionality_reduction import compute_phate + +# %% +test_data_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_data_ckpt264.zarr" +) +test_drugs_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_drugs_ckpt264.zarr" +) +cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/pseudolabels/cell_cycle_labels.csv" + +# %% Load embeddings and annotations. + +test_features = read_embedding_dataset(test_data_features_path) +# test_drugs = read_embedding_dataset(test_drugs_path) + +# Load cell cycle labels +cell_cycle_labels_df = pd.read_csv(cell_cycle_labels_path, dtype={"dataset_name": str}) + +# Create a combined identifier for matching +sample_coords = test_features.coords["sample"].values +fov_names = [coord[0] for coord in sample_coords] +ids = [coord[1] for coord in sample_coords] + +# Create DataFrame with embeddings and identifiers +embedding_df = pd.DataFrame( + { + "dataset_name": fov_names, + "timepoint": ids, + } +) + +# Merge with cell cycle labels +merged_data = embedding_df.merge( + cell_cycle_labels_df, on=["dataset_name", "timepoint"], how="inner" +) + +print(f"Original embeddings: {len(embedding_df)}") +print(f"Cell cycle labels: {len(cell_cycle_labels_df)}") +print(f"Merged data: {len(merged_data)}") +print(f"Cell cycle distribution:\n{merged_data['cell_cycle_state'].value_counts()}") + +# Get corresponding features for merged samples +merged_indices = merged_data.index.values +cell_cycle_states = merged_data["cell_cycle_state"].values + +# %% +# compute phate +phate_kwargs = { + "knn": 10, + "decay": 20, + "n_components": 2, + "gamma": 1, + "t": "auto", + "n_jobs": -1, +} + +phate_model, phate_embedding = compute_phate(test_features, **phate_kwargs) +# %% + +# Define colorblind-friendly palette for cell cycle states (blue/orange as requested) +cycle_colors = {"G1": "#1f77b4", "G2": "#ff7f0e", "S": "#9467bd"} + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x=phate_embedding[merged_indices, 0], + y=phate_embedding[merged_indices, 1], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, +) +plt.title("PHATE Embedding Colored by Cell Cycle State") +plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + + +# %% +# Plot the PHATE embedding from the xarray + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE2"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, +) +plt.title("PHATE1 vs PHATE2 Colored by Cell Cycle State") +plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") +# %% +# plot the 3D PHATE embedding (Note: seaborn scatterplot doesn't support 3D, using matplotlib) + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111, projection='3d') + +for state in ["G1", "G2", "S"]: + mask = cell_cycle_states == state + ax.scatter( + test_features["PHATE1"][merged_indices][mask], + test_features["PHATE2"][merged_indices][mask], + test_features["PHATE3"][merged_indices][mask], + c=cycle_colors[state], + alpha=0.6, + label=state + ) + +ax.set_xlabel("PHATE1") +ax.set_ylabel("PHATE2") +ax.set_zlabel("PHATE3") +ax.set_title("3D PHATE Embedding Colored by Cell Cycle State") +ax.legend() + +# %% +# Plot the PHATE embedding from test_drugs (commented out since not loaded) +# plt.figure(figsize=(10, 10)) +# sns.scatterplot( +# x=test_drugs["PHATE1"], +# y=test_drugs["PHATE2"], +# # hue=test_drugs["t"], +# alpha=0.5, +# ) +# plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") +# %% +fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + +# PHATE1 vs PHATE2 +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE2"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[0], +) +axes[0].set_title("PHATE1 vs PHATE2") +axes[0].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +# PHATE1 vs PHATE3 +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE3"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[1], +) +axes[1].set_title("PHATE1 vs PHATE3") +axes[1].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +# PHATE2 vs PHATE3 +sns.scatterplot( + x=test_features["PHATE2"][merged_indices], + y=test_features["PHATE3"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[2], +) +axes[2].set_title("PHATE2 vs PHATE3") +axes[2].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +plt.tight_layout() +plt.show() +# %% diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py new file mode 100644 index 000000000..4c1b1383b --- /dev/null +++ b/viscy/data/cell_division_triplet.py @@ -0,0 +1,395 @@ +import logging +import random +from pathlib import Path +from typing import Literal, Sequence + +import numpy as np +import pandas as pd +import torch +from monai.transforms import Compose, MapTransform +from natsort import natsorted +from torch import Tensor +from torch.utils.data import Dataset + +from viscy.data.hcs import HCSDataModule +from viscy.data.triplet import ( + _gather_channels, + _scatter_channels, + _transform_channel_wise, +) +from viscy.data.typing import DictTransform, TripletSample + +_logger = logging.getLogger("lightning.pytorch") + + +class CellDivisionTripletDataset(Dataset): + def __init__( + self, + data_paths: list[Path], + channel_names: list[str], + anchor_transform: DictTransform | None = None, + positive_transform: DictTransform | None = None, + negative_transform: DictTransform | None = None, + fit: bool = True, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + ) -> None: + """Dataset for triplet sampling of cell division data from npy files. + + Parameters + ---------- + data_paths : list[Path] + List of paths to npy files containing cell division tracks (T,C,Y,X format) + channel_names : list[str] + Input channel names + anchor_transform : DictTransform | None, optional + Transforms applied to the anchor sample, by default None + positive_transform : DictTransform | None, optional + Transforms applied to the positive sample, by default None + negative_transform : DictTransform | None, optional + Transforms applied to the negative sample, by default None + fit : bool, optional + Fitting mode in which the full triplet will be sampled, + only sample anchor if False, by default True + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage, by default True + """ + self.channel_names = channel_names + self.anchor_transform = anchor_transform + self.positive_transform = positive_transform + self.negative_transform = negative_transform + self.fit = fit + self.time_interval = time_interval + self.return_negative = return_negative + + # Load and process all data files + self.cell_tracks = self._load_data(data_paths) + self.valid_anchors = self._filter_anchors() + + def _load_data(self, data_paths: list[Path]) -> list[dict]: + """Load npy files.""" + all_tracks = [] + + for path in data_paths: + data = np.load(path) # Shape: (T, C, Y, X) + T, C, Y, X = data.shape + + # Create track info for this file + # NOTE: using the filename as track ID as UID. + track_info = { + "data": torch.from_numpy(data.astype(np.float32)), + "file_path": str(path), + "track_id": path.stem, + "num_timepoints": T, + "shape": (T, C, Y, X), + } + all_tracks.append(track_info) + + _logger.info(f"Loaded {len(all_tracks)} tracks") + return all_tracks + + def _filter_anchors(self) -> list[dict]: + """Create valid anchor points based on time interval constraints.""" + valid_anchors = [] + + for track in self.cell_tracks: + num_timepoints = track["num_timepoints"] + + if self.time_interval == "any" or not self.fit: + valid_timepoints = list(range(num_timepoints)) + else: + # Only timepoints that have a future timepoint at the specified interval + valid_timepoints = list(range(num_timepoints - self.time_interval)) + + for t in valid_timepoints: + anchor_info = { + "track": track, + "timepoint": t, + "track_id": track["track_id"], + "file_path": track["file_path"], + } + valid_anchors.append(anchor_info) + + return valid_anchors + + def __len__(self) -> int: + return len(self.valid_anchors) + + def _sample_positive(self, anchor_info: dict) -> Tensor: + """Select a positive sample from the same track.""" + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + + if self.time_interval == "any": + # Use the same anchor patch (will be augmented differently) + positive_t = anchor_t + else: + # Use future timepoint + positive_t = anchor_t + self.time_interval + + positive_patch = track["data"][positive_t] # Shape: (C, Y, X) + # Add depth dimension: (C, Y, X) -> (C, D=1, Y, X) + positive_patch = positive_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + return positive_patch + + def _sample_negative(self, anchor_info: dict) -> Tensor: + """Select a negative sample from a different track.""" + anchor_track_id = anchor_info["track_id"] + + negative_candidates = [ + t for t in self.cell_tracks if t["track_id"] != anchor_track_id + ] + + if not negative_candidates: + # Fallback: use different timepoint from same track + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + available_times = [ + t for t in range(track["num_timepoints"]) if t != anchor_t + ] + if available_times: + neg_t = random.choice(available_times) + negative_patch = track["data"][neg_t] + else: + # Ultimate fallback: use same patch (transforms will differentiate) + negative_patch = track["data"][anchor_t] + else: + # Sample from different track + neg_track = random.choice(negative_candidates) + + if self.time_interval == "any": + neg_t = random.randint(0, neg_track["num_timepoints"] - 1) + else: + # Try to use same relative timepoint, fallback to random + anchor_t = anchor_info["timepoint"] + target_t = anchor_t + self.time_interval + if target_t < neg_track["num_timepoints"]: + neg_t = target_t + else: + neg_t = random.randint(0, neg_track["num_timepoints"] - 1) + + negative_patch = neg_track["data"][neg_t] + + # Add depth dimension: (C, Y, X) -> (C, D=1, Y, X) + negative_patch = negative_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + return negative_patch + + def __getitem__(self, index: int) -> TripletSample: + anchor_info = self.valid_anchors[index] + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + + # Get anchor patch and add depth dimension + anchor_patch = track["data"][anchor_t] # Shape: (C, Y, X) + anchor_patch = anchor_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + + sample = {"anchor": anchor_patch} + + if self.fit: + positive_patch = self._sample_positive(anchor_info) + + if self.positive_transform: + positive_patch = _transform_channel_wise( + transform=self.positive_transform, + channel_names=self.channel_names, + patch=positive_patch, + norm_meta=None, + ) + + if self.return_negative: + negative_patch = self._sample_negative(anchor_info) + + if self.negative_transform: + negative_patch = _transform_channel_wise( + transform=self.negative_transform, + channel_names=self.channel_names, + patch=negative_patch, + norm_meta=None, + ) + + sample.update({"positive": positive_patch, "negative": negative_patch}) + else: + sample.update({"positive": positive_patch}) + else: + # For prediction mode, include index information + index_dict = { + "fov_name": anchor_info["track_id"], + "id": anchor_t, + } + sample.update({"index": index_dict}) + + if self.anchor_transform: + sample["anchor"] = _transform_channel_wise( + transform=self.anchor_transform, + channel_names=self.channel_names, + patch=sample["anchor"], + norm_meta=None, + ) + + return sample + + +class CellDivisionTripletDataModule(HCSDataModule): + def __init__( + self, + data_path: str, + source_channel: str | Sequence[str], + final_yx_patch_size: tuple[int, int] = (64, 64), # Match dataset size + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], + augment_validation: bool = True, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + persistent_workers: bool = False, + prefetch_factor: int | None = None, + pin_memory: bool = False, + ): + """Lightning data module for cell division triplet sampling. + + Parameters + ---------- + data_path : str + Path to directory containing npy files + source_channel : str | Sequence[str] + List of input channel names + final_yx_patch_size : tuple[int, int], optional + Output patch size, by default (64, 64) + split_ratio : float, optional + Ratio of training samples, by default 0.8 + batch_size : int, optional + Batch size, by default 16 + num_workers : int, optional + Number of data-loading workers, by default 8 + normalizations : list[MapTransform], optional + Normalization transforms, by default [] + augmentations : list[MapTransform], optional + Augmentation transforms, by default [] + augment_validation : bool, optional + Apply augmentations to validation data, by default True + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage, by default True + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False + """ + # Initialize parent class with minimal required parameters + super().__init__( + data_path=data_path, + source_channel=source_channel, + target_channel=[], + z_window_size=1, + split_ratio=split_ratio, + batch_size=batch_size, + num_workers=num_workers, + target_2d=False, # Set to False since we're adding depth dimension + yx_patch_size=final_yx_patch_size, + normalizations=normalizations, + augmentations=augmentations, + caching=False, # NOTE: Not applicable for npy files + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.split_ratio = split_ratio + self.data_path = Path(data_path) + self.time_interval = time_interval + self.return_negative = return_negative + self.augment_validation = augment_validation + + # Find all npy files in the data directory + self.npy_files = list(self.data_path.glob("*.npy")) + if not self.npy_files: + raise ValueError(f"No .npy files found in {data_path}") + + _logger.info(f"Found {len(self.npy_files)} .npy files in {data_path}") + + @property + def _base_dataset_settings(self) -> dict: + return { + "channel_names": self.source_channel, + "time_interval": self.time_interval, + } + + def _setup_fit(self, dataset_settings: dict): + augment_transform, no_aug_transform = self._fit_transform() + + # Shuffle and split the npy files + shuffled_indices = self._set_fit_global_state(len(self.npy_files)) + npy_files = [self.npy_files[i] for i in shuffled_indices] + + #Se the train an dval positions + num_train_files = int(len(self.npy_files) * self.split_ratio) + train_npy_files = npy_files[:num_train_files] + val_npy_files = npy_files[num_train_files:] + + _logger.debug(f"Number of training files: {len(train_npy_files)}") + _logger.debug(f"Number of validation files: {len(val_npy_files)}") + + + # Determine anchor transform based on time interval + anchor_transform = ( + no_aug_transform + if (self.time_interval == "any" or self.time_interval == 0) + else augment_transform + ) + + # Create training dataset + self.train_dataset = CellDivisionTripletDataset( + data_paths=train_npy_files, + anchor_transform=anchor_transform, + positive_transform=augment_transform, + negative_transform=augment_transform, + fit=True, + return_negative=self.return_negative, + **dataset_settings, + ) + + # Choose transforms for validation based on augment_validation parameter + val_positive_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_negative_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_anchor_transform = ( + anchor_transform if self.augment_validation else no_aug_transform + ) + + # Create validation dataset + self.val_dataset = CellDivisionTripletDataset( + data_paths=val_npy_files, + anchor_transform=val_anchor_transform, + positive_transform=val_positive_transform, + negative_transform=val_negative_transform, + fit=True, + return_negative=self.return_negative, + **dataset_settings, + ) + + _logger.info(f"Training dataset size: {len(self.train_dataset)}") + _logger.info(f"Validation dataset size: {len(self.val_dataset)}") + + def _setup_predict(self, dataset_settings: dict): + self._set_predict_global_state() + + # For prediction, use all data + self.predict_dataset = CellDivisionTripletDataset( + data_paths=self.npy_files, + anchor_transform=Compose(self.normalizations), + fit=False, + **dataset_settings, + ) + + def _setup_test(self, *args, **kwargs): + raise NotImplementedError("Self-supervised model does not support testing") From 385322b9842c9477c818d83a85334c18a2dfab29 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 2 Aug 2025 13:03:29 -0700 Subject: [PATCH 022/101] cleanup the vae and add the monai to lightning. adding configs --- .../DynaCLR/BetaVAE/config_betavae.yml | 130 +++++++++ .../BetaVAE/config_betavae_convnext.yml | 146 ++++++++++ viscy/representation/engine.py | 82 +++--- viscy/representation/vae.py | 262 ++++++++++++------ 4 files changed, 499 insertions(+), 121 deletions(-) create mode 100644 applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml create mode 100644 applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml diff --git a/applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml b/applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml new file mode 100644 index 000000000..93784faf1 --- /dev/null +++ b/applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml @@ -0,0 +1,130 @@ +seed_everything: 42 +trainer: + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + precision: 16-mixed + max_epochs: 200 + log_every_n_steps: 10 + check_val_every_n_epoch: 1 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: "/hpc/projects/organelle_phenotyping/models/SEC61B/vae" + version: "sensor_phase3d_zikv_denv_lr2e-4_beta1.5" + log_graph: false + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: "loss/total/val" + save_top_k: 5 + save_last: true + every_n_epochs: 1 + fast_dev_run: true + enable_checkpointing: true + # inference_mode: true + use_distributed_sampler: true + +model: + class_path: viscy.representation.engine.BetaVaeModule + init_args: + architecture: "monai_beta" + model_config: + spatial_dims: 3 + in_shape: [2, 16, 192, 192] + out_channels: 2 + latent_size: 1024 + channels: [64, 128, 256, 512] + strides: [[2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]] + beta: 1.0 # Conservative target - can increase later + beta_schedule: cosine + beta_min: 0.1 # Start low to learn reconstructions first + beta_warmup_epochs: 50 # Half of training for gradual ramp + lr: 0.0002 + example_input_array_shape: [1, 2, 16, 192, 192] + loss_function: + class_path: torch.nn.MSELoss + init_args: {reduction: 'mean'} +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr" + tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr" + source_channel: + - &phase Phase3D + - &mcherry raw mCherry EX561 EM600-37 + z_range: [10, 26] + initial_yx_patch_size: [384, 384] + final_yx_patch_size: [192, 192] + batch_size: 64 + num_workers: 12 + time_interval: 1 + augment_validation: false + return_negative: false + fit_include_wells: ["B/3", "B/4", "C/3", "C/4"] + augmentations: + - class_path: viscy.transforms.RandAffined + init_args: + keys: [*phase, *mcherry] + prob: 0.8 + scale_range: [0, 0.2, 0.2] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.0, 0.01, 0.01] + padding_mode: zeros + - class_path: viscy.transforms.RandAdjustContrastd + init_args: + keys: [*mcherry] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy.transforms.RandAdjustContrastd + init_args: + keys: [*phase] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy.transforms.RandScaleIntensityd + init_args: + keys: [*mcherry] + prob: 0.5 + factors: 0.5 + - class_path: viscy.transforms.RandScaleIntensityd + init_args: + keys: [*phase] + prob: 0.5 + factors: 0.5 + - class_path: viscy.transforms.RandGaussianSmoothd + init_args: + keys: [*phase, *mcherry] + prob: 0.5 + sigma_x: [0.25, 0.75] + sigma_y: [0.25, 0.75] + sigma_z: [0.0, 0.0] + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [*mcherry] + prob: 0.5 + mean: 0.0 + std: 0.2 + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [*phase] + prob: 0.5 + mean: 0.0 + std: 0.2 + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [*phase] + level: fov_statistics + subtrahend: mean + divisor: std + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: [*mcherry] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml b/applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml new file mode 100644 index 000000000..c586855e0 --- /dev/null +++ b/applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml @@ -0,0 +1,146 @@ +seed_everything: 42 +trainer: + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + precision: 16-mixed + max_epochs: 200 + log_every_n_steps: 10 + check_val_every_n_epoch: 1 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: "/hpc/projects/organelle_phenotyping/models/SEC61B/vae" + version: "sensor_phase3d_zikv_denv_lr2e-4_beta1.5" + log_graph: false + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: "loss/total/val" + save_top_k: 5 + save_last: true + every_n_epochs: 1 + fast_dev_run: true + enable_checkpointing: true + # inference_mode: true + use_distributed_sampler: true + +model: + class_path: viscy.representation.engine.BetaVaeModule + init_args: + architecture: "2.5D" + model_config: + backbone: convnext_tiny + in_channels: 2 + in_stack_depth: 16 + out_stack_depth: 16 + latent_dim: 1024 + input_spatial_size: [192, 192] + stem_kernel_size: [4, 2, 2] + stem_stride: [4, 2, 2] + decoder_stages: 4 + head_expansion_ratio: 2 + head_pool: false + upsample_mode: pixelshuffle + conv_blocks: 2 + norm_name: batch + beta: 1.0 # Conservative target - can increase later + beta_schedule: cosine + beta_min: 0.1 # Start low to learn reconstructions first + beta_warmup_epochs: 50 # Half of training for gradual ramp + lr: 0.0002 + example_input_array_shape: [1, 2, 16, 192, 192] + loss_function: + class_path: torch.nn.MSELoss + init_args: + reduction: mean + log_batches_per_epoch: 8 + log_samples_per_batch: 1 + compute_disentanglement: false + disentanglement_frequency: 10 + log_enhanced_visualizations: false + log_enhanced_visualizations_frequency: 30 + +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr" + tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr" + source_channel: + - &phase Phase3D + - &mcherry raw mCherry EX561 EM600-37 + z_range: [10, 26] + initial_yx_patch_size: [384, 384] + final_yx_patch_size: [192, 192] + batch_size: 64 + num_workers: 12 + time_interval: 1 + augment_validation: false + return_negative: false + fit_include_wells: ["B/3", "B/4", "C/3", "C/4"] + augmentations: + - class_path: viscy.transforms.RandAffined + init_args: + keys: [*phase, *mcherry] + prob: 0.8 + scale_range: [0, 0.2, 0.2] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.0, 0.01, 0.01] + padding_mode: zeros + - class_path: viscy.transforms.RandAdjustContrastd + init_args: + keys: [*mcherry] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy.transforms.RandAdjustContrastd + init_args: + keys: [*phase] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy.transforms.RandScaleIntensityd + init_args: + keys: [*mcherry] + prob: 0.5 + factors: 0.5 + - class_path: viscy.transforms.RandScaleIntensityd + init_args: + keys: [*phase] + prob: 0.5 + factors: 0.5 + - class_path: viscy.transforms.RandGaussianSmoothd + init_args: + keys: [*phase, *mcherry] + prob: 0.5 + sigma_x: [0.25, 0.75] + sigma_y: [0.25, 0.75] + sigma_z: [0.0, 0.0] + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [*mcherry] + prob: 0.5 + mean: 0.0 + std: 0.2 + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [*phase] + prob: 0.5 + mean: 0.0 + std: 0.2 + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [*phase] + level: fov_statistics + subtrahend: mean + divisor: std + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: [*mcherry] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 \ No newline at end of file diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 6ea532a5a..a47940ec2 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Optional, Sequence, TypedDict +from typing import Literal, Sequence, TypedDict import numpy as np import torch @@ -12,12 +12,16 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.representation.disentanglement_metrics import DisentanglementMetrics -from viscy.representation.vae import VaeDecoder, VaeEncoder +from viscy.representation.vae import BetaVae25D, BetaVaeMonai from viscy.representation.vae_logging import BetaVaeLogger from viscy.utils.log_images import detach_sample, render_images _logger = logging.getLogger("lightning.pytorch") +_VAE_ARCHITECTURE = { + "2.5D": BetaVae25D, + "monai_beta": BetaVaeMonai, +} class ContrastivePrediction(TypedDict): features: Tensor @@ -249,20 +253,18 @@ def predict_step( "index": batch["index"], } - -class VaeModule(LightningModule): - """Native PyTorch Lightning Beta-VAE implementation.""" - +class BetaVaeModule(LightningModule): def __init__( self, - encoder: VaeEncoder, - decoder: VaeDecoder, - loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), + architecture: Literal["monai_beta","2.5D"], + model_config: dict = {}, + loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="mean"), beta: float = 1.0, beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, beta_min: float = 0.1, beta_warmup_epochs: int = 50, - lr: float = 1e-3, + lr: float = 1e-5, + lr_schedule: Literal["WarmupCosine", "Constant"] = "Constant", log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_array_shape: Sequence[int] = (1, 2, 30, 256, 256), @@ -273,35 +275,50 @@ def __init__( ): super().__init__() - self.encoder = encoder - self.decoder = decoder - - # Infer latent dimension from encoder and validate decoder matches - self.latent_dim = encoder.latent_dim - # Validate that decoder's latent_dim matches encoder's embedding_dim - if hasattr(decoder, "latent_dim") and decoder.latent_dim != self.latent_dim: + net_class= _VAE_ARCHITECTURE.get(architecture) + if not net_class: raise ValueError( - f"Encoder embedding_dim ({self.latent_dim}) must match " - f"decoder latent_dim ({decoder.latent_dim})" + f"Architecture {architecture} not in {_VAE_ARCHITECTURE.keys()}" ) + + self.model = net_class(**model_config) + self.model_config = model_config + self.loss_function = loss_function + self.beta = beta self.beta_schedule = beta_schedule self.beta_min = beta_min self.beta_warmup_epochs = beta_warmup_epochs + self.lr = lr + self.lr_schedule = lr_schedule + self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch - self.loss_function = loss_function + self.example_input_array = torch.rand(*example_input_array_shape) self.compute_disentanglement = compute_disentanglement self.disentanglement_frequency = disentanglement_frequency + self.log_enhanced_visualizations = log_enhanced_visualizations self.log_enhanced_visualizations_frequency = ( log_enhanced_visualizations_frequency ) self.training_step_outputs = [] self.validation_step_outputs = [] - self.vae_logger = BetaVaeLogger(latent_dim=self.latent_dim) + + # Handle different parameter names for latent dimensions + latent_dim = None + if "latent_dim" in self.model_config: + latent_dim = self.model_config["latent_dim"] + elif "latent_size" in self.model_config: + latent_dim = self.model_config["latent_size"] + + if latent_dim is not None: + self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) + else: + _logger.warning("No latent dimension provided for BetaVaeLogger. Using default with 128 dimensions.") + self.vae_logger = BetaVaeLogger() def setup(self, stage: str = None): """Setup hook to initialize device-dependent components.""" @@ -348,23 +365,20 @@ def _get_current_beta(self) -> float: def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" - # Encode - encoder_output = self.encoder(x) - mu = encoder_output.mean - logvar = encoder_output.log_covariance - z = encoder_output.z + # Handle different model output formats + model_output = self.model(x) + + recon_x = model_output.recon_x + mu = model_output.mean + logvar = model_output.logvar + z = model_output.z - # Decode - reconstruction = self.decoder(z) - # Compute losses with current beta (normalized by batch size) current_beta = self._get_current_beta() batch_size = x.size(0) - # MSE loss normalized by batch size - recon_loss = self.loss_function(reconstruction, x) - - # KL loss normalized by batch size + # NOTE: normalizing by the batch size + recon_loss = self.loss_function(recon_x, x) kl_loss = ( -0.5 * current_beta @@ -375,7 +389,7 @@ def forward(self, x: Tensor) -> dict: total_loss = recon_loss + kl_loss return { - "recon_x": reconstruction, + "recon_x": recon_x, "z": z, "mu": mu, "logvar": logvar, diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 4e248b7a5..d92ab3fad 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from types import SimpleNamespace from typing import Callable, Literal @@ -5,6 +6,7 @@ import torch from monai.networks.blocks import ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer +from monai.networks.nets import VarAutoEncoder from torch import Tensor, nn from viscy.unet.networks.unext2 import ( @@ -97,21 +99,17 @@ def forward(self, inp: Tensor) -> Tensor: class VaeEncoder(nn.Module): """VAE encoder for microscopy data with 3D to 2D conversion.""" - # TODO: roll back the Conv2d to AveragePooling and linear layer to global pooling - # TODO: embedding dim - # TODO: check the OG VAE compression rate - # TODO do log grid search for the best embedding dim - def __init__( self, - backbone: str = "resnet50", # [64, 256, 512, 1024, 2048] channels + backbone: Literal["resnet50", "convnext_tiny"] = "resnet50", in_channels: int = 2, in_stack_depth: int = 16, latent_dim: int = 1024, input_spatial_size: tuple[int, int] = (256, 256), - stem_kernel_size: tuple[int, int, int] = (4, 5, 5), - stem_stride: tuple[int, int, int] = (4, 5, 5), # same as kernel size + stem_kernel_size: tuple[int, int, int] = (2, 4, 4), + stem_stride: tuple[int, int, int] = (2, 4, 4), drop_path_rate: float = 0.0, + pretrained: bool = True, ): super().__init__() self.backbone = backbone @@ -119,18 +117,23 @@ def __init__( encoder = timm.create_model( backbone, - pretrained=False, + pretrained=pretrained, features_only=True, drop_path_rate=drop_path_rate, ) - if "resnet" in backbone: - in_channels_encoder = encoder.conv1.out_channels - # remove the original 3D stem for rgb imges to support the multichannel 3D input + if "convnext" in backbone: + num_channels = encoder.feature_info.channels() + in_channels_encoder = num_channels[0] + encoder.stem_0 = nn.Identity() + out_channels_encoder = num_channels[-1] + elif "resnet" in backbone: + num_channels = encoder.feature_info.channels() + in_channels_encoder = num_channels[0] encoder.conv1 = nn.Identity() - out_channels_encoder = encoder.feature_info.channels()[-1] + out_channels_encoder = num_channels[-1] else: - raise ValueError(f"Backbone {backbone} not supported") + raise ValueError(f"Backbone {backbone} not supported. Use 'resnet50', 'convnext_tiny', or 'convnextv2_tiny'") # Stem for 3d multichannel and to convert 3D to 2D self.stem = StemDepthtoChannels( @@ -141,35 +144,27 @@ def __init__( stem_stride=stem_stride, ) self.encoder = encoder - - # Calculate spatial dimensions after encoder and initialize linear layers + self.num_channels = num_channels + self.in_channels_encoder = in_channels_encoder self.out_channels_encoder = out_channels_encoder + + # Calculate spatial size after stem + stem_spatial_size_h = input_spatial_size[0] // stem_stride[1] + stem_spatial_size_w = input_spatial_size[1] // stem_stride[2] + + # Spatial size after backbone + backbone_reduction = 2 ** (len(num_channels) - 1) + final_spatial_size_h = stem_spatial_size_h // backbone_reduction + final_spatial_size_w = stem_spatial_size_w // backbone_reduction + + flattened_size = out_channels_encoder * final_spatial_size_h * final_spatial_size_w - if "resnet50" in backbone: - # Calculate spatial size after stem, then ResNet50 downsampling - stem_spatial_h = ( - input_spatial_size[0] - stem_kernel_size[1] - ) // stem_stride[1] + 1 - stem_spatial_w = ( - input_spatial_size[1] - stem_kernel_size[2] - ) // stem_stride[2] + 1 - - # ResNet50 downsamples by 32x total, but stem already downsampled - total_downsample_factor = 32 - stem_downsample_factor = stem_stride[1] # Spatial downsampling from stem - resnet_downsample_factor = total_downsample_factor // stem_downsample_factor - final_h = stem_spatial_h // resnet_downsample_factor - final_w = stem_spatial_w // resnet_downsample_factor - flattened_size = out_channels_encoder * final_h * final_w - else: - raise ValueError( - f"Backbone {backbone} not supported for analytical calculation" - ) - - # Multi-layer perceptron for better representation learning self.fc = nn.Linear(flattened_size, latent_dim) self.fc_mu = nn.Linear(latent_dim, latent_dim) self.fc_logvar = nn.Linear(latent_dim, latent_dim) + + # Store final spatial size for decoder (assuming square for simplicity) + self.encoder_spatial_size = final_spatial_size_h # Assuming square output def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """Reparameterization trick: sample from N(mu, var) using N(0,1).""" @@ -183,17 +178,19 @@ def forward(self, x: Tensor) -> SimpleNamespace: features = self.encoder(x) - # Take highest resolution features and flatten - x = features[-1] # [B, C, H, W] + # NOTE: taking the highest resolution features and flatten + # When features_only=False, encoder returns single tensor, not list + if isinstance(features, list): + x = features[-1] # [B, C, H, W] + else: + x = features # [B, C, H, W] x_flat = x.flatten(1) # [B, C*H*W] - flatten from dim 1 onwards - # Apply intermediate FC layer - x_intermediate = self.fc(x_flat) # [B, intermediate_dim] + x_intermediate = self.fc(x_flat) - # Apply linear layers to get 1D embeddings - mu = self.fc_mu(x_intermediate) # [B, latent_dim] - logvar = self.fc_logvar(x_intermediate) # [B, latent_dim] - z = self.reparameterize(mu, logvar) # [B, latent_dim] + mu = self.fc_mu(x_intermediate) + logvar = self.fc_logvar(x_intermediate) + z = self.reparameterize(mu, logvar) return SimpleNamespace(mean=mu, log_covariance=logvar, z=z) @@ -208,68 +205,40 @@ def __init__( out_channels: int = 2, out_stack_depth: int = 16, head_expansion_ratio: int = 2, + strides: list[int] = [2, 2, 2, 1], + encoder_spatial_size: int=16, head_pool: bool = False, upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, norm_name: str = "batch", upsample_pre_conv: Literal["default"] | Callable | None = None, - strides: list[int] | None = None, - input_spatial_size: tuple[int, int] = ( - 128, - 128, - ), # Input size to calculate spatial dimensions ): super().__init__() + self.decoder_channels = decoder_channels + self.latent_dim = latent_dim self.out_channels = out_channels self.out_stack_depth = out_stack_depth + self.decoder_channels = decoder_channels - head_channels = ( - (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio - ) - decoder_channels_with_head = decoder_channels.copy() + [head_channels] - - num_stages = len(decoder_channels_with_head) - 1 - if strides is None: - if ( - num_stages == 4 - ): # Default [1024, 512, 256, 128] + head = 5 channels, 4 stages - strides = [ - 2, - 2, - 2, - 1, - ] # Reduce to account for PixelToVoxelHead's 4x upsampling - else: - strides = [2] * num_stages # Fallback to uniform 2x upsampling - elif len(strides) != num_stages: - raise ValueError( - f"Length of strides ({len(strides)}) must match number of stages ({num_stages})" - ) - # Calculate spatial size based on input dimensions and ResNet50 32x downsampling - self.spatial_size = input_spatial_size[0] // 32 # ResNet50 downsamples by 32x + self.spatial_size = encoder_spatial_size self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) - # Project 1D latent to spatial format, then to first decoder channels self.latent_reshape = nn.Linear( latent_dim, self.spatial_channels * self.spatial_size * self.spatial_size ) self.latent_proj = nn.Conv2d( - self.spatial_channels, decoder_channels_with_head[0], kernel_size=1 + self.spatial_channels, decoder_channels[0], kernel_size=1 ) # Build the decoder stages self.decoder_stages = nn.ModuleList() - + num_stages = len(self.decoder_channels) - 1 for i in range(num_stages): - in_channels = decoder_channels_with_head[i] - out_channels_stage = decoder_channels_with_head[i + 1] - stride = strides[i] - stage = VaeUpStage( - in_channels=in_channels, - out_channels=out_channels_stage, - scale_factor=stride, + in_channels=self.decoder_channels[i], + out_channels=self.decoder_channels[i + 1], + scale_factor=strides[i], mode=upsample_mode, conv_blocks=conv_blocks, norm_name=norm_name, @@ -279,7 +248,7 @@ def __init__( # Head to convert back to 3D self.head = PixelToVoxelHead( - in_channels=head_channels, + in_channels=decoder_channels[-1], out_channels=self.out_channels, out_stack_depth=self.out_stack_depth, expansion_ratio=head_expansion_ratio, @@ -305,7 +274,126 @@ def forward(self, z: Tensor) -> Tensor: for stage in self.decoder_stages: x = stage(x) - # Last stage outputs head_channels directly - no final_conv needed output = self.head(x) return output + + +class BetaVae25D(nn.Module): + """2.5D Beta-VAE combining VaeEncoder and VaeDecoder.""" + + def __init__( + self, + backbone: Literal["resnet50", "convnext_tiny"] = "resnet50", + in_channels: int = 2, + in_stack_depth: int = 16, + out_stack_depth: int = 16, + latent_dim: int = 1024, + input_spatial_size: tuple[int, int] = (256, 256), + stem_kernel_size: tuple[int, int, int] = (2, 4, 4), + stem_stride: tuple[int, int, int] = (2, 4, 4), + drop_path_rate: float = 0.0, + decoder_stages: int = 4, + head_expansion_ratio: int = 2, + head_pool: bool = False, + upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", + conv_blocks: int = 2, + norm_name: str = "batch", + upsample_pre_conv: Literal["default"] | Callable | None = None, + ): + super().__init__() + + self.encoder = VaeEncoder( + backbone=backbone, + in_channels=in_channels, + in_stack_depth=in_stack_depth, + latent_dim=latent_dim, + input_spatial_size=input_spatial_size, + stem_kernel_size=stem_kernel_size, + stem_stride=stem_stride, + drop_path_rate=drop_path_rate, + ) + + decoder_channels = self.encoder.num_channels.copy() + decoder_channels.reverse() + decoder_channels[-1] = ( + (out_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + + strides = [2] * (len(decoder_channels) - 1) + [1] + + self.decoder = VaeDecoder( + decoder_channels=decoder_channels, + latent_dim=latent_dim, + out_channels=in_channels, + out_stack_depth=out_stack_depth, + head_expansion_ratio=head_expansion_ratio, + head_pool=head_pool, + upsample_mode=upsample_mode, + conv_blocks=conv_blocks, + norm_name=norm_name, + upsample_pre_conv=upsample_pre_conv, + strides=strides, + encoder_spatial_size=self.encoder.encoder_spatial_size, + ) + + def forward(self, x: Tensor) -> SimpleNamespace: + """Forward pass returning VAE outputs.""" + encoder_output = self.encoder(x) + recon_x = self.decoder(encoder_output.z) + + return SimpleNamespace( + recon_x=recon_x, + mean=encoder_output.mean, + logvar=encoder_output.log_covariance, + z=encoder_output.z + ) + + +class BetaVaeMonai(nn.Module): + """Beta-VAE with Monai architecture.""" + + def __init__(self, + spatial_dims: int, + in_shape: Sequence[int], + out_channels: int, + latent_size: int, + channels: Sequence[int], + strides: Sequence[int], + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, + num_res_units: int = 0, + use_sigmoid: bool = False, + **kwargs + ): + super().__init__() + + self.spatial_dims = spatial_dims + self.in_shape = in_shape + self.out_channels = out_channels + self.latent_size = latent_size + self.channels = channels + self.strides = strides + self.kernel_size = kernel_size + self.up_kernel_size = up_kernel_size + self.num_res_units = num_res_units + self.use_sigmoid = use_sigmoid + + self.model = VarAutoEncoder( + spatial_dims=self.spatial_dims, + in_shape=self.in_shape, + out_channels=self.out_channels, + latent_size=self.latent_size, + channels=self.channels, + strides=self.strides, + kernel_size=self.kernel_size, + up_kernel_size=self.up_kernel_size, + num_res_units=self.num_res_units, + use_sigmoid=self.use_sigmoid, + **kwargs + ) + + def forward(self, x: Tensor) -> SimpleNamespace: + """Forward pass returning VAE encoder outputs.""" + recon_x, mu, logvar, z = self.model(x) + return SimpleNamespace(recon_x=recon_x, mean=mu, logvar=logvar, z=z) \ No newline at end of file From e0bf81351de20de3a34a2a331f7bd0bdee15b60c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 3 Aug 2025 15:05:12 -0700 Subject: [PATCH 023/101] add saving hyperparameters --- viscy/representation/engine.py | 120 +++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 12 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index a47940ec2..f082591ce 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -43,6 +43,7 @@ def __init__( log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, log_embeddings: bool = False, + embedding_log_frequency: int = 10, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -56,6 +57,50 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] self.log_embeddings = log_embeddings + self.embedding_log_frequency = embedding_log_frequency + + def on_train_start(self) -> None: + """Log comprehensive hyperparameters including model architecture details.""" + super().on_train_start() + + # Collect comprehensive hyperparameters + hparams = { + # Training hyperparameters + "lr": self.lr, + "schedule": self.schedule, + "input_shape": self.example_input_array_shape, + "loss_function_class": self.loss_function.__class__.__name__, + } + + # Add loss function specific parameters + if hasattr(self.loss_function, 'margin'): + hparams["loss_margin"] = self.loss_function.margin + if hasattr(self.loss_function, 'temperature'): + hparams["loss_temperature"] = self.loss_function.temperature + if hasattr(self.loss_function, 'normalize_embeddings'): + hparams["loss_normalize_embeddings"] = self.loss_function.normalize_embeddings + + # Add encoder details if it's a ContrastiveEncoder + if hasattr(self.model, 'backbone'): + hparams["encoder_backbone"] = self.model.backbone + if hasattr(self.model, 'in_channels'): + hparams["encoder_in_channels"] = self.model.in_channels + if hasattr(self.model, 'in_stack_depth'): + hparams["encoder_in_stack_depth"] = self.model.in_stack_depth + if hasattr(self.model, 'embedding_dim'): + hparams["encoder_embedding_dim"] = self.model.embedding_dim + if hasattr(self.model, 'projection_dim'): + hparams["encoder_projection_dim"] = self.model.projection_dim + if hasattr(self.model, 'drop_path_rate'): + hparams["encoder_drop_path_rate"] = self.model.drop_path_rate + if hasattr(self.model, 'stem_kernel_size'): + hparams["encoder_stem_kernel_size"] = str(self.model.stem_kernel_size) + if hasattr(self.model, 'stem_stride'): + hparams["encoder_stem_stride"] = str(self.model.stem_stride) + + # Log to TensorBoard + if self.logger is not None: + self.logger.log_hyperparams(hparams) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Return both features and projections. @@ -186,12 +231,6 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_train_epoch_end(self) -> None: super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) - # Log UMAP embeddings for validation - if self.log_embeddings: - embeddings = torch.cat( - [output["embeddings"] for output in self.validation_step_outputs] - ) - self.log_embedding_umap(embeddings, tag="train") self.training_step_outputs = [] def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: @@ -229,15 +268,72 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - # Log UMAP embeddings for training - if self.log_embeddings: - embeddings = torch.cat( - [output["embeddings"] for output in self.training_step_outputs] - ) - self.log_embedding_umap(embeddings, tag="val") + + # Log UMAP embeddings from validation set every N epochs + if ( + self.log_embeddings + and self.current_epoch % self.embedding_log_frequency == 0 + and self.current_epoch > 0 + ): + self._collect_and_log_embeddings() self.validation_step_outputs = [] + def _collect_and_log_embeddings(self): + """Collect embeddings from validation dataloader and log UMAP visualization.""" + try: + # Get validation dataloader + val_dataloaders = self.trainer.val_dataloaders + if val_dataloaders is None: + _logger.warning("No validation dataloader available for embedding logging") + return + elif isinstance(val_dataloaders, list): + val_dataloader = val_dataloaders[0] if val_dataloaders else None + else: + val_dataloader = val_dataloaders + + if val_dataloader is None: + _logger.warning("No validation dataloader available for embedding logging") + return + + _logger.info(f"Collecting embeddings for visualization at epoch {self.current_epoch}") + + # Collect embeddings from validation set + embeddings_list = [] + max_samples = 1000 # Limit samples for performance + sample_count = 0 + + self.eval() + with torch.no_grad(): + for batch in val_dataloader: + if sample_count >= max_samples: + break + + # Move batch to device + anchor = batch["anchor"].to(self.device) + + # Get embeddings (features, not projections) + features, _ = self(anchor) + embeddings_list.append(features.cpu()) + + sample_count += features.size(0) + + if embeddings_list: + embeddings = torch.cat(embeddings_list, dim=0)[:max_samples] + self.log_embedding_umap(embeddings, tag="validation") + + # Also log to TensorBoard's embedding projector + self.logger.experiment.add_embedding( + embeddings, + global_step=self.current_epoch, + tag="validation_embeddings" + ) + else: + _logger.warning("No embeddings collected from validation set") + + except Exception as e: + _logger.error(f"Error collecting embeddings: {e}") + def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer From a1ad2dc241409827d84ccd908b6bfa628182d2fd Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 3 Aug 2025 20:17:15 -0700 Subject: [PATCH 024/101] fix hyperparameter logging --- viscy/representation/engine.py | 48 ++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index f082591ce..f1b944fdd 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -68,7 +68,7 @@ def on_train_start(self) -> None: # Training hyperparameters "lr": self.lr, "schedule": self.schedule, - "input_shape": self.example_input_array_shape, + "input_shape": self.example_input_array, "loss_function_class": self.loss_function.__class__.__name__, } @@ -298,9 +298,11 @@ def _collect_and_log_embeddings(self): _logger.info(f"Collecting embeddings for visualization at epoch {self.current_epoch}") - # Collect embeddings from validation set + # Collect embeddings, images, and metadata from validation set embeddings_list = [] - max_samples = 1000 # Limit samples for performance + images_list = [] + labels_list = [] + max_samples = 500 # Reduced for memory efficiency with images sample_count = 0 self.eval() @@ -311,23 +313,59 @@ def _collect_and_log_embeddings(self): # Move batch to device anchor = batch["anchor"].to(self.device) + batch_size = anchor.size(0) # Get embeddings (features, not projections) features, _ = self(anchor) embeddings_list.append(features.cpu()) - sample_count += features.size(0) + # Collect images for sprite visualization + # Take middle slice for 3D data and first channel if multi-channel + if anchor.ndim == 5: # (B, C, D, H, W) + mid_z = anchor.size(2) // 2 + img_slice = anchor[:, 0, mid_z].cpu() # (B, H, W) + else: # (B, C, H, W) + img_slice = anchor[:, 0].cpu() # (B, H, W) + images_list.append(img_slice) + + # Collect labels from index information + if "index" in batch and batch["index"] is not None: + for i, idx_info in enumerate(batch["index"][:batch_size]): + if isinstance(idx_info, dict): + # Create label from track_id and time info + track_id = idx_info.get("track_id", "unknown") + t = idx_info.get("t", "unknown") + labels_list.append(f"track_{track_id}_t_{t}") + else: + labels_list.append(f"sample_{sample_count + i}") + else: + # Fallback labels + for i in range(batch_size): + labels_list.append(f"sample_{sample_count + i}") + + sample_count += batch_size if embeddings_list: embeddings = torch.cat(embeddings_list, dim=0)[:max_samples] + images = torch.cat(images_list, dim=0)[:max_samples] + labels = labels_list[:max_samples] + + # Normalize images for visualization (0-1 range) + images = (images - images.min()) / (images.max() - images.min() + 1e-8) + + # Log UMAP visualization self.log_embedding_umap(embeddings, tag="validation") - # Also log to TensorBoard's embedding projector + # Log to TensorBoard's embedding projector with images and labels self.logger.experiment.add_embedding( embeddings, + metadata=labels, + label_img=images.unsqueeze(1), # Add channel dimension global_step=self.current_epoch, tag="validation_embeddings" ) + + _logger.info(f"Logged {len(embeddings)} embeddings with images and labels") else: _logger.warning("No embeddings collected from validation set") From 17c1e89afead82deb9bddca87148e8e4a2dacb9f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 3 Aug 2025 20:18:06 -0700 Subject: [PATCH 025/101] add embedding logging to the CLIP version --- viscy/representation/multi_modal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 55481d434..ad4f48717 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -50,6 +50,7 @@ def __init__( log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, log_embeddings: bool = False, + embedding_log_frequency: int = 10, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), prediction_arm: Literal["source", "target"] = "source", ) -> None: @@ -61,6 +62,7 @@ def __init__( log_batches_per_epoch=log_batches_per_epoch, log_samples_per_batch=log_samples_per_batch, log_embeddings=log_embeddings, + embedding_log_frequency=embedding_log_frequency, example_input_array_shape=example_input_array_shape, ) self.example_input_array = (self.example_input_array, self.example_input_array) From 4269787a4b6da9a3105d3db1326d374a45b128d0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 6 Aug 2025 16:04:00 -0700 Subject: [PATCH 026/101] test and plot of monaivae --- .../DynaCLR/MonaiVAE/test_vae_magnitudes.py | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py diff --git a/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py b/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py new file mode 100644 index 000000000..d0f72e407 --- /dev/null +++ b/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py @@ -0,0 +1,235 @@ +#%% +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from monai.transforms import ( + NormalizeIntensity, +) +from torchview import draw_graph + +from viscy.representation.vae import BetaVae25D, BetaVaeMonai + + +def compute_vae_losses(model_output, target, beta=1.0): + """Compute VAE losses: reconstruction (MSE) and KL divergence. + """ + recon_loss = F.mse_loss(model_output.recon_x, target, reduction='mean') + + kl_loss = -0.5 * torch.sum(1 + model_output.logvar - model_output.mean.pow(2) - model_output.logvar.exp()) + kl_loss = kl_loss / model_output.mean.size(0) + + total_loss = recon_loss + beta * kl_loss + + return { + 'mu': model_output.mean, + 'logvar': model_output.logvar, + 'recon_loss': recon_loss.item(), + 'kl_loss': kl_loss.item(), + 'total_loss': total_loss.item(), + 'beta': beta, + # 'recon_magnitude': torch.abs(model_output.recon_x).mean().item(), + # 'target_magnitude': torch.abs(target).mean().item(), + # 'latent_mean_magnitude': torch.abs(model_output.mean).mean().item(), + # 'latent_std_magnitude': torch.exp(0.5 * model_output.logvar).mean().item(), + } + + +def create_synthetic_data(batch_size=2, channels=2, depth=16, height=256, width=256): + """Create synthetic microscopy-like data with known statistics. + These are from one FOV of the Phase3D + - mean: 8.196415001293644e-05 ≈ 0.0001 + - std: 0.09095408767461777 ≈ 0.091 + """ + torch.manual_seed(42) + synthetic_data = torch.randn(batch_size, channels, depth, height, width) * 0.091 + 0.0001 + + for b in range(batch_size): + for c in range(channels): + for d in range(depth): + # Add some blob-like structures + y_center, x_center = np.random.randint(50, height-50), np.random.randint(50, width-50) + y, x = np.ogrid[:height, :width] + mask = (y - y_center)**2 + (x - x_center)**2 < np.random.randint(400, 1600) + synthetic_data[b, c, d][mask] += np.random.normal(0.05, 0.02) + + synthetic_data = torch.clamp(synthetic_data, min=0) + + return synthetic_data + + +def create_known_target(input_data, noise_level=0.1): + """Create a target with known relationship to input for testing MSE magnitude. + """ + target = input_data.clone() + + noise = torch.randn_like(target) * noise_level * target.std() + target = target + noise + + target = target * 0.95 + 0.01 + + return torch.clamp(target, min=0) + + +def test_vae_magnitudes(): + """Test VAE models with both real dataloader and synthetic data.""" + print("=== VAE Magnitude Testing ===\n") + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + model_configs = [ + # { + # 'name': 'BetaVae25D_ResNet50', + # 'model_class': BetaVae25D, + # 'kwargs': { + # 'backbone': 'resnet50', + # 'in_channels': 2, + # 'in_stack_depth': 16, + # 'latent_dim': 1024, + # 'input_spatial_size': (256, 256), + # } + # }, + # Uncomment to test MONAI version + { + 'name': 'BetaVaeMonai', + 'model_class': BetaVaeMonai, + 'kwargs': { + 'spatial_dims': 3, + 'in_shape': (2, 16, 256, 256), # (C, D, H, W) + 'out_channels': 2, + 'latent_size': 1024, + 'channels': (32, 64, 128, 256), + 'strides': (2, 2, 2, 2), + } + } + ] + + # Test different beta values + beta_values = [0.1, 1.0, 4.0, 10.0] + + for model_config in model_configs: + print(f"\n{'='*50}") + print(f"Testing {model_config['name']}") + print(f"{'='*50}") + + # Initialize model + model = model_config['model_class'](**model_config['kwargs']) + model = model.to(device) + model.eval() + + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Draw model graph + print(f"\n--- Model Architecture ---") + sample_input = create_synthetic_data(batch_size=1).to(device) + try: + model_graph = draw_graph( + model, + input_data=sample_input, + expand_nested=True, + depth=6, + save_graph=True, + filename=f'{model_config["name"]}_graph', + directory='./model_graphs/' + ) + print(f"Model graph saved to: ./model_graphs/{model_config['name']}_graph.png") + except Exception as e: + print(f"Could not generate model graph: {e}") + + # Test 1: Synthetic data with known target + print(f"\n--- Test 1: Synthetic Data ---") + synthetic_input = create_synthetic_data().to(device) + synthetic_target = create_known_target(synthetic_input).to(device) + + print(f"Input shape: {synthetic_input.shape}") + print(f"Input stats - mean: {synthetic_input.mean():.6f}, std: {synthetic_input.std():.6f}") + print(f"Target stats - mean: {synthetic_target.mean():.6f}, std: {synthetic_target.std():.6f}") + + with torch.no_grad(): + synthetic_output = model(synthetic_input) + + print(f"Output shape: {synthetic_output.recon_x.shape}") + print(f"Latent shape: {synthetic_output.z.shape}") + + for beta in beta_values: + losses = compute_vae_losses(synthetic_output, synthetic_target, beta) + print(f"\nBeta = {beta}:") + print(f" Mu shape: {losses['mu'].shape}, mean: {losses['mu'].mean():.6f}, std: {losses['mu'].std():.6f}") + print(f" Logvar shape: {losses['logvar'].shape}, mean: {losses['logvar'].mean():.6f}, std: {losses['logvar'].std():.6f}") + print(f" Reconstruction Loss: {losses['recon_loss']:.6f}") + print(f" KL Loss: {losses['kl_loss']:.6f}") + print(f" Total Loss: {losses['total_loss']:.6f}") + # print(f" Recon magnitude: {losses['recon_magnitude']:.6f}") + # print(f" Target magnitude: {losses['target_magnitude']:.6f}") + # print(f" Latent mean magnitude: {losses['latent_mean_magnitude']:.6f}") + # print(f" Latent std magnitude: {losses['latent_std_magnitude']:.6f}") + + #TODO: use the dataloader to run it with real data + # data_path = "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV" + # zarr_path = Path(data_path) / "2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr" + zarr_path = None + if not zarr_path.exists() or zarr_path is None: + print(f"Found real data at: {zarr_path}") + + normalizations = [ + NormalizeIntensity() + ] + + print("Testing with real data format...") + + real_like_data = create_synthetic_data(batch_size=1, channels=2, depth=16, height=256, width=256) + + normalized_data = (real_like_data - real_like_data.mean()) / real_like_data.std() + normalized_data = normalized_data.to(device) + + print(f"Normalized data stats - mean: {normalized_data.mean():.6f}, std: {normalized_data.std():.6f}") + + with torch.no_grad(): + real_output = model(normalized_data) + + losses = compute_vae_losses(real_output, normalized_data, beta=1.0) + print(f"\nPerfect reconstruction test (beta=1.0):") + print(f" Reconstruction Loss: {losses['recon_loss']:.6f}") + print(f" KL Loss: {losses['kl_loss']:.6f}") + print(f" Total Loss: {losses['total_loss']:.6f}") + + else: + raise NotImplementedError("not implemented") + + +def print_expected_ranges(): + """Print expected ranges for VAE loss components.""" + print("\n" + "="*60) + print("EXPECTED LOSS MAGNITUDE RANGES") + print("="*60) + print(""" +For Beta-VAE with normalized input (0-mean, 1-std): + +NOTES + +1. RECONSTRUCTION LOSS (MSE): + - Well-trained model: 0.01 - 0.1 + - Untrained/poorly trained: 0.5 - 2.0 + - Perfect reconstruction: < 0.001 + +2. KL DIVERGENCE LOSS: + - Posterior collapse (BAD): < 10 (model ignores latent space) + - Well-regularized: depends on latent dim, but should allow reconstruction + - Over-regularized (BAD): Forces posterior too close to prior, hurts reconstruction + - Typical untrained: can be very high as posterior is random + +3. BETA PARAMETER EFFECTS: + - Beta < 1.0: Prioritizes reconstruction (lower MSE, higher KL) + - Beta = 1.0: Standard VAE balance + - Beta > 1.0: Prioritizes disentanglement (higher MSE, lower KL) + """ + ) + + +if __name__ == "__main__": + print_expected_ranges() + test_vae_magnitudes() + print("\n=== Testing Complete ===") +# %% From ed7f5a7d1cf13f74d5f2b4cff33ad6865ea13022 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Aug 2025 16:26:22 -0700 Subject: [PATCH 027/101] handle monai_vae 2d --- viscy/representation/engine.py | 17 +++++++++++++++-- viscy/representation/vae.py | 8 +++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index f1b944fdd..35b22447d 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -42,8 +42,8 @@ def __init__( schedule: Literal["WarmupCosine", "Constant"] = "Constant", log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, - log_embeddings: bool = False, - embedding_log_frequency: int = 10, + log_embeddings: bool = True, + embedding_log_frequency: int = 20, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -417,6 +417,7 @@ def __init__( self.model = net_class(**model_config) self.model_config = model_config + self.architecture = architecture self.loss_function = loss_function self.beta = beta @@ -499,6 +500,13 @@ def _get_current_beta(self) -> float: def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" + + original_shape = x.shape + is_monai_2d = (self.architecture == "monai_beta" and + self.model_config.get("spatial_dims") == 2) + if is_monai_2d and len(x.shape) == 5 and x.shape[2] == 1: + x = x.squeeze(2) + # Handle different model output formats model_output = self.model(x) @@ -506,6 +514,10 @@ def forward(self, x: Tensor) -> dict: mu = model_output.mean logvar = model_output.logvar z = model_output.z + + if is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1: + # Convert back (B, C, H, W) to (B, C, 1, H, W) + recon_x = recon_x.unsqueeze(2) current_beta = self._get_current_beta() @@ -513,6 +525,7 @@ def forward(self, x: Tensor) -> dict: # NOTE: normalizing by the batch size recon_loss = self.loss_function(recon_x, x) + kl_loss = ( -0.5 * current_beta diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index d92ab3fad..fcf74e22d 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -6,6 +6,7 @@ import torch from monai.networks.blocks import ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer +from monai.networks.layers.factories import Norm from monai.networks.nets import VarAutoEncoder from torch import Tensor, nn @@ -25,7 +26,7 @@ def __init__( scale_factor: int, mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, - norm_name: str, + norm_name: Literal["batch", "instance"], upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() @@ -178,7 +179,7 @@ def forward(self, x: Tensor) -> SimpleNamespace: features = self.encoder(x) - # NOTE: taking the highest resolution features and flatten + # NOTE: taking the highest semantic features and flatten # When features_only=False, encoder returns single tensor, not list if isinstance(features, list): x = features[-1] # [B, C, H, W] @@ -210,7 +211,7 @@ def __init__( head_pool: bool = False, upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, - norm_name: str = "batch", + norm_name: Literal["batch", "instance"] = "batch", upsample_pre_conv: Literal["default"] | Callable | None = None, ): super().__init__() @@ -364,6 +365,7 @@ def __init__(self, up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, use_sigmoid: bool = False, + norm: str= Norm.BATCH, **kwargs ): super().__init__() From 5e27e7c820155d790d2f4938ca9df7f8a7243752 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Aug 2025 16:26:49 -0700 Subject: [PATCH 028/101] redifining rotation agumentsations --- viscy/transforms/__init__.py | 2 ++ viscy/transforms/_redef.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 495069184..2803df97f 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -6,6 +6,7 @@ RandFlipd, RandGaussianNoised, RandGaussianSmoothd, + RandRotate90d, RandScaleIntensityd, RandSpatialCropd, RandWeightedCropd, @@ -28,6 +29,7 @@ "RandFlipd", "RandGaussianNoised", "RandGaussianSmoothd", + "RandRotate90d", "RandInvertIntensityd", "RandScaleIntensityd", "RandSpatialCropd", diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index d79a4aff9..4d827f3dc 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -10,6 +10,7 @@ RandFlipd, RandGaussianNoised, RandGaussianSmoothd, + RandRotate90d, RandScaleIntensityd, RandSpatialCropd, RandWeightedCropd, @@ -182,3 +183,13 @@ def __init__( **kwargs, ): super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) + +class RandRotate90d(RandRotate90d): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + spatial_axes: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, spatial_axes=spatial_axes, **kwargs) \ No newline at end of file From f12204086beea64c57490a9bbb325c3abf68e503 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Aug 2025 16:27:15 -0700 Subject: [PATCH 029/101] adding optional scaling to phate --- .../evaluation/dimensionality_reduction.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index eb5d43f91..2ba38a8ec 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -10,6 +10,7 @@ def compute_phate( embedding_dataset, + scale_embeddings: bool = False, n_components: int = 2, knn: int = 5, decay: int = 40, @@ -59,11 +60,18 @@ def compute_phate( else embedding_dataset ) + if scale_embeddings: + scaler = StandardScaler() + embeddings_scaled = scaler.fit_transform(embeddings) + else: + embeddings_scaled = embeddings + # Compute PHATE embeddings phate_model = phate.PHATE( - n_components=n_components, knn=knn, decay=decay, **phate_kwargs + n_components=n_components, knn=knn, decay=decay, random_state=42, **phate_kwargs ) - phate_embedding = phate_model.fit_transform(embeddings) + + phate_embedding = phate_model.fit_transform(embeddings_scaled) # Update dataset if requested if update_dataset and isinstance(embedding_dataset, Dataset): From 5a5f4a99a3fbcd8aeae314a4a3f3cefed9827aa8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 14 Aug 2025 14:37:30 -0700 Subject: [PATCH 030/101] adding alias and output 2d --- viscy/data/cell_division_triplet.py | 34 ++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index 4c1b1383b..c9259ea44 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -23,6 +23,16 @@ class CellDivisionTripletDataset(Dataset): + # Hardcoded channel mapping for .npy files + CHANNEL_MAPPING = { + # Channel 0 aliases (brightfield) + 'bf': 0, + 'brightfield': 0, + # Channel 1 aliases (h2b) + 'h2b': 1, + 'nuclei': 1, + } + def __init__( self, data_paths: list[Path], @@ -33,6 +43,7 @@ def __init__( fit: bool = True, time_interval: Literal["any"] | int = "any", return_negative: bool = True, + output_2d: bool = False, ) -> None: """Dataset for triplet sampling of cell division data from npy files. @@ -56,6 +67,8 @@ def __init__( by default "any" return_negative : bool, optional Whether to return the negative sample during the fit stage, by default True + output_2d : bool, optional + Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False """ self.channel_names = channel_names self.anchor_transform = anchor_transform @@ -64,6 +77,7 @@ def __init__( self.fit = fit self.time_interval = time_interval self.return_negative = return_negative + self.output_2d = output_2d # Load and process all data files self.cell_tracks = self._load_data(data_paths) @@ -131,8 +145,9 @@ def _sample_positive(self, anchor_info: dict) -> Tensor: positive_t = anchor_t + self.time_interval positive_patch = track["data"][positive_t] # Shape: (C, Y, X) - # Add depth dimension: (C, Y, X) -> (C, D=1, Y, X) - positive_patch = positive_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + # Add depth dimension only if not output_2d: (C, Y, X) -> (C, D=1, Y, X) + if not self.output_2d: + positive_patch = positive_patch.unsqueeze(1) # Shape: (C, 1, Y, X) return positive_patch def _sample_negative(self, anchor_info: dict) -> Tensor: @@ -173,8 +188,9 @@ def _sample_negative(self, anchor_info: dict) -> Tensor: negative_patch = neg_track["data"][neg_t] - # Add depth dimension: (C, Y, X) -> (C, D=1, Y, X) - negative_patch = negative_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + # Add depth dimension only if not output_2d: (C, Y, X) -> (C, D=1, Y, X) + if not self.output_2d: + negative_patch = negative_patch.unsqueeze(1) # Shape: (C, 1, Y, X) return negative_patch def __getitem__(self, index: int) -> TripletSample: @@ -182,9 +198,10 @@ def __getitem__(self, index: int) -> TripletSample: track = anchor_info["track"] anchor_t = anchor_info["timepoint"] - # Get anchor patch and add depth dimension + # Get anchor patch and add depth dimension only if not output_2d anchor_patch = track["data"][anchor_t] # Shape: (C, Y, X) - anchor_patch = anchor_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + if not self.output_2d: + anchor_patch = anchor_patch.unsqueeze(1) # Shape: (C, 1, Y, X) sample = {"anchor": anchor_patch} @@ -246,6 +263,7 @@ def __init__( augment_validation: bool = True, time_interval: Literal["any"] | int = "any", return_negative: bool = True, + output_2d: bool = False, persistent_workers: bool = False, prefetch_factor: int | None = None, pin_memory: bool = False, @@ -276,6 +294,8 @@ def __init__( Future time interval to sample positive and anchor from, by default "any" return_negative : bool, optional Whether to return the negative sample during the fit stage, by default True + output_2d : bool, optional + Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False persistent_workers : bool, optional Whether to keep worker processes alive between iterations, by default False prefetch_factor : int | None, optional @@ -305,6 +325,7 @@ def __init__( self.data_path = Path(data_path) self.time_interval = time_interval self.return_negative = return_negative + self.output_2d = output_2d self.augment_validation = augment_validation # Find all npy files in the data directory @@ -319,6 +340,7 @@ def _base_dataset_settings(self) -> dict: return { "channel_names": self.source_channel, "time_interval": self.time_interval, + "output_2d": self.output_2d, } def _setup_fit(self, dataset_settings: dict): From 5597aece168b4ef3f2c25192c90d37e6e2add832 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 15 Aug 2025 14:40:19 -0700 Subject: [PATCH 031/101] normalizing by also the latent dim and swapping to FP32 for forward pass to avoid overflow with log and exp --- viscy/representation/engine.py | 23 ++++++++++++++--------- viscy/representation/vae_logging.py | 2 -- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 35b22447d..b1b0351da 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -465,19 +465,20 @@ def setup(self, stage: str = None): def _get_current_beta(self) -> float: """Get current beta value based on scheduling.""" if self.beta_schedule is None: - return self.beta + return max(self.beta, 1e-6) epoch = self.current_epoch if self.beta_schedule == "linear": # Linear warmup from beta_min to beta if epoch < self.beta_warmup_epochs: - return ( + beta_val = ( self.beta_min + (self.beta - self.beta_min) * epoch / self.beta_warmup_epochs ) + return max(beta_val, 1e-6) else: - return self.beta + return max(self.beta, 1e-6) elif self.beta_schedule == "cosine": # Cosine warmup from beta_min to beta @@ -485,19 +486,22 @@ def _get_current_beta(self) -> float: import math progress = epoch / self.beta_warmup_epochs - return self.beta_min + (self.beta - self.beta_min) * 0.5 * ( + beta_val = self.beta_min + (self.beta - self.beta_min) * 0.5 * ( 1 + math.cos(math.pi * (1 - progress)) ) + return max(beta_val, 1e-6) else: - return self.beta + return max(self.beta, 1e-6) elif self.beta_schedule == "warmup": # Keep beta_min for warmup epochs, then jump to beta - return self.beta_min if epoch < self.beta_warmup_epochs else self.beta + beta_val = self.beta_min if epoch < self.beta_warmup_epochs else self.beta + return max(beta_val, 1e-6) else: - return self.beta + return max(self.beta, 1e-6) + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" @@ -522,15 +526,16 @@ def forward(self, x: Tensor) -> dict: current_beta = self._get_current_beta() batch_size = x.size(0) + latent_dim = mu.size(1) + normalizer = batch_size * latent_dim - # NOTE: normalizing by the batch size recon_loss = self.loss_function(recon_x, x) kl_loss = ( -0.5 * current_beta * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - / batch_size + / normalizer ) total_loss = recon_loss + kl_loss diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index b97042bfb..3477e53cf 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -81,7 +81,6 @@ def log_enhanced_metrics( f"loss/total/{stage}": total_loss, f"loss/reconstruction/{stage}": recon_loss, f"loss/kl/{stage}": kl_loss, - f"loss/weighted_kl/{stage}": beta * kl_loss, f"loss/mae/{stage}": mae_loss, f"beta/{stage}": beta, f"loss/kl_recon_ratio/{stage}": kl_recon_ratio, @@ -505,7 +504,6 @@ def log_disentanglement_metrics( vae_model=vae_model, dataloader=dataloader, max_samples=max_samples, - sync_dist=sync_dist, ) # Log metrics with organized naming From c782928046854edbf7223c85beed427699e93f01 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Aug 2025 14:38:49 -0700 Subject: [PATCH 032/101] update test for magnitudes --- .../DynaCLR/MonaiVAE/test_vae_magnitudes.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py b/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py index d0f72e407..8a6a1612e 100644 --- a/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py +++ b/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py @@ -7,6 +7,7 @@ from monai.transforms import ( NormalizeIntensity, ) +from torch.nn import KLDivLoss, MSELoss from torchview import draw_graph from viscy.representation.vae import BetaVae25D, BetaVaeMonai @@ -15,10 +16,17 @@ def compute_vae_losses(model_output, target, beta=1.0): """Compute VAE losses: reconstruction (MSE) and KL divergence. """ - recon_loss = F.mse_loss(model_output.recon_x, target, reduction='mean') + mse_loss_fn = MSELoss(reduction='mean') + recon_loss = mse_loss_fn(model_output.recon_x, target) + + # Standard VAE: per-sample, per-dimension KL loss normalization + batch_size = target.size(0) + latent_dim = model_output.mean.size(1) # Get latent dimension + normalizer = batch_size * latent_dim # Normalize by both batch size and latent dim kl_loss = -0.5 * torch.sum(1 + model_output.logvar - model_output.mean.pow(2) - model_output.logvar.exp()) - kl_loss = kl_loss / model_output.mean.size(0) + print(f" Debug - KL raw: {kl_loss.item():.6f}, normalizer: {normalizer}, batch_size: {target.size(0)}") + kl_loss = kl_loss / normalizer total_loss = recon_loss + beta * kl_loss @@ -154,7 +162,7 @@ def test_vae_magnitudes(): print(f"Latent shape: {synthetic_output.z.shape}") for beta in beta_values: - losses = compute_vae_losses(synthetic_output, synthetic_target, beta) + losses = compute_vae_losses(model_output=synthetic_output, target=synthetic_target, beta=beta) print(f"\nBeta = {beta}:") print(f" Mu shape: {losses['mu'].shape}, mean: {losses['mu'].mean():.6f}, std: {losses['mu'].std():.6f}") print(f" Logvar shape: {losses['logvar'].shape}, mean: {losses['logvar'].mean():.6f}, std: {losses['logvar'].std():.6f}") @@ -170,7 +178,7 @@ def test_vae_magnitudes(): # data_path = "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV" # zarr_path = Path(data_path) / "2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr" zarr_path = None - if not zarr_path.exists() or zarr_path is None: + if not zarr_path: print(f"Found real data at: {zarr_path}") normalizations = [ @@ -189,7 +197,7 @@ def test_vae_magnitudes(): with torch.no_grad(): real_output = model(normalized_data) - losses = compute_vae_losses(real_output, normalized_data, beta=1.0) + losses = compute_vae_losses(model_output=real_output, target=normalized_data, beta=1.0) print(f"\nPerfect reconstruction test (beta=1.0):") print(f" Reconstruction Loss: {losses['recon_loss']:.6f}") print(f" KL Loss: {losses['kl_loss']:.6f}") From 6fce186527e808b400fcbd0580b20def41543491 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Aug 2025 14:39:14 -0700 Subject: [PATCH 033/101] expose the normalization for vae --- viscy/representation/vae.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index fcf74e22d..1e2d00ddf 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -365,7 +365,7 @@ def __init__(self, up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, use_sigmoid: bool = False, - norm: str= Norm.BATCH, + norm: Literal[Norm.BATCH, Norm.INSTANCE] = Norm.INSTANCE, **kwargs ): super().__init__() @@ -380,6 +380,7 @@ def __init__(self, self.up_kernel_size = up_kernel_size self.num_res_units = num_res_units self.use_sigmoid = use_sigmoid + self.norm = norm self.model = VarAutoEncoder( spatial_dims=self.spatial_dims, @@ -392,6 +393,7 @@ def __init__(self, up_kernel_size=self.up_kernel_size, num_res_units=self.num_res_units, use_sigmoid=self.use_sigmoid, + norm=self.norm, **kwargs ) From dd3da5df8ca7c701f77b2d99c4455a7edad33ccc Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 24 Aug 2025 17:50:06 -0700 Subject: [PATCH 034/101] add sam 2 test --- .../benchmarking/DynaCLR/SAM2/run_sam2.sh | 18 + .../benchmarking/DynaCLR/SAM2/sam2_config.yml | 59 +++ .../DynaCLR/SAM2/sam2_embeddings.py | 352 ++++++++++++++++++ .../DynaCLR/SAM2/test_sam2_visualization.py | 204 ++++++++++ 4 files changed, 633 insertions(+) create mode 100644 applications/benchmarking/DynaCLR/SAM2/run_sam2.sh create mode 100644 applications/benchmarking/DynaCLR/SAM2/sam2_config.yml create mode 100644 applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py create mode 100644 applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py diff --git a/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh new file mode 100644 index 000000000..405499ec8 --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +#SBATCH --job-name=dynaclr_imagenet +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem-per-cpu=7G +#SBATCH --time=0-02:00:00 +#SBATCH --output=./slurm_logs/%j_dynaclr_sam2.out + + +module load anaconda/latest +conda activate viscy + +CONFIG_PATH=/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sensor_only.yml +python /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py -c $CONFIG_PATH \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml new file mode 100644 index 000000000..19f3ea51f --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml @@ -0,0 +1,59 @@ +datamodule: + batch_size: 32 + final_yx_patch_size: + - 192 + - 192 + include_fov_names: null + include_track_ids: null + initial_yx_patch_size: + - 192 + - 192 + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - Phase3D + lower: 50 + upper: 99 + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - raw GFP EX488 EM525-45 + lower: 50 + upper: 99 + num_workers: 10 + source_channel: + - Phase3D + - raw GFP EX488 EM525-45 + z_range: + - 25 + - 40 +embedding: + pca_kwargs: + n_components: 8 + phate_kwargs: + decay: 40 + knn: 5 + n_components: 2 + n_jobs: -1 + random_state: 42 + reductions: + - PHATE + - PCA +execution: + overwrite: false + save_config: true + show_config: true +model: + model_name: facebook/sam2-hiera-base-plus + channel_reduction_methods: + Phase3D: middle_slice + raw GFP EX488 EM525-45: max +paths: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + output_path: /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sec61b_n_phase_all_highresfeats0.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py new file mode 100644 index 000000000..ef823bc46 --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -0,0 +1,352 @@ +import importlib +import logging +import os +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import click +import torch +import yaml +from lightning.pytorch import LightningModule +from sam2.sam2_image_predictor import SAM2ImagePredictor +from skimage.exposure import rescale_intensity + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import EmbeddingWriter +from viscy.trainer import VisCyTrainer + + +class SAM2Module(LightningModule): + def __init__( + self, + model_name: str = "facebook/sam2-hiera-base-plus", + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + ): + super().__init__() + self.model_name = model_name + self.channel_reduction_methods = channel_reduction_methods or {} + self.channel_names = channel_names or [] + + torch.set_float32_matmul_precision("high") + self.model = None # Initialize in on_predict_start when device is set + + def on_predict_start(self): + """Initialize model with proper device when prediction starts""" + if self.model is None: + self.model = SAM2ImagePredictor.from_pretrained( + self.model_name, device=self.device + ) + + def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: + """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + + Args: + x: 5D input tensor + + Returns: + 4D tensor after applying reduction methods + """ + if x.dim() != 5: + return x + + B, C, D, H, W = x.shape + result = torch.zeros((B, C, H, W), device=x.device) + + # Process all channels at once for each reduction method to minimize loops + middle_slice_indices = [] + mean_indices = [] + max_indices = [] + + # Group channels by reduction method + for c in range(C): + channel_name = ( + self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" + ) + method = self.channel_reduction_methods.get(channel_name, "middle_slice") + + if method == "mean": + mean_indices.append(c) + elif method == "max": + max_indices.append(c) + else: # Default to middle_slice for any unknown method + middle_slice_indices.append(c) + + # Apply middle_slice reduction to all relevant channels at once + if middle_slice_indices: + indices = torch.tensor(middle_slice_indices, device=x.device) + result[:, indices] = x[:, indices, D // 2] + + # Apply mean reduction to all relevant channels at once + if mean_indices: + indices = torch.tensor(mean_indices, device=x.device) + result[:, indices] = x[:, indices].mean(dim=2) + + # Apply max reduction to all relevant channels at once + if max_indices: + indices = torch.tensor(max_indices, device=x.device) + result[:, indices] = x[:, indices].max(dim=2)[0] + + return result + + def _convert_to_rgb(self, x: torch.Tensor) -> list: + """Convert input tensor to 3-channel RGB format as needed for SAM2. + + Args: + x: Input tensor with 1, 2, or 3+ channels + + Returns: + List of numpy arrays in HWC format for SAM2 + """ + # Convert to RGB and scale to [0, 255] range for SAM2 + if x.shape[1] == 1: + x_rgb = x.repeat(1, 3, 1, 1) * 255.0 + elif x.shape[1] == 2: + x_3ch = torch.zeros( + (x.shape[0], 3, x.shape[2], x.shape[3]), device=x.device + ) + x[:, 0] = rescale_intensity(x[:, 0], out_range="uint8") + x[:, 1] = rescale_intensity(x[:, 1], out_range="uint8") + + x_3ch[:, 0] = x[:, 0] + x_3ch[:, 1] = x[:, 1] + x_3ch[:, 2] = 0.5 * (x[:, 0] + x[:, 1]) # B channel as blend + + elif x.shape[1] == 3: + x_rgb = rescale_intensity(x, out_range="uint8") + else: + # More than 3 channels, normalize first 3 and scale + x_3ch = x[:, :3] + x_rgb = rescale_intensity(x_3ch, out_range="uint8") + + # Convert to list of numpy arrays in HWC format for SAM2 + return [ + x_rgb[i].cpu().numpy().transpose(1, 2, 0) for i in range(x_rgb.shape[0]) + ] + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Extract features from the input images. + + Returns: + Dictionary with features, properly shaped empty projections tensor, and index information + """ + x = batch["anchor"] + + # Handle 5D input (B, C, D, H, W) using configured reduction methods + if x.dim() == 5: + x = self._reduce_5d_input(x) + + # Convert input to RGB format and get list of numpy arrays in HWC format for SAM2 + image_list = self._convert_to_rgb(x) + self.model.set_image_batch(image_list) + + # Extract features + # features_0 = self.model._features["image_embed"].mean(dim=(2, 3)) + # features_1 = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) + # features_2 = self.model._features["high_res_feats"][1].mean(dim=(2, 3)) + # features = torch.concat([features_0, features_1, features_2], dim=1) + features = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) + + # Return features and empty projections with correct batch dimension + return { + "features": features, + "projections": torch.zeros((features.shape[0], 0), device=features.device), + "index": batch["index"], + } + + +def load_config(config_file): + """Load configuration from a YAML file.""" + with open(config_file, "r") as f: + config = yaml.safe_load(f) + return config + + +def load_normalization_from_config(norm_config): + """Load a normalization transform from a configuration dictionary.""" + class_path = norm_config["class_path"] + init_args = norm_config.get("init_args", {}) + + # Split module and class name + module_path, class_name = class_path.rsplit(".", 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the class + transform_class = getattr(module, class_name) + + # Instantiate the transform + return transform_class(**init_args) + + +@click.command() +@click.option( + "--config", + "-c", + type=click.Path(exists=True), + required=True, + help="Path to YAML configuration file", +) +def main(config): + """Extract SAM2 embeddings and save to zarr format using VisCy Trainer.""" + # Configure logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load config file + cfg = load_config(config) + logger.info(f"Loaded configuration from {config}") + + # Prepare datamodule parameters + dm_params = {} + + # Add data and tracks paths from the paths section + if "paths" not in cfg: + raise ValueError("Configuration must contain a 'paths' section") + + if "data_path" not in cfg["paths"]: + raise ValueError( + "Data path is required in the configuration file (paths.data_path)" + ) + dm_params["data_path"] = cfg["paths"]["data_path"] + + if "tracks_path" not in cfg["paths"]: + raise ValueError( + "Tracks path is required in the configuration file (paths.tracks_path)" + ) + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Add datamodule parameters + if "datamodule" not in cfg: + raise ValueError("Configuration must contain a 'datamodule' section") + + # Prepare normalizations + if ( + "normalizations" not in cfg["datamodule"] + or not cfg["datamodule"]["normalizations"] + ): + raise ValueError( + "Normalizations are required in the configuration file (datamodule.normalizations)" + ) + + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy all other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + # Handle patch sizes + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + # Set up the data module + logger.info("Setting up data module") + dm = TripletDataModule(**dm_params) + + # Get model parameters for handling 5D inputs + channel_reduction_methods = {} + + if "model" in cfg and "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + + # Initialize SAM2 model with reduction settings + logger.info("Loading SAM2 model") + model = SAM2Module( + model_name=cfg["model"]["model_name"], + channel_reduction_methods=channel_reduction_methods, + ) + + # Get dimensionality reduction parameters from config + phate_kwargs = None + pca_kwargs = None + + if "embedding" in cfg: + if "phate_kwargs" in cfg["embedding"]: + phate_kwargs = cfg["embedding"]["phate_kwargs"] + if "pca_kwargs" in cfg["embedding"]: + pca_kwargs = cfg["embedding"]["pca_kwargs"] + # Check if output path exists and should be overwritten + if "output_path" not in cfg["paths"]: + raise ValueError( + "Output path is required in the configuration file (paths.output_path)" + ) + + output_path = Path(cfg["paths"]["output_path"]) + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + + overwrite = False + if "execution" in cfg and "overwrite" in cfg["execution"]: + overwrite = cfg["execution"]["overwrite"] + elif output_path.exists(): + logger.warning(f"Output path {output_path} already exists, will overwrite") + overwrite = True + + # Set up EmbeddingWriter callback + embedding_writer = EmbeddingWriter( + output_path=output_path, + phate_kwargs=phate_kwargs, + pca_kwargs=pca_kwargs, + overwrite=overwrite, + ) + + # Set up and run VisCy trainer + logger.info("Setting up VisCy trainer") + trainer = VisCyTrainer( + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + callbacks=[embedding_writer], + inference_mode=True, + ) + + logger.info(f"Running prediction and saving to {output_path}") + trainer.predict(model, datamodule=dm) + + # Save configuration if requested + save_config_flag = True + show_config_flag = True + + if "execution" in cfg: + if "save_config" in cfg["execution"]: + save_config_flag = cfg["execution"]["save_config"] + if "show_config" in cfg["execution"]: + show_config_flag = cfg["execution"]["show_config"] + + # Save configuration if requested + if save_config_flag: + config_path = os.path.join(output_dir, "config.yml") + with open(config_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False) + logger.info(f"Configuration saved to {config_path}") + + # Display configuration if requested + if show_config_flag: + click.echo("\nConfiguration used:") + click.echo("-" * 40) + for key, value in cfg.items(): + click.echo(f"{key}:") + if isinstance(value, dict): + for subkey, subvalue in value.items(): + if isinstance(subvalue, list) and subkey == "normalizations": + click.echo(f" {subkey}:") + for norm in subvalue: + click.echo(f" - class_path: {norm['class_path']}") + click.echo(f" init_args: {norm['init_args']}") + else: + click.echo(f" {subkey}: {subvalue}") + else: + click.echo(f" {value}") + click.echo("-" * 40) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py b/applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py new file mode 100644 index 000000000..1ab3b1082 --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py @@ -0,0 +1,204 @@ +# %% +""" +Test script to visualize SAM2 input images and feature processing. +This script helps debug what images are being passed to SAM2 and how they're processed. +""" + +import matplotlib.pyplot as plt +import numpy as np +import torch +import yaml +from pathlib import Path +import sys +import os + +from sam2_embeddings import SAM2Module, load_config, load_normalization_from_config +from viscy.data.triplet import TripletDataModule + + +def visualize_rgb_conversion(x_original, x_rgb_list, save_dir="./debug_images"): + """Visualize the RGB conversion process""" + os.makedirs(save_dir, exist_ok=True) + + print(f"Original input shape: {x_original.shape}") + print(f"Original input range: [{x_original.min():.3f}, {x_original.max():.3f}]") + + # Plot original channels + B, C = x_original.shape[:2] + fig, axes = plt.subplots(3, max(3, C), figsize=(15, 12)) + + # Plot original channels + for c in range(C): + ax = axes[0, c] if C > 1 else axes[0, 0] + img = x_original[0, c].cpu().numpy() + im = ax.imshow(img, cmap="gray") + ax.set_title(f"Original Channel {c}") + ax.axis("off") + plt.colorbar(im, ax=ax) + + # Plot RGB conversion + rgb_img = x_rgb_list[0] # First batch item + print(f"RGB image shape: {rgb_img.shape}") + print(f"RGB image range: [{rgb_img.min():.3f}, {rgb_img.max():.3f}]") + + for c in range(3): + ax = axes[1, c] + im = ax.imshow(rgb_img[:, :, c], cmap="gray") + ax.set_title(f"RGB Channel {c}") + ax.axis("off") + plt.colorbar(im, ax=ax) + + # Plot merged RGB image + ax = axes[2, 0] + # Normalize to 0-1 for display + rgb_display = rgb_img.copy() + rgb_display = (rgb_display - rgb_display.min()) / (rgb_display.max() - rgb_display.min()) + im = ax.imshow(rgb_display) + ax.set_title("Merged RGB Image") + ax.axis("off") + + # Check if RGB is properly scaled to 0-255 + ax = axes[2, 1] + ax.text(0.1, 0.8, f"RGB Range: [{rgb_img.min():.1f}, {rgb_img.max():.1f}]", transform=ax.transAxes) + ax.text(0.1, 0.6, f"Expected: [0, 255]", transform=ax.transAxes) + ax.text(0.1, 0.4, f"Properly scaled: {rgb_img.min() >= 0 and rgb_img.max() <= 255}", transform=ax.transAxes) + ax.text(0.1, 0.2, f"Mean: {rgb_img.mean():.1f}", transform=ax.transAxes) + ax.set_title("RGB Scaling Check") + ax.axis("off") + + plt.tight_layout() + plt.savefig(f"{save_dir}/rgb_conversion.png", dpi=150, bbox_inches="tight") + plt.close() + + +def test_sam2_processing(config_path, num_samples=3): + """Test SAM2 processing with visualization""" + + # Load configuration + cfg = load_config(config_path) + print(f"Loaded config from: {config_path}") + + # Setup data module (same as in main function) + dm_params = {} + dm_params["data_path"] = cfg["paths"]["data_path"] + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Setup normalizations + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + print("Setting up data module...") + dm = TripletDataModule(**dm_params) + dm.setup(stage="predict") + + # Get model parameters + channel_reduction_methods = {} + if "model" in cfg and "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + + # Initialize SAM2 model + print("Loading SAM2 model...") + model = SAM2Module( + model_name=cfg["model"]["model_name"], + channel_reduction_methods=channel_reduction_methods, + ) + + # Get dataloader + predict_dataloader = dm.predict_dataloader() + + print(f"Testing with {num_samples} samples...") + + # Test processing + for i, batch in enumerate(predict_dataloader): + if i >= num_samples: + break + + print(f"\n--- Sample {i+1} ---") + x = batch["anchor"] + print(f"Input tensor shape: {x.shape}") + print(f"Input tensor range: [{x.min():.3f}, {x.max():.3f}]") + + # Test 5D reduction if needed + if x.dim() == 5: + print("Applying 5D reduction...") + x_reduced = model._reduce_5d_input(x) + print(f"After 5D reduction: {x_reduced.shape}") + print(f"Reduction methods: {model.channel_reduction_methods}") + else: + x_reduced = x + + # Test RGB conversion + print("Converting to RGB...") + x_rgb_list = model._convert_to_rgb(x_reduced) + print(f"RGB conversion result: {len(x_rgb_list)} images") + print(f"First RGB image shape: {x_rgb_list[0].shape}") + + # Visualize the conversion + visualize_rgb_conversion(x_reduced, x_rgb_list, f"./debug_images/sample_{i}") + + # Test feature extraction (if model is available) + try: + print("Testing feature extraction...") + model.model = model.model or model.on_predict_start() + model.model.set_image_batch(x_rgb_list) + + # Check what features are available + features_dict = model.model._features + print(f"Available features: {list(features_dict.keys())}") + + if "high_res_feats" in features_dict: + high_res_feats = features_dict["high_res_feats"] + print(f"High-res features length: {len(high_res_feats)}") + for j, feat in enumerate(high_res_feats): + print(f" Layer {j}: {feat.shape}") + + if "image_embed" in features_dict: + image_embed = features_dict["image_embed"] + print(f"Image embed shape: {image_embed.shape}") + + # Extract final features (current approach) + features = model.model._features["high_res_feats"][1].mean(dim=(2, 3)) + print(f"Final features shape: {features.shape}") + print(f"Final features range: [{features.min():.3f}, {features.max():.3f}]") + + except Exception as e: + print(f"Feature extraction failed: {e}") + + print("-" * 50) + + +def main(): + """Main function to run the test""" + config_path = "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sensor_only.yml" + + if not Path(config_path).exists(): + print(f"Config file not found: {config_path}") + print("Please provide a valid config file path") + return + + try: + test_sam2_processing(config_path, num_samples=3) + print("\nTest completed successfully!") + print("Check ./debug_images/ for visualization outputs") + except Exception as e: + print(f"Test failed: {e}") + import traceback + + traceback.print_exc() + + +# %% +if __name__ == "__main__": + main() + +# %% From 8ae92f79ae98c32a41cc77cf4ff7bd64dcb561fe Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 26 Aug 2025 13:56:17 -0700 Subject: [PATCH 035/101] refactor smoothness metrics --- .../smoothness/compute_smoothness.py | 127 +++++++++++ viscy/representation/evaluation/smoothness.py | 209 ++++++++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py create mode 100644 viscy/representation/evaluation/smoothness.py diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py new file mode 100644 index 000000000..ad3bdb6a5 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -0,0 +1,127 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from lightning.pytorch import seed_everything +from matplotlib.patches import FancyArrowPatch + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.smoothness import compute_embeddings_smoothness + +colormap = { + 2: "orange", + 1: "steelblue", +} +#%% +# FEATURES + +# openphenom_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/open_phenom/features/open_phenom_features.csv") +# imagenet_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/imagenet/features/imagenet_features.csv") +dynaclr_features_path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/dtw_evaluation/SAM2/sam2_sensor_only.zarr") + +# ANNOTATIONS +ann_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + +# TRACKS + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +# LOADING DATASETS +# openphenom_features = read_embedding_dataset(openphenom_features_path) +# imagenet_features = read_embedding_dataset(imagenet_features_path) +dynaclr_embedding_dataset = read_embedding_dataset(dynaclr_features_path) +#%% +# Compute the smoothness of the features +DISTANCE_METRIC = "cosine" +feature_paths ={ + "dynaclr": dynaclr_features_path, +} +cmap = plt.get_cmap("tab10") # or use "Set2", "tab20", etc. +labels = list(feature_paths.keys()) +interval_colors = {label: cmap(i % cmap.N) for i, label in enumerate(labels)} +# Print and check each path +for label, path in feature_paths.items(): + print(f"{label} color: {interval_colors[label]}") + assert Path(path).exists(), f"Path {path} does not exist" + +output_dir = Path("./smoothness_metrics") +output_dir.mkdir(parents=True, exist_ok=True) + +results = {} +for label, path in feature_paths.items(): + results[label] = {} + print(f"\nProcessing - {label}") + embedding_dataset = read_embedding_dataset(Path(path)) + + # Compute displacements + stats, distributions, _ = compute_embeddings_smoothness( + prediction_path=Path(path), + distance_metric=DISTANCE_METRIC, + verbose=True, + ) + + # Plot the piecewise distances + fig = plt.figure() + sns.histplot( + distributions["adjacent_frame_distribution"], + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + distributions["random_frame_distribution"], + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", + ) + plt.xlabel(f"{DISTANCE_METRIC} Distance") + plt.ylabel("Density") + # Add vertical lines for the peaks + plt.axvline( + x=stats["adjacent_frame_peak"], color="cyan", linestyle="--", alpha=0.8 + ) + plt.axvline(x=stats["random_frame_peak"], color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + plt.savefig(output_dir/f"{label}_smoothness.pdf", dpi=300) + plt.savefig(output_dir/f"{label}_smoothness.png", dpi=300) + plt.close() + + #metrics to csv + scalar_metrics = { + "adjacent_frame_mean": stats["adjacent_frame_mean"], + "adjacent_frame_std": stats["adjacent_frame_std"], + "adjacent_frame_median": stats["adjacent_frame_median"], + "adjacent_frame_peak": stats["adjacent_frame_peak"], + "random_frame_mean": stats["random_frame_mean"], + "random_frame_std": stats["random_frame_std"], + "random_frame_median": stats["random_frame_median"], + "random_frame_peak": stats["random_frame_peak"], + "dynamic_range": stats["dynamic_range"] + } + # Create DataFrame with single row + stats_df = pd.DataFrame(stats) # Note the list wrapper + stats_df.to_csv(output_dir/f"{label}_smoothness_stats.csv", index=False) + + + + + + + + + +#%% + +#%% \ No newline at end of file diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py new file mode 100644 index 000000000..51be25f7c --- /dev/null +++ b/viscy/representation/evaluation/smoothness.py @@ -0,0 +1,209 @@ +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from scipy.optimize import minimize_scalar +from scipy.signal import find_peaks +from scipy.stats import gaussian_kde +from sklearn.preprocessing import StandardScaler + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) + + +def compute_piece_wise_distance( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray,groupby:list[str] = ["fov_name", "track_id"] +)->tuple[list[list[float]], list[list[float]]]: + """ + Computing the piece-wise distance and rank difference + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + + Parameters + ---------- + features_df : pd.DataFrame + DataFrame containing the features + cross_dist : NDArray + Cross-distance matrix + rank_fractions : NDArray + Rank fractions + groupby : list[str], optional + Columns to group by, by default ["fov_name", "track_id"] + + Returns + ------- + piece_wise_dissimilarity_per_track : list + Piece-wise dissimilarity per track + piece_wise_rank_difference_per_track : list + Piece-wise rank difference per track + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for _, subdata in features_df.groupby(groupby): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def find_distribution_peak(data: np.ndarray, method: Literal["histogram", "kde_robust"] = "kde_robust") -> float: + """ Find the peak of a distribution + + Parameters + ---------- + data: np.ndarray + The data to find the peak of + method: Literal["histogram", "kde_robust"], optional + The method to use to find the peak, by default "kde_robust" + + Returns + ------- + float: The peak of the distribution (highest peak if multiple) + """ + if method == 'histogram': + # Simple histogram-based peak finding + hist, bin_edges = np.histogram(data, bins=50, density=True) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + peaks, properties = find_peaks(hist, height=np.max(hist) * 0.1) # 10% of max height + if len(peaks) == 0: + return bin_centers[np.argmax(hist)] # Fallback to global max + # Return peak with highest density + peak_heights = properties['peak_heights'] + return bin_centers[peaks[np.argmax(peak_heights)]] + + elif method == 'kde_robust': + # More robust KDE approach + kde = gaussian_kde(data) + x_range = np.linspace(np.min(data), np.max(data), 1000) + kde_vals = kde(x_range) + peaks, properties = find_peaks(kde_vals, height=np.max(kde_vals) * 0.1) + if len(peaks) == 0: + return x_range[np.argmax(kde_vals)] # Fallback to global max + # Return peak with highest KDE value + peak_heights = properties['peak_heights'] + return x_range[peaks[np.argmax(peak_heights)]] + + + +def compute_embeddings_smoothness( + prediction_path: Path, + distance_metric: Literal["cosine", "euclidean"] = "cosine", + verbose: bool = False, +) -> tuple[dict, dict, list[list[float]]]: + """ + Compute the smoothness statistics of embeddings + + Parameters + ---------- + prediction_path: Path to the embedding dataset + distance_metric: Distance metric to use, by default "cosine" + + Returns: + ------- + stats: dict: Dictionary containing metrics including: + - adj_frame_mean: Mean of adjacent frame dissimilarity + - adj_frame_std: Standard deviation of adjacent frame dissimilarity + - adj_frame_median: Median of adjacent frame dissimilarity + - adj_frame_peak: Peak of adjacent frame distribution + - adj_frame_p99: 99th percentile of adjacent frame dissimilarity + - adj_frame_p1: 1st percentile of adjacent frame dissimilarity + - adj_frame_distribution: Full distribution of adjacent frame dissimilarities + - random_frame_mean: Mean of random sampling dissimilarity + - random_frame_std: Standard deviation of random sampling dissimilarity + - random_frame_median: Median of random sampling dissimilarity + - random_frame_peak: Peak of random sampling distribution + - random_frame_distribution: Full distribution of random sampling dissimilarities + - dynamic_range: Difference between random and adjacent peaks + distributions: dict: Dictionary containing distributions including: + - adjacent_frame_distribution: Full distribution of adjacent frame dissimilarities + - random_frame_distribution: Full distribution of random sampling dissimilarities + piecewise_distance_per_track: list[list[float]] + Piece-wise distance per track + """ + + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + scaled_features = StandardScaler().fit_transform(features.values) + + # Compute the distance matrix + cross_dist = pairwise_distance_matrix(scaled_features, metric=distance_metric) + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise distance and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piecewise_distance_per_track, _ = ( + compute_piece_wise_distance(features_df, cross_dist, rank_fractions) + ) + + all_piecewise_distances = np.concatenate(piecewise_distance_per_track) + + # p99_piece_wise_distance = np.array( + # [np.percentile(track, 99) for track in piecewise_distance_per_track] + # ) + # p1_percentile_piece_wise_distance = np.array( + # [np.percentile(track, 1) for track in piecewise_distance_per_track] + # ) + + # Random sampling values in the distance matrix with same size as adjacent frame measurements + n_samples = len(all_piecewise_distances) + # Avoid sampling the diagonal elements + np.random.seed(42) + i_indices = np.random.randint(0, len(cross_dist), size=n_samples) + j_indices = np.random.randint(0, len(cross_dist), size=n_samples) + + diagonal_mask = i_indices == j_indices + while diagonal_mask.any(): + j_indices[diagonal_mask] = np.random.randint(0, len(cross_dist), + size=diagonal_mask.sum()) + diagonal_mask = i_indices == j_indices + sampled_values = cross_dist[i_indices, j_indices] + + # Compute the peaks of both distributions using KDE + adjacent_peak = find_distribution_peak(all_piecewise_distances, method="kde_robust") + random_peak = find_distribution_peak(sampled_values, method="kde_robust") + dynamic_range = random_peak - adjacent_peak + + stats = { + "adjacent_frame_mean": float(np.mean(all_piecewise_distances)), + "adjacent_frame_std": float(np.std(all_piecewise_distances)), + "adjacent_frame_median": float(np.median(all_piecewise_distances)), + "adjacent_frame_peak": float(adjacent_peak), + # "adjacent_frame_p99": p99_piece_wise_distance, + # "adjacent_frame_p1": p1_percentile_piece_wise_distance, + # "adjacent_frame_distribution": all_piecewise_distances, + "random_frame_mean": float(np.mean(sampled_values)), + "random_frame_std": float(np.std(sampled_values)), + "random_frame_median": float(np.median(sampled_values)), + "random_frame_peak": float(random_peak), + # "random_frame_distribution": sampled_values, + "dynamic_range": float(dynamic_range), + } + distributions = { + "adjacent_frame_distribution": all_piecewise_distances, + "random_frame_distribution": sampled_values, + } + + if verbose: + for key, value in stats.items(): + print(f"{key}: {value}") + + return stats, distributions, piecewise_distance_per_track + From 9d36a6f49db0db22369bf51e340505c1c762fcff Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 31 Aug 2025 07:44:57 -0700 Subject: [PATCH 036/101] rever to normalalize kl wrt to batch size and removing the the beta min value --- viscy/representation/engine.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index b1b0351da..3c6296b5d 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -442,6 +442,8 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] + self._min_beta = 1e-15 + # Handle different parameter names for latent dimensions latent_dim = None if "latent_dim" in self.model_config: @@ -465,7 +467,7 @@ def setup(self, stage: str = None): def _get_current_beta(self) -> float: """Get current beta value based on scheduling.""" if self.beta_schedule is None: - return max(self.beta, 1e-6) + return max(self.beta, self._min_beta) epoch = self.current_epoch @@ -476,9 +478,9 @@ def _get_current_beta(self) -> float: self.beta_min + (self.beta - self.beta_min) * epoch / self.beta_warmup_epochs ) - return max(beta_val, 1e-6) + return max(beta_val, self._min_beta) else: - return max(self.beta, 1e-6) + return max(self.beta, self._min_beta) elif self.beta_schedule == "cosine": # Cosine warmup from beta_min to beta @@ -489,17 +491,17 @@ def _get_current_beta(self) -> float: beta_val = self.beta_min + (self.beta - self.beta_min) * 0.5 * ( 1 + math.cos(math.pi * (1 - progress)) ) - return max(beta_val, 1e-6) + return max(beta_val, self._min_beta) else: - return max(self.beta, 1e-6) + return max(self.beta, self._min_beta) elif self.beta_schedule == "warmup": # Keep beta_min for warmup epochs, then jump to beta beta_val = self.beta_min if epoch < self.beta_warmup_epochs else self.beta - return max(beta_val, 1e-6) + return max(beta_val, self._min_beta) else: - return max(self.beta, 1e-6) + return max(self.beta, self._min_beta) @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) def forward(self, x: Tensor) -> dict: @@ -535,7 +537,7 @@ def forward(self, x: Tensor) -> dict: -0.5 * current_beta * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - / normalizer + / batch_size ) total_loss = recon_loss + kl_loss From d5bac4ac2dc6562a07384144dc6e5a059ca4882a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 2 Sep 2025 17:18:04 -0700 Subject: [PATCH 037/101] commit dtwembeddings w sam --- .../evaluation/compare_dtw_embeddings_sam2.py | 766 ++++++++++++++++++ 1 file changed, 766 insertions(+) create mode 100644 applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py new file mode 100644 index 000000000..78a5a719b --- /dev/null +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -0,0 +1,766 @@ +# %% +import ast +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from plotting_utils import ( + find_pattern_matches, + identify_lineages, + plot_pc_trajectories, + plot_reference_vs_full_lineages, +) +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.dimensionality_reduction import compute_pca + +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") # Simplified format +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +NAPARI = True +if NAPARI: + import os + + import napari +s + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() +# %% +# Organelle and Phate aligned to infection + +input_data_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +) +infection_annotations_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/combined_annotations_n_tracks_infection.csv" +) + +pretrain_features_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/prediction_pretrained_models" +) +# Phase n organelle +# dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + +# pahe n sensor +dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions_infection/2chan_192patch_100ckpt_timeAware_ntxent_GT.zarr" + +output_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/model_comparison" +) + + +# Load embeddings +imagenet_features_path = ( + pretrain_features_root / "ImageNet/20241107_sensor_n_phase_imagenet.zarr" +) +openphenom_features_path = ( + pretrain_features_root / "OpenPhenom/20241107_sensor_n_phase_openphenom.zarr" +) + +dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path) +imagenet_embeddings = read_embedding_dataset(imagenet_features_path) +openphenom_embeddings = read_embedding_dataset(openphenom_features_path) + +# Load infection annotations +infection_annotations_df = pd.read_csv(infection_annotations_path) +infection_annotations_df["fov_name"] = "/C/2/000001" + +process_embeddings = [ + (dynaclr_embeddings, "dynaclr"), + (imagenet_embeddings, "imagenet"), + (openphenom_embeddings, "openphenom"), +] + + +output_root.mkdir(parents=True, exist_ok=True) +# %% +feature_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True) + +# Logic to find lineages +lineages = identify_lineages(feature_df) +logger.info(f"Found {len(lineages)} distinct lineages") +filtered_lineages = [] +min_timepoints = 20 +for fov_id, track_ids in lineages: + # Get all rows for this lineage + lineage_rows = feature_df[ + (feature_df["fov_name"] == fov_id) & (feature_df["track_id"].isin(track_ids)) + ] + + # Count the total number of timepoints + total_timepoints = len(lineage_rows) + + # Only keep lineages with at least min_timepoints + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) +logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +) + +# %% +# Aligning condition embeddings to infection +# OPTION 1: Use the infection annotations to find the reference lineage +reference_lineage_fov = "/C/2/001000" +reference_lineage_track_id = [129] +reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling + +# Option 2: from the filtered lineages find one from FOV C/2/000001 +reference_lineage_fov = "/C/2/000001" +for fov_id, track_ids in filtered_lineages: + if reference_lineage_fov == fov_id: + break +reference_lineage_track_id = track_ids +reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling + +# %% +# Dictionary to store alignment results for comparison +alignment_results = {} + +for embeddings, name in process_embeddings: + # Get the reference pattern from the current embedding space + reference_pattern = None + reference_lineage = [] + for fov_id, track_ids in filtered_lineages: + if fov_id == reference_lineage_fov and all( + track_id in track_ids for track_id in reference_lineage_track_id + ): + logger.info( + f"Found reference pattern for {fov_id} {reference_lineage_track_id} using {name} embeddings" + ) + reference_pattern = embeddings.sel( + sample=(fov_id, reference_lineage_track_id) + ).features.values + reference_lineage.append(reference_pattern) + break + if reference_pattern is None: + logger.info(f"Reference pattern not found for {name} embeddings. Skipping.") + continue + reference_pattern = np.concatenate(reference_lineage) + reference_pattern = reference_pattern[ + reference_timepoints[0] : reference_timepoints[1] + ] + + # Find all matches to the reference pattern + metric = "cosine" + all_match_positions = find_pattern_matches( + reference_pattern, + filtered_lineages, + embeddings, + window_step_fraction=0.1, + num_candidates=4, + method="bernd_clifford", + save_path=output_root / f"{name}_matching_lineages_{metric}.csv", + metric=metric, + ) + + # Store results for later comparison + alignment_results[name] = all_match_positions + +# Visualize warping paths in PC space instead of raw embedding dimensions +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Call the new function from plotting_utils + plot_pc_trajectories( + reference_lineage_fov=reference_lineage_fov, + reference_lineage_track_id=reference_lineage_track_id, + reference_timepoints=reference_timepoints, + match_positions=match_positions, + embeddings_dataset=next( + emb for emb, emb_name in process_embeddings if emb_name == name + ), + filtered_lineages=filtered_lineages, + name=name, + save_path=output_root / f"{name}_pc_lineage_alignment.png", + ) + + +# %% +# Compare DTW performance between embedding methods + +# Create a DataFrame to collect the alignment statistics for comparison +match_data = [] +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + for i, row in match_positions.head(10).iterrows(): # Take top 10 matches + warping_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + match_data.append( + { + "model": name, + "match_position": row["start_timepoint"], + "dtw_distance": row["distance"], + "path_skewness": row["skewness"], + "path_length": len(warping_path), + } + ) + +comparison_df = pd.DataFrame(match_data) + +# Create visualizations to compare alignment quality +plt.figure(figsize=(12, 10)) + +# 1. Compare DTW distances +plt.subplot(2, 2, 1) +sns.boxplot(x="model", y="dtw_distance", data=comparison_df) +plt.title("DTW Distance by Model") +plt.ylabel("DTW Distance (lower is better)") + +# 2. Compare path skewness +plt.subplot(2, 2, 2) +sns.boxplot(x="model", y="path_skewness", data=comparison_df) +plt.title("Path Skewness by Model") +plt.ylabel("Skewness (lower is better)") + +# 3. Compare path lengths +plt.subplot(2, 2, 3) +sns.boxplot(x="model", y="path_length", data=comparison_df) +plt.title("Warping Path Length by Model") +plt.ylabel("Path Length") + +# 4. Scatterplot of distance vs skewness +plt.subplot(2, 2, 4) +scatter = sns.scatterplot( + x="dtw_distance", y="path_skewness", hue="model", data=comparison_df +) +plt.title("DTW Distance vs Path Skewness") +plt.xlabel("DTW Distance") +plt.ylabel("Path Skewness") +plt.legend(title="Model") + +plt.tight_layout() +plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300) +plt.close() + +# %% +# Analyze warping path step patterns for better understanding of alignment quality + +# Step pattern analysis +step_pattern_counts = { + name: {"diagonal": 0, "horizontal": 0, "vertical": 0, "total": 0} + for name in alignment_results.keys() +} + +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Get the top match + top_match = match_positions.iloc[0] + path = ( + ast.literal_eval(top_match["warp_path"]) + if isinstance(top_match["warp_path"], str) + else top_match["warp_path"] + ) + + # Count step types + for i in range(1, len(path)): + prev_i, prev_j = path[i - 1] + curr_i, curr_j = path[i] + + step_i = curr_i - prev_i + step_j = curr_j - prev_j + + if step_i == 1 and step_j == 1: + step_pattern_counts[name]["diagonal"] += 1 + elif step_i == 1 and step_j == 0: + step_pattern_counts[name]["vertical"] += 1 + elif step_i == 0 and step_j == 1: + step_pattern_counts[name]["horizontal"] += 1 + + step_pattern_counts[name]["total"] += 1 + +# Convert to percentages +for name in step_pattern_counts: + total = step_pattern_counts[name]["total"] + if total > 0: + for key in ["diagonal", "horizontal", "vertical"]: + step_pattern_counts[name][key] = ( + step_pattern_counts[name][key] / total + ) * 100 + +# Visualize step pattern distributions +step_df = pd.DataFrame( + { + "model": [name for name in step_pattern_counts.keys() for _ in range(3)], + "step_type": ["diagonal", "horizontal", "vertical"] * len(step_pattern_counts), + "percentage": [ + step_pattern_counts[name]["diagonal"] for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["horizontal"] + for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["vertical"] for name in step_pattern_counts.keys() + ], + } +) + +plt.figure(figsize=(10, 6)) +sns.barplot(x="model", y="percentage", hue="step_type", data=step_df) +plt.title("Step Pattern Distribution in Warping Paths") +plt.ylabel("Percentage (%)") +plt.savefig(output_root / "step_pattern_distribution.png", dpi=300) +plt.close() + +# %% +# Find all matches to the reference pattern +MODEL = "openphenom" +alignment_df_path = output_root / f"{MODEL}_matching_lineages_cosine.csv" +alignment_df = pd.read_csv(alignment_df_path) + +# Get the top N aligned cells + +source_channels = [ + "Phase3D", + "raw GFP EX488 EM525-45", + "raw mCherry EX561 EM600-37", +] +yx_patch_size = (192, 192) +z_range = (10, 30) +view_ref_sector_only = (True,) + +all_lineage_images = [] +all_aligned_stacks = [] +all_unaligned_stacks = [] + +# Get aligned and unaligned stacks +top_aligned_cells = alignment_df.head(5) +napari_viewer = viewer if NAPARI else None +# Plot the aligned and unaligned stacks +for idx, row in tqdm( + top_aligned_cells.iterrows(), + total=len(top_aligned_cells), + desc="Aligning images", +): + fov_name = row["fov_name"] + track_ids = ast.literal_eval(row["track_ids"]) + warp_path = ast.literal_eval(row["warp_path"]) + start_time = int(row["start_timepoint"]) + + print(f"Aligning images for {fov_name} with track ids: {track_ids}") + data_module = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=source_channels, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + batch_size=1, + num_workers=12, + predict_cells=True, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + ) + data_module.setup("predict") + + # Get the images for the lineage + lineage_images = [] + for batch in data_module.predict_dataloader(): + image = batch["anchor"].numpy()[0] + lineage_images.append(image) + + lineage_images = np.array(lineage_images) + all_lineage_images.append(lineage_images) + print(f"Lineage images shape: {np.array(lineage_images).shape}") + + # Create an aligned stack based on the warping path + if view_ref_sector_only: + aligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + unaligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + + # Map each reference timepoint to the corresponding lineage timepoint + for ref_idx in range(len(reference_pattern)): + # Find matches in warping path for this reference index + matches = [(i, q) for i, q in warp_path if i == ref_idx] + unaligned_stack[ref_idx] = lineage_images[ref_idx] + if matches: + # Get the corresponding lineage timepoint (first match if multiple) + print(f"Found match for ref idx: {ref_idx}") + match = matches[0] + query_idx = match[1] + lineage_idx = int(start_time + query_idx) + print( + f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + ) + # Copy the image if it's within bounds + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Find nearest valid timepoint if out of bounds + nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + else: + # If no direct match, find closest reference timepoint in warping path + print(f"No match found for ref idx: {ref_idx}") + all_ref_indices = [i for i, _ in warp_path] + if all_ref_indices: + closest_ref_idx = min( + all_ref_indices, key=lambda x: abs(x - ref_idx) + ) + closest_matches = [ + (i, q) for i, q in warp_path if i == closest_ref_idx + ] + + if closest_matches: + closest_query_idx = closest_matches[0][1] + lineage_idx = int(start_time + closest_query_idx) + + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Bound to valid range + nearest_idx = min( + max(0, lineage_idx), len(lineage_images) - 1 + ) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + + all_aligned_stacks.append(aligned_stack) + all_unaligned_stacks.append(unaligned_stack) + +all_aligned_stacks = np.array(all_aligned_stacks) +all_unaligned_stacks = np.array(all_unaligned_stacks) +# %% +if NAPARI: + for idx, row in tqdm( + top_aligned_cells.reset_index().iterrows(), + total=len(top_aligned_cells), + desc="Plotting aligned and unaligned stacks", + ): + fov_name = row["fov_name"] + # track_ids = ast.literal_eval(row["track_ids"]) + track_ids = row["track_ids"] + + aligned_stack = all_aligned_stacks[idx] + unaligned_stack = all_unaligned_stacks[idx] + + unaligned_gfp_mip = np.max(unaligned_stack[:, 1, :, :], axis=1) + aligned_gfp_mip = np.max(aligned_stack[:, 1, :, :], axis=1) + unaligned_mcherry_mip = np.max(unaligned_stack[:, 2, :, :], axis=1) + aligned_mcherry_mip = np.max(aligned_stack[:, 2, :, :], axis=1) + + z_slice = 15 + unaligned_phase = unaligned_stack[:, 0, z_slice, :] + aligned_phase = aligned_stack[:, 0, z_slice, :] + + # unaligned + viewer.add_image( + unaligned_gfp_mip, + name=f"unaligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + unaligned_mcherry_mip, + name=f"unaligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + unaligned_phase, + name=f"unaligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + # aligned + viewer.add_image( + aligned_gfp_mip, + name=f"aligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + aligned_mcherry_mip, + name=f"aligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + aligned_phase, + name=f"aligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + viewer.grid.enabled = True + viewer.grid.shape = (-1, 6) +# %% +# Evaluate model performance based on infection state warping accuracy +# Check unique infection status values +unique_infection_statuses = infection_annotations_df["infection_status"].unique() +logger.info(f"Unique infection status values: {unique_infection_statuses}") + +# If "infected" is not in the unique values, this could explain zero precision/recall +if "infected" not in unique_infection_statuses: + logger.warning('The label "infected" is not found in the infection_status column!') + logger.info(f"Using these values instead: {unique_infection_statuses}") + + # If we need to map values, we could do it here + if len(unique_infection_statuses) >= 2: + logger.info( + f'Will treat "{unique_infection_statuses[1]}" as "infected" for metrics calculation' + ) + infection_target_value = unique_infection_statuses[1] + else: + infection_target_value = unique_infection_statuses[0] +else: + infection_target_value = "infected" + +logger.info(f'Using "{infection_target_value}" as positive class for F1 calculation') + +# Check if the reference track is in the annotations +logger.info( + f"Looking for infection annotations for reference lineage: {reference_lineage_fov}, tracks: {reference_lineage_track_id}" +) +print(f"Sample of infection_annotations_df: {infection_annotations_df.head()}") + +reference_infection_states = {} +for track_id in reference_lineage_track_id: + reference_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == reference_lineage_fov) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this reference track + annotation_count = len(reference_annotations) + logger.info(f"Found {annotation_count} annotations for track {track_id}") + if annotation_count > 0: + print( + f"Sample annotations for track {track_id}: {reference_annotations.head()}" + ) + + for _, row in reference_annotations.iterrows(): + reference_infection_states[row["t"]] = row["infection_status"] + +if reference_infection_states: + logger.info( + f"Total reference timepoints with infection status: {len(reference_infection_states)}" + ) + reference_t_range = range(reference_timepoints[0], reference_timepoints[1]) + reference_gt_states = [ + reference_infection_states.get(t, "unknown") for t in reference_t_range + ] + logger.info(f"Reference track infection states: {reference_gt_states[:5]}...") + + # Evaluate warping accuracy for each model + model_performance = [] + + for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + total_correct = 0 + total_predictions = 0 + true_positives = 0 + false_positives = 0 + false_negatives = 0 + + # Analyze top alignments for this model + alignment_details = [] + for i, row in match_positions.head(10).iterrows(): + fov_name = row["fov_name"] + track_ids = row[ + "track_ids" + ] # This is already a list of track IDs for the lineage + warp_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + start_time = int(row["start_timepoint"]) + + # Get annotations for all tracks in this lineage + track_infection_states = {} + for track_id in track_ids: + track_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == fov_name) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this track to the combined dictionary + for _, annotation_row in track_annotations.iterrows(): + # Use t + track-specific offset if needed to handle timepoint overlaps between tracks + track_infection_states[annotation_row["t"]] = annotation_row[ + "infection_status" + ] + + # Only proceed if we found annotations for at least one track + if track_infection_states: + # For each reference timepoint, check if the warped timepoint maintains the infection state + track_correct = 0 + track_predictions = 0 + track_tp = 0 + track_fp = 0 + track_fn = 0 + + for ref_idx, query_idx in warp_path: + # Map to actual timepoints + ref_t = reference_timepoints[0] + ref_idx + query_t = start_time + query_idx + + # Get ground truth infection states + ref_state = reference_infection_states.get(ref_t, "unknown") + query_state = track_infection_states.get(query_t, "unknown") + + # Skip unknown states + if ref_state != "unknown" and query_state != "unknown": + track_predictions += 1 + + # Count correct alignments + if ref_state == query_state: + track_correct += 1 + + # Calculate F1 score components for "infected" state + if ( + ref_state == infection_target_value + and query_state == infection_target_value + ): + track_tp += 1 + elif ( + ref_state != infection_target_value + and query_state == infection_target_value + ): + track_fp += 1 + elif ( + ref_state == infection_target_value + and query_state != infection_target_value + ): + track_fn += 1 + + # Calculate track-specific metrics + if track_predictions > 0: + track_accuracy = track_correct / track_predictions + track_precision = ( + track_tp / (track_tp + track_fp) + if (track_tp + track_fp) > 0 + else 0 + ) + track_recall = ( + track_tp / (track_tp + track_fn) + if (track_tp + track_fn) > 0 + else 0 + ) + track_f1 = ( + 2 + * (track_precision * track_recall) + / (track_precision + track_recall) + if (track_precision + track_recall) > 0 + else 0 + ) + + alignment_details.append( + { + "fov_name": fov_name, + "track_ids": track_ids, + "accuracy": track_accuracy, + "precision": track_precision, + "recall": track_recall, + "f1_score": track_f1, + "correct": track_correct, + "total": track_predictions, + } + ) + + # Add to model totals + total_correct += track_correct + total_predictions += track_predictions + true_positives += track_tp + false_positives += track_fp + false_negatives += track_fn + + # Calculate metrics + accuracy = total_correct / total_predictions if total_predictions > 0 else 0 + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + # Store alignment details for this model + if alignment_details: + alignment_details_df = pd.DataFrame(alignment_details) + print(f"\nDetailed alignment results for {name}:") + print(alignment_details_df) + alignment_details_df.to_csv( + output_root / f"{name}_alignment_details.csv", index=False + ) + + model_performance.append( + { + "model": name, + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1_score": f1, + "total_predictions": total_predictions, + } + ) + + # Create performance DataFrame and visualize + performance_df = pd.DataFrame(model_performance) + print(performance_df) + + # Plot performance metrics + plt.figure(figsize=(12, 8)) + + # Accuracy plot + plt.subplot(2, 2, 1) + sns.barplot(x="model", y="accuracy", data=performance_df) + plt.title("Infection State Warping Accuracy") + plt.ylabel("Accuracy") + + # Precision plot + plt.subplot(2, 2, 2) + sns.barplot(x="model", y="precision", data=performance_df) + plt.title("Precision for Infected State") + plt.ylabel("Precision") + + # Recall plot + plt.subplot(2, 2, 3) + sns.barplot(x="model", y="recall", data=performance_df) + plt.title("Recall for Infected State") + plt.ylabel("Recall") + + # F1 score plot + plt.subplot(2, 2, 4) + sns.barplot(x="model", y="f1_score", data=performance_df) + plt.title("F1 Score for Infected State") + plt.ylabel("F1 Score") + + plt.tight_layout() + # plt.savefig(output_root / "infection_state_warping_performance.png", dpi=300) + # plt.close() +else: + logger.warning("Reference track annotations not found in infection_annotations_df") + +# %% From 1b1f9b7cf550206852e0b2971bcb4299f83599ba Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 5 Sep 2025 10:27:45 -0700 Subject: [PATCH 038/101] added a clamp to logvar, switch to mse loss sum reduction like the original formulation. --- viscy/representation/engine.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 3c6296b5d..59485b614 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -392,7 +392,7 @@ def __init__( self, architecture: Literal["monai_beta","2.5D"], model_config: dict = {}, - loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="mean"), + loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), beta: float = 1.0, beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, beta_min: float = 0.1, @@ -443,6 +443,7 @@ def __init__( self.validation_step_outputs = [] self._min_beta = 1e-15 + self._logvar_minmax = (-20,20) # Handle different parameter names for latent dimensions latent_dim = None @@ -527,20 +528,19 @@ def forward(self, x: Tensor) -> dict: current_beta = self._get_current_beta() - batch_size = x.size(0) - latent_dim = mu.size(1) - normalizer = batch_size * latent_dim + batch_size = original_shape[0] - recon_loss = self.loss_function(recon_x, x) + # Use original input for loss computation to ensure shape consistency + x_original = x if not (is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1) else x.unsqueeze(2) + recon_loss = self.loss_function(recon_x, x_original) + if isinstance(self.loss_function, nn.MSELoss): + if hasattr(self.loss_function, 'reduction') and self.loss_function.reduction == 'sum': + recon_loss = recon_loss / batch_size - kl_loss = ( - -0.5 - * current_beta - * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - / batch_size - ) + kl_loss = -0.5 * torch.sum(1 + torch.clamp(logvar,self._logvar_minmax[0],self._logvar_minmax[1]) - mu.pow(2) - logvar.exp(), dim=1) + kl_loss = torch.mean(kl_loss) - total_loss = recon_loss + kl_loss + total_loss = recon_loss + current_beta * kl_loss return { "recon_x": recon_x, From b012bad22a3135914a072b090c707228eb79026b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 5 Sep 2025 10:28:11 -0700 Subject: [PATCH 039/101] remove unecessary vae logging losses. --- viscy/representation/vae_logging.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 3477e53cf..822fbcb59 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -63,28 +63,15 @@ def log_enhanced_metrics( lambda: getattr(lightning_module, "beta", 1.0), )() - # Record losses and reconstruction quality metrics - kl_recon_ratio = kl_loss / (recon_loss + 1e-8) - - mae_loss = F.l1_loss(recon_x, x) - - # Add gradient explosion diagnostics + # Check for explosion and NaN/Inf grad_diagnostics = self._compute_gradient_diagnostics(lightning_module) - - # Add NaN/Inf detection nan_inf_diagnostics = self._check_nan_inf(recon_x, x, z) - # Shape diagnostics removed for cleaner logs - metrics = { - # All losses in one consolidated group - f"loss/total/{stage}": total_loss, - f"loss/reconstruction/{stage}": recon_loss, - f"loss/kl/{stage}": kl_loss, - f"loss/mae/{stage}": mae_loss, + f"loss/{stage}/total": total_loss, + f"loss/{stage}/reconstruction": recon_loss, + f"loss/{stage}/kl": kl_loss, f"beta/{stage}": beta, - f"loss/kl_recon_ratio/{stage}": kl_recon_ratio, - f"loss/recon_contribution/{stage}": recon_loss / total_loss, } # Add diagnostic metrics From 9f4be8da39dc2594f3d10ce9dcce77af79d29ca6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 15:28:14 -0700 Subject: [PATCH 040/101] add a way to handle when using 'mean' reduction for proper scaling --- viscy/representation/engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 59485b614..730cdd98a 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -536,6 +536,10 @@ def forward(self, x: Tensor) -> dict: if isinstance(self.loss_function, nn.MSELoss): if hasattr(self.loss_function, 'reduction') and self.loss_function.reduction == 'sum': recon_loss = recon_loss / batch_size + elif hasattr(self.loss_function, 'reduction') and self.loss_function.reduction == 'mean': + # Correct the over-normalization by PyTorch's mean reduction by multiplying by the number of elements per image + num_elements_per_image = x_original[0].numel() + recon_loss = recon_loss * num_elements_per_image kl_loss = -0.5 * torch.sum(1 + torch.clamp(logvar,self._logvar_minmax[0],self._logvar_minmax[1]) - mu.pow(2) - logvar.exp(), dim=1) kl_loss = torch.mean(kl_loss) From e97dad133047acc9c394b6d270beb55f5cef81d8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 15:29:10 -0700 Subject: [PATCH 041/101] adding optional config for middle slice index for computing sam2 embeddings and dinov3 --- .../DINOV3/config_dinov3_convnext_tiny.yml | 64 +++ .../DynaCLR/DINOV3/dinov3_embeddings.py | 425 ++++++++++++++++++ .../DynaCLR/SAM2/sam2_embeddings.py | 14 +- 3 files changed, 500 insertions(+), 3 deletions(-) create mode 100644 applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml create mode 100644 applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py diff --git a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml new file mode 100644 index 000000000..4d7fe1a03 --- /dev/null +++ b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml @@ -0,0 +1,64 @@ +datamodule: + batch_size: 32 + final_yx_patch_size: + - 224 + - 224 + include_fov_names: null + include_track_ids: null + initial_yx_patch_size: + - 224 + - 224 + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - RFP + lower: 50 + upper: 99 + - class_path: viscy.transforms.NormalizeIntensityd + init_args: + keys: + - Phase3D + num_workers: 10 + source_channel: + - RFP + - Phase3D + z_range: + - 15 + - 45 + +embedding: + pca_kwargs: + n_components: 8 + phate_kwargs: + decay: 40 + knn: 5 + n_components: 2 + n_jobs: -1 + random_state: 42 + reductions: + - PHATE + - PCA + +execution: + overwrite: false + save_config: true + show_config: true + +model: + model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + pooling_method: mean # Options: "mean", "max", "cls_token" + middle_slice_index: 30 # Specific z-slice index (if null, uses D//2) + channel_reduction_methods: + Phase3D: middle_slice + RFP: max + channel_names: + - RFP + - Phase3D + +paths: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr + output_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_mean.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py new file mode 100644 index 000000000..5f16b1e9a --- /dev/null +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -0,0 +1,425 @@ +import importlib +import logging +import os +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import click +import numpy as np +import torch +import yaml +from lightning.pytorch import LightningModule +from PIL import Image +from skimage.exposure import rescale_intensity +from transformers import AutoImageProcessor, AutoModel + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import EmbeddingWriter +from viscy.trainer import VisCyTrainer + + +class DINOv3Module(LightningModule): + def __init__( + self, + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + pooling_method: str = "mean", # "mean", "max", or "cls_token" + middle_slice_index: Optional[int] = None, + ): + """ + DINOv3 module for feature extraction. + + Args: + model_name: DINOv3 model name from HuggingFace + channel_reduction_methods: How to reduce 5D inputs per channel + channel_names: Names of channels for reduction mapping + pooling_method: How to pool spatial tokens ("mean", "max", "cls_token") + middle_slice_index: Specific z-slice index to use (if None, uses D//2) + """ + super().__init__() + self.model_name = model_name + self.channel_reduction_methods = channel_reduction_methods or {} + self.channel_names = channel_names or [] + self.pooling_method = pooling_method + self.middle_slice_index = middle_slice_index + + torch.set_float32_matmul_precision("high") + self.model = None + self.processor = None + + def on_predict_start(self): + """Initialize model and processor when prediction starts""" + if self.model is None: + self.processor = AutoImageProcessor.from_pretrained(self.model_name) + self.model = AutoModel.from_pretrained(self.model_name) + self.model.eval() + self.model.to(self.device) + + def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: + """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + + Args: + x: 5D input tensor + + Returns: + 4D tensor after applying reduction methods + """ + if x.dim() != 5: + return x + + B, C, D, H, W = x.shape + result = torch.zeros((B, C, H, W), device=x.device) + + # Group channels by reduction method + middle_slice_indices = [] + mean_indices = [] + max_indices = [] + + for c in range(C): + channel_name = ( + self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" + ) + method = self.channel_reduction_methods.get(channel_name, "middle_slice") + + if method == "mean": + mean_indices.append(c) + elif method == "max": + max_indices.append(c) + else: # Default to middle_slice + middle_slice_indices.append(c) + + # Apply reductions + if middle_slice_indices: + indices = torch.tensor(middle_slice_indices, device=x.device) + slice_idx = self.middle_slice_index if self.middle_slice_index is not None else D // 2 + result[:, indices] = x[:, indices, slice_idx] + + if mean_indices: + indices = torch.tensor(mean_indices, device=x.device) + result[:, indices] = x[:, indices].mean(dim=2) + + if max_indices: + indices = torch.tensor(max_indices, device=x.device) + result[:, indices] = x[:, indices].max(dim=2)[0] + + return result + + def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: + """Convert tensor to list of PIL Images for DINOv3 processing. + + Args: + x: Input tensor (B, C, H, W) + + Returns: + List of PIL Images + """ + images = [] + + for b in range(x.shape[0]): + img_tensor = x[b] # (C, H, W) + + if img_tensor.shape[0] == 1: + # Single channel - convert to grayscale PIL + img_array = img_tensor[0].cpu().numpy() + # Normalize to 0-255 + img_normalized = ((img_array - img_array.min()) / + (img_array.max() - img_array.min()) * 255).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode='L') + + elif img_tensor.shape[0] == 2: + # Two channels - create RGB with blend in blue + img_array = img_tensor.cpu().numpy() + rgb_array = np.zeros((img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8) + + # Normalize each channel to 0-255 + ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype(np.uint8) + ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype(np.uint8) + + rgb_array[:, :, 0] = ch0_norm # Red + rgb_array[:, :, 1] = ch1_norm # Green + rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue as blend + + pil_img = Image.fromarray(rgb_array, mode='RGB') + + elif img_tensor.shape[0] == 3: + # Three channels - direct RGB + img_array = img_tensor.cpu().numpy().transpose(1, 2, 0) # HWC + img_normalized = rescale_intensity(img_array, out_range=(0, 255)).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode='RGB') + + else: + # More than 3 channels - use first 3 + img_array = img_tensor[:3].cpu().numpy().transpose(1, 2, 0) # HWC + img_normalized = rescale_intensity(img_array, out_range=(0, 255)).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode='RGB') + + images.append(pil_img) + + return images + + def _pool_features(self, features: torch.Tensor) -> torch.Tensor: + """Pool spatial features from DINOv3 tokens. + + Args: + features: Token features (B, num_tokens, hidden_dim) + + Returns: + Pooled features (B, hidden_dim) + """ + if self.pooling_method == "cls_token": + # For ViT models, first token is usually CLS token + if "vit" in self.model_name.lower(): + return features[:, 0, :] # CLS token + else: + # For ConvNeXt, no CLS token, fall back to mean + return features.mean(dim=1) + + elif self.pooling_method == "max": + return features.max(dim=1)[0] + else: # mean pooling + return features.mean(dim=1) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Extract features from input images using DINOv3. + + Returns: + Dictionary with pooled features, empty projections, and index information + """ + x = batch["anchor"] + + # Handle 5D input (B, C, D, H, W) + if x.dim() == 5: + x = self._reduce_5d_input(x) + + # Convert to PIL Images for DINOv3 processing + pil_images = self._convert_to_pil_images(x) + + # Process all images in batch + batch_features = [] + + for pil_img in pil_images: + # Process single image + inputs = self.processor(pil_img, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + # Get all tokens from last hidden state + token_features = outputs.last_hidden_state # (1, num_tokens, hidden_dim) + + # Pool spatial tokens to get single feature vector + pooled_features = self._pool_features(token_features) # (1, hidden_dim) + + batch_features.append(pooled_features) + + # Concatenate all features in batch + features = torch.cat(batch_features, dim=0) # (B, hidden_dim) + + return { + "features": features, + "projections": torch.zeros((features.shape[0], 0), device=features.device), + "index": batch["index"], + } + + +def load_config(config_file): + """Load configuration from a YAML file.""" + with open(config_file, "r") as f: + config = yaml.safe_load(f) + return config + + +def load_normalization_from_config(norm_config): + """Load a normalization transform from a configuration dictionary.""" + class_path = norm_config["class_path"] + init_args = norm_config.get("init_args", {}) + + # Split module and class name + module_path, class_name = class_path.rsplit(".", 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the class + transform_class = getattr(module, class_name) + + # Instantiate the transform + return transform_class(**init_args) + + +@click.command() +@click.option( + "--config", + "-c", + type=click.Path(exists=True), + required=True, + help="Path to YAML configuration file", +) +def main(config): + """Extract DINOv3 embeddings and save to zarr format using VisCy Trainer.""" + # Configure logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load config file + cfg = load_config(config) + logger.info(f"Loaded configuration from {config}") + + # Prepare datamodule parameters + dm_params = {} + + # Add data and tracks paths from the paths section + if "paths" not in cfg: + raise ValueError("Configuration must contain a 'paths' section") + + if "data_path" not in cfg["paths"]: + raise ValueError( + "Data path is required in the configuration file (paths.data_path)" + ) + dm_params["data_path"] = cfg["paths"]["data_path"] + + if "tracks_path" not in cfg["paths"]: + raise ValueError( + "Tracks path is required in the configuration file (paths.tracks_path)" + ) + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Add datamodule parameters + if "datamodule" not in cfg: + raise ValueError("Configuration must contain a 'datamodule' section") + + # Prepare normalizations + if ( + "normalizations" not in cfg["datamodule"] + or not cfg["datamodule"]["normalizations"] + ): + raise ValueError( + "Normalizations are required in the configuration file (datamodule.normalizations)" + ) + + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy all other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + # Handle patch sizes + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + # Set up the data module + logger.info("Setting up data module") + dm = TripletDataModule(**dm_params) + + # Get model parameters + model_name = cfg["model"].get("model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m") + pooling_method = cfg["model"].get("pooling_method", "mean") + channel_reduction_methods = cfg["model"].get("channel_reduction_methods", {}) + channel_names = cfg["model"].get("channel_names", []) + middle_slice_index = cfg["model"].get("middle_slice_index", None) + + # Initialize DINOv3 model + logger.info(f"Loading DINOv3 model: {model_name}") + model = DINOv3Module( + model_name=model_name, + pooling_method=pooling_method, + channel_reduction_methods=channel_reduction_methods, + channel_names=channel_names, + middle_slice_index=middle_slice_index, + ) + + # Get dimensionality reduction parameters from config + phate_kwargs = None + pca_kwargs = None + + if "embedding" in cfg: + if "phate_kwargs" in cfg["embedding"]: + phate_kwargs = cfg["embedding"]["phate_kwargs"] + if "pca_kwargs" in cfg["embedding"]: + pca_kwargs = cfg["embedding"]["pca_kwargs"] + + # Check if output path exists and should be overwritten + if "output_path" not in cfg["paths"]: + raise ValueError( + "Output path is required in the configuration file (paths.output_path)" + ) + + output_path = Path(cfg["paths"]["output_path"]) + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + + overwrite = False + if "execution" in cfg and "overwrite" in cfg["execution"]: + overwrite = cfg["execution"]["overwrite"] + elif output_path.exists(): + logger.warning(f"Output path {output_path} already exists, will overwrite") + overwrite = True + + # Set up EmbeddingWriter callback + embedding_writer = EmbeddingWriter( + output_path=output_path, + phate_kwargs=phate_kwargs, + pca_kwargs=pca_kwargs, + overwrite=overwrite, + ) + + # Set up and run VisCy trainer + logger.info("Setting up VisCy trainer") + trainer = VisCyTrainer( + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + callbacks=[embedding_writer], + inference_mode=True, + ) + + logger.info(f"Running prediction and saving to {output_path}") + trainer.predict(model, datamodule=dm) + + # Save configuration if requested + save_config_flag = True + show_config_flag = True + + if "execution" in cfg: + if "save_config" in cfg["execution"]: + save_config_flag = cfg["execution"]["save_config"] + if "show_config" in cfg["execution"]: + show_config_flag = cfg["execution"]["show_config"] + + # Save configuration if requested + if save_config_flag: + config_path = os.path.join(output_dir, "config.yml") + with open(config_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False) + logger.info(f"Configuration saved to {config_path}") + + # Display configuration if requested + if show_config_flag: + click.echo("\nConfiguration used:") + click.echo("-" * 40) + for key, value in cfg.items(): + click.echo(f"{key}:") + if isinstance(value, dict): + for subkey, subvalue in value.items(): + if isinstance(subvalue, list) and subkey == "normalizations": + click.echo(f" {subkey}:") + for norm in subvalue: + click.echo(f" - class_path: {norm['class_path']}") + click.echo(f" init_args: {norm['init_args']}") + else: + click.echo(f" {subkey}: {subvalue}") + else: + click.echo(f" {value}") + click.echo("-" * 40) + + logger.info("Done!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py index ef823bc46..3d13d5a5a 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -24,11 +24,13 @@ def __init__( Dict[str, Literal["middle_slice", "mean", "max"]] ] = None, channel_names: Optional[List[str]] = None, + middle_slice_index: Optional[int] = None, ): super().__init__() self.model_name = model_name self.channel_reduction_methods = channel_reduction_methods or {} self.channel_names = channel_names or [] + self.middle_slice_index = middle_slice_index torch.set_float32_matmul_precision("high") self.model = None # Initialize in on_predict_start when device is set @@ -77,7 +79,8 @@ def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: # Apply middle_slice reduction to all relevant channels at once if middle_slice_indices: indices = torch.tensor(middle_slice_indices, device=x.device) - result[:, indices] = x[:, indices, D // 2] + slice_idx = self.middle_slice_index if self.middle_slice_index is not None else D // 2 + result[:, indices] = x[:, indices, slice_idx] # Apply mean reduction to all relevant channels at once if mean_indices: @@ -252,15 +255,20 @@ def main(config): # Get model parameters for handling 5D inputs channel_reduction_methods = {} + middle_slice_index = None - if "model" in cfg and "channel_reduction_methods" in cfg["model"]: - channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + if "model" in cfg: + if "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + if "middle_slice_index" in cfg["model"]: + middle_slice_index = cfg["model"]["middle_slice_index"] # Initialize SAM2 model with reduction settings logger.info("Loading SAM2 model") model = SAM2Module( model_name=cfg["model"]["model_name"], channel_reduction_methods=channel_reduction_methods, + middle_slice_index=middle_slice_index, ) # Get dimensionality reduction parameters from config From 3ec674479472984d2bb91591b09443f5bdb1d21a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 15:29:31 -0700 Subject: [PATCH 042/101] converting latent stats active_dimensions parameter to float to remove warning --- viscy/representation/vae_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 822fbcb59..a9b395ba9 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -93,7 +93,7 @@ def log_enhanced_metrics( f"latent_statistics/std_avg/{stage}": torch.mean(latent_std), f"latent_statistics/mean_max/{stage}": torch.max(latent_mean), f"latent_statistics/std_max/{stage}": torch.max(latent_std), - f"latent_statistics/active_dims/{stage}": active_dims, + f"latent_statistics/active_dims/{stage}": active_dims.float(), f"latent_statistics/effective_dim/{stage}": effective_dim, f"latent_statistics/utilization/{stage}": active_dims / self.latent_dim, } From fb0ecc4ff72b702ee06fdf6f509ddc5cf6bc31bb Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:05:46 -0700 Subject: [PATCH 043/101] ruff --- viscy/data/cell_division_triplet.py | 4 ---- viscy/representation/disentanglement_metrics.py | 2 +- viscy/representation/engine.py | 1 - viscy/representation/evaluation/smoothness.py | 1 - viscy/representation/vae_logging.py | 1 - viscy/scripts/optimization/optuna_utils.py | 2 +- viscy/scripts/optimization/optuna_vae_search.py | 6 +----- 7 files changed, 3 insertions(+), 14 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index c9259ea44..f5a5180d1 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -4,17 +4,13 @@ from typing import Literal, Sequence import numpy as np -import pandas as pd import torch from monai.transforms import Compose, MapTransform -from natsort import natsorted from torch import Tensor from torch.utils.data import Dataset from viscy.data.hcs import HCSDataModule from viscy.data.triplet import ( - _gather_channels, - _scatter_channels, _transform_channel_wise, ) from viscy.data.typing import DictTransform, TripletSample diff --git a/viscy/representation/disentanglement_metrics.py b/viscy/representation/disentanglement_metrics.py index 5bb7b209d..5b721b390 100644 --- a/viscy/representation/disentanglement_metrics.py +++ b/viscy/representation/disentanglement_metrics.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 730cdd98a..5d007ba4b 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -11,7 +11,6 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder -from viscy.representation.disentanglement_metrics import DisentanglementMetrics from viscy.representation.vae import BetaVae25D, BetaVaeMonai from viscy.representation.vae_logging import BetaVaeLogger from viscy.utils.log_images import detach_sample, render_images diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 51be25f7c..27174cd8c 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from scipy.optimize import minimize_scalar from scipy.signal import find_peaks from scipy.stats import gaussian_kde from sklearn.preprocessing import StandardScaler diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index a9b395ba9..875ac0edf 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import numpy as np import torch -import torch.nn.functional as F from PIL import Image from sklearn.decomposition import PCA from sklearn.manifold import TSNE diff --git a/viscy/scripts/optimization/optuna_utils.py b/viscy/scripts/optimization/optuna_utils.py index 91c9f5369..bfcc88481 100644 --- a/viscy/scripts/optimization/optuna_utils.py +++ b/viscy/scripts/optimization/optuna_utils.py @@ -3,7 +3,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import numpy as np import yaml diff --git a/viscy/scripts/optimization/optuna_vae_search.py b/viscy/scripts/optimization/optuna_vae_search.py index 34d27c9ef..b75d582f0 100644 --- a/viscy/scripts/optimization/optuna_vae_search.py +++ b/viscy/scripts/optimization/optuna_vae_search.py @@ -3,14 +3,10 @@ Optuna hyperparameter optimization for VAE training with PyTorch Lightning. """ -import os -import shutil import subprocess import tempfile from pathlib import Path -from typing import Any, Dict -import click import optuna import torch import yaml @@ -155,7 +151,7 @@ def main(): # Set up study study_name = "vae_hyperparameter_optimization" - storage_url = f"sqlite:///optuna_vae_study.db" + storage_url = "sqlite:///optuna_vae_study.db" study = optuna.create_study( study_name=study_name, From 7f295cc07ed741f539728888b89ace98ff216838 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:07:27 -0700 Subject: [PATCH 044/101] removing the optuna config --- viscy/scripts/optimization/__init__.py | 3 - viscy/scripts/optimization/optuna_utils.py | 415 ------------------ .../optimization/optuna_vae_parallel.sh | 53 --- .../scripts/optimization/optuna_vae_search.py | 226 ---------- .../scripts/optimization/optuna_vae_slurm.sh | 47 -- 5 files changed, 744 deletions(-) delete mode 100644 viscy/scripts/optimization/__init__.py delete mode 100644 viscy/scripts/optimization/optuna_utils.py delete mode 100644 viscy/scripts/optimization/optuna_vae_parallel.sh delete mode 100644 viscy/scripts/optimization/optuna_vae_search.py delete mode 100644 viscy/scripts/optimization/optuna_vae_slurm.sh diff --git a/viscy/scripts/optimization/__init__.py b/viscy/scripts/optimization/__init__.py deleted file mode 100644 index 57b5ee39e..000000000 --- a/viscy/scripts/optimization/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Optimization scripts for hyperparameter tuning using Optuna and other methods. -""" \ No newline at end of file diff --git a/viscy/scripts/optimization/optuna_utils.py b/viscy/scripts/optimization/optuna_utils.py deleted file mode 100644 index bfcc88481..000000000 --- a/viscy/scripts/optimization/optuna_utils.py +++ /dev/null @@ -1,415 +0,0 @@ -import glob -import os -import subprocess -import tempfile -from pathlib import Path -from typing import Any, Dict, List, Optional - -import numpy as np -import yaml -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator - - -def extract_tensorboard_metric( - log_dir: str, metric_name: str = "loss/total/val", aggregation: str = "min" -) -> float: - """ - Extract a metric from TensorBoard logs. - - Args: - log_dir: Path to the directory containing TensorBoard logs - metric_name: Name of the metric to extract (e.g., "loss/total/val") - aggregation: How to aggregate the metric values ("min", "max", "last", "mean") - - Returns: - The aggregated metric value, or float('inf') if extraction fails - - Examples: - >>> extract_tensorboard_metric("./logs/version_1", "loss/total/val", "min") - 0.234567 - - >>> extract_tensorboard_metric("./logs/version_1", "accuracy", "max") - 0.891234 - """ - try: - # Find the events file - events_files = list(Path(log_dir).glob("events.out.tfevents.*")) - if not events_files: - print(f"Warning: No events file found in {log_dir}") - return float("inf") - - # Load TensorBoard data - ea = EventAccumulator(str(events_files[0])) - ea.Reload() - - # Extract metric - if metric_name in ea.Tags()["scalars"]: - values = np.array([scalar.value for scalar in ea.Scalars(metric_name)]) - - if aggregation == "min": - result = float(np.min(values)) - elif aggregation == "max": - result = float(np.max(values)) - elif aggregation == "last": - result = float(values[-1]) - elif aggregation == "mean": - result = float(np.mean(values)) - else: - raise ValueError(f"Unknown aggregation: {aggregation}") - - print(f"Extracted {metric_name} ({aggregation}): {result:.6f}") - return result - else: - print(f"Warning: Metric '{metric_name}' not found in {log_dir}") - available_metrics = ea.Tags()["scalars"] - print(f"Available metrics: {available_metrics}") - return float("inf") - - except Exception as e: - print(f"Error extracting {metric_name} from {log_dir}: {e}") - return float("inf") - - -def modify_config( - base_config_path: str, - modifications: Dict[str, Any], - output_path: Optional[str] = None, -) -> str: - """ - Modify a YAML configuration file with new parameter values. - - Supports nested key modification using dot notation (e.g., "model.init_args.beta"). - - Args: - base_config_path: Path to the base configuration file - modifications: Dictionary with nested keys to modify - e.g., {"model.init_args.beta": 10, "trainer.max_epochs": 50} - output_path: Where to save the modified config (if None, creates temp file) - - Returns: - Path to the modified configuration file - - Examples: - >>> modify_config("base.yml", {"model.init_args.lr": 1e-3}, "modified.yml") - "modified.yml" - - >>> temp_path = modify_config("base.yml", {"trainer.max_epochs": 100}) - >>> # Returns path to temporary file - """ - # Load base config - with open(base_config_path, "r") as f: - config = yaml.safe_load(f) - - # Apply modifications - for key_path, value in modifications.items(): - keys = key_path.split(".") - current = config - - # Navigate to the nested dictionary - for key in keys[:-1]: - if key not in current: - current[key] = {} - current = current[key] - - # Set the final value - current[keys[-1]] = value - - # Save modified config - if output_path is None: - temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) - output_path = temp_file.name - temp_file.close() - - with open(output_path, "w") as f: - yaml.dump(config, f, default_flow_style=False) - - return output_path - - -def run_lightning_training( - config_path: str, - working_dir: str = ".", - timeout: int = 3600, - capture_output: bool = True, -) -> subprocess.CompletedProcess: - """ - Run Lightning training with the given configuration. - - Args: - config_path: Path to the configuration file - working_dir: Working directory for the training process - timeout: Timeout in seconds (default: 1 hour) - capture_output: Whether to capture stdout/stderr - - Returns: - CompletedProcess object with training results - - Examples: - >>> result = run_lightning_training("config.yml", timeout=1800) - >>> if result.returncode == 0: - ... print("Training completed successfully") - """ - cmd = ["python", "-m", "viscy.cli.train", "fit", "--config", config_path] - - print(f"Running command: {' '.join(cmd)}") - - return subprocess.run( - cmd, cwd=working_dir, capture_output=capture_output, text=True, timeout=timeout - ) - - -def suggest_hyperparameters( - trial, param_config: Dict[str, Dict[str, Any]] -) -> Dict[str, Any]: - """ - Suggest hyperparameters based on a configuration dictionary. - - Supports different parameter types with flexible configuration options. - - Args: - trial: Optuna trial object - param_config: Configuration for parameters with format: - { - "param_name": { - "type": "float" | "int" | "categorical", - "low": , # for float/int - "high": , # for float/int - "choices": [], # for categorical - "log": True/False, # for float/int (optional) - "step": # for int (optional) - } - } - - Returns: - Dictionary of suggested parameter values - - Examples: - >>> param_config = { - ... "lr": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True}, - ... "batch_size": {"type": "categorical", "choices": [32, 64, 128]}, - ... "epochs": {"type": "int", "low": 10, "high": 100, "step": 10} - ... } - >>> params = suggest_hyperparameters(trial, param_config) - >>> # Returns: {"lr": 0.0001234, "batch_size": 64, "epochs": 50} - """ - params = {} - - for param_name, config in param_config.items(): - param_type = config["type"] - - if param_type == "float": - log_scale = config.get("log", False) - params[param_name] = trial.suggest_float( - param_name, config["low"], config["high"], log=log_scale - ) - elif param_type == "int": - step = config.get("step", 1) - params[param_name] = trial.suggest_int( - param_name, config["low"], config["high"], step=step - ) - elif param_type == "categorical": - params[param_name] = trial.suggest_categorical( - param_name, config["choices"] - ) - else: - raise ValueError(f"Unknown parameter type: {param_type}") - - return params - - -def create_study_with_defaults( - study_name: str, - storage_url: str, - direction: str = "minimize", - sampler_name: str = "TPE", - pruner_name: str = "Median", - sampler_kwargs: Optional[Dict[str, Any]] = None, - pruner_kwargs: Optional[Dict[str, Any]] = None, -): - """ - Create an Optuna study with commonly used samplers and pruners. - - Args: - study_name: Name of the study - storage_url: Storage URL (e.g., "sqlite:///study.db") - direction: Optimization direction ("minimize" or "maximize") - sampler_name: Sampler type ("TPE", "Random", "CmaEs") - pruner_name: Pruner type ("Median", "Hyperband", "None") - sampler_kwargs: Additional sampler arguments (e.g., {"seed": 42}) - pruner_kwargs: Additional pruner arguments (e.g., {"n_startup_trials": 5}) - - Returns: - Optuna study object - - Examples: - >>> study = create_study_with_defaults( - ... "vae_optimization", - ... "sqlite:///vae_study.db", - ... sampler_kwargs={"seed": 42} - ... ) - >>> study.optimize(objective, n_trials=100) - """ - import optuna - - # Set up sampler - sampler_kwargs = sampler_kwargs or {} - if sampler_name == "TPE": - sampler = optuna.samplers.TPESampler(**sampler_kwargs) - elif sampler_name == "Random": - sampler = optuna.samplers.RandomSampler(**sampler_kwargs) - elif sampler_name == "CmaEs": - sampler = optuna.samplers.CmaEsSampler(**sampler_kwargs) - else: - raise ValueError(f"Unknown sampler: {sampler_name}") - - # Set up pruner - pruner_kwargs = pruner_kwargs or {} - if pruner_name == "Median": - pruner = optuna.pruners.MedianPruner(**pruner_kwargs) - elif pruner_name == "Hyperband": - pruner = optuna.pruners.HyperbandPruner(**pruner_kwargs) - elif pruner_name == "None": - pruner = optuna.pruners.NopPruner() - else: - raise ValueError(f"Unknown pruner: {pruner_name}") - - return optuna.create_study( - study_name=study_name, - storage=storage_url, - direction=direction, - load_if_exists=True, - sampler=sampler, - pruner=pruner, - ) - - -def save_best_config( - study, - base_config_path: str, - output_path: str, - param_mapping: Dict[str, str], - additional_modifications: Optional[Dict[str, Any]] = None, -) -> None: - """ - Save the best configuration found by Optuna to a file. - - Args: - study: Completed Optuna study - base_config_path: Path to the base configuration file - output_path: Where to save the best configuration - param_mapping: Mapping from Optuna parameter names to config keys - e.g., {"beta": "model.init_args.beta", "lr": "model.init_args.lr"} - additional_modifications: Additional modifications to apply to the config - e.g., {"trainer.max_epochs": 300, "model.init_args.loss_function.init_args.reduction": "mean"} - - Examples: - >>> param_mapping = { - ... "beta": "model.init_args.beta", - ... "lr": "model.init_args.lr" - ... } - >>> additional_mods = {"trainer.max_epochs": 300} - >>> save_best_config(study, "base.yml", "best.yml", param_mapping, additional_mods) - """ - if study.best_trial is None: - print("No best trial found") - return - - # Create modifications dictionary - modifications = {} - for optuna_param, config_key in param_mapping.items(): - if optuna_param in study.best_params: - modifications[config_key] = study.best_params[optuna_param] - - # Add additional modifications - if additional_modifications: - modifications.update(additional_modifications) - - # Create and save the best configuration - modify_config(base_config_path, modifications, output_path) - - print(f"Best configuration saved to: {output_path}") - print(f"Best value: {study.best_value:.6f}") - print("Best parameters:") - for key, value in study.best_params.items(): - print(f" {key}: {value}") - - -def cleanup_temp_files(file_patterns: List[str], working_dir: str = ".") -> None: - """ - Clean up temporary files matching the given patterns. - - Uses glob patterns to match files for deletion. Handles both files and - directories safely. - - Args: - file_patterns: List of glob patterns for files to delete - e.g., ["trial_*.yml", "temp_*", "*.tmp"] - working_dir: Directory to search in (default: current directory) - - Examples: - >>> cleanup_temp_files(["trial_*.yml", "temp_logs_*"]) - Removed: trial_1.yml - Removed: trial_2.yml - Removed: temp_logs_experiment1 - """ - for pattern in file_patterns: - files = glob.glob(os.path.join(working_dir, pattern)) - for file_path in files: - try: - if os.path.isfile(file_path): - os.remove(file_path) - print(f"Removed: {file_path}") - elif os.path.isdir(file_path): - import shutil - - shutil.rmtree(file_path) - print(f"Removed directory: {file_path}") - except Exception as e: - print(f"Failed to remove {file_path}: {e}") - - -def validate_config_modifications( - base_config_path: str, modifications: Dict[str, Any] -) -> bool: - """ - Validate that configuration modifications are applicable to the base config. - - Checks if the nested keys exist in the base configuration structure. - - Args: - base_config_path: Path to the base configuration file - modifications: Dictionary of modifications to validate - - Returns: - True if all modifications are valid, False otherwise - - Examples: - >>> modifications = {"model.init_args.beta": 10, "invalid.key": 5} - >>> validate_config_modifications("config.yml", modifications) - False # because "invalid.key" doesn't exist in base config - """ - try: - with open(base_config_path, "r") as f: - config = yaml.safe_load(f) - - for key_path in modifications.keys(): - keys = key_path.split(".") - current = config - - # Check if nested path exists - for key in keys[:-1]: - if not isinstance(current, dict) or key not in current: - print(f"Invalid key path: {key_path} (missing: {key})") - return False - current = current[key] - - # Check final key (it's ok if it doesn't exist, we'll create it) - if not isinstance(current, dict): - print(f"Invalid key path: {key_path} (parent is not dict)") - return False - - return True - - except Exception as e: - print(f"Error validating config modifications: {e}") - return False diff --git a/viscy/scripts/optimization/optuna_vae_parallel.sh b/viscy/scripts/optimization/optuna_vae_parallel.sh deleted file mode 100644 index 4ddbaddee..000000000 --- a/viscy/scripts/optimization/optuna_vae_parallel.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=optuna_vae_parallel -#SBATCH --output=optuna_vae_parallel_%A_%a.out -#SBATCH --error=optuna_vae_parallel_%A_%a.err -#SBATCH --time=12:00:00 -#SBATCH --partition=gpu -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=12 -#SBATCH --gres=gpu:1 -#SBATCH --mem-per-cpu=6G -#SBATCH --array=1-4 # Run 4 parallel workers - - -# Print job info -echo "Array Job ID: $SLURM_ARRAY_JOB_ID" -echo "Array Task ID: $SLURM_ARRAY_TASK_ID" -echo "Node: $SLURM_NODELIST" -echo "Start Time: $(date)" - -# Change to repo directory -module load anaconda/25.3.1 -conda activate viscy - -# Load environment -OPTUNA_SCRIPT='/home/eduardo.hirata/repos/viscy/viscy/scripts/optimization/optuna_vae_search.py' - -# Shared storage for Optuna study (all workers use same database) -SHARED_DB="/hpc/projects/organelle_phenotyping/models/SEC61B/vae/optuna_results/optuna_parallel_study.db" -OUTPUT_DIR="/hpc/projects/organelle_phenotyping/models/SEC61B/vae//optuna_results/parallel_job_${SLURM_ARRAY_JOB_ID}" -mkdir -p $OUTPUT_DIR - -# Each worker runs a portion of trials -TRIALS_PER_WORKER=15 # 4 workers × 15 trials = 60 total trials - -echo "Worker $SLURM_ARRAY_TASK_ID starting $TRIALS_PER_WORKER trials..." - -# Run Optuna optimization (all workers share the same database) -python \ - --storage_url "sqlite:///$pp" \ - --n_trials $TRIALS_PER_WORKER \ - --timeout 43200 \ - --study_name "vae_parallel_optimization" - -# Only the first worker saves the final results -if [ $SLURM_ARRAY_TASK_ID -eq 1 ]; then - echo "Worker 1 saving final results..." - sleep 60 # Wait for other workers to finish - cp $SHARED_DB $OUTPUT_DIR/ - cp best_vae_config.yml $OUTPUT_DIR/ 2>/dev/null || echo "No best config generated yet" -fi - -echo "Worker $SLURM_ARRAY_TASK_ID completed at: $(date)" \ No newline at end of file diff --git a/viscy/scripts/optimization/optuna_vae_search.py b/viscy/scripts/optimization/optuna_vae_search.py deleted file mode 100644 index b75d582f0..000000000 --- a/viscy/scripts/optimization/optuna_vae_search.py +++ /dev/null @@ -1,226 +0,0 @@ -#!/usr/bin/env python3 -""" -Optuna hyperparameter optimization for VAE training with PyTorch Lightning. -""" - -import subprocess -import tempfile -from pathlib import Path - -import optuna -import torch -import yaml -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator - - -def extract_best_val_loss(log_dir: str) -> float: - """Extract the best validation loss from TensorBoard logs.""" - try: - # Find the events file - events_files = list(Path(log_dir).glob("events.out.tfevents.*")) - if not events_files: - print(f"Warning: No events file found in {log_dir}") - return float("inf") - - # Load TensorBoard data - ea = EventAccumulator(str(events_files[0])) - ea.Reload() - - # Extract validation loss - if "loss/total/val" in ea.Tags()["scalars"]: - val_losses = ea.Scalars("loss/total/val") - best_val_loss = min([scalar.value for scalar in val_losses]) - print(f"Best validation loss: {best_val_loss:.6f}") - return best_val_loss - else: - print(f"Warning: No validation loss found in {log_dir}") - return float("inf") - - except Exception as e: - print(f"Error extracting validation loss from {log_dir}: {e}") - return float("inf") - - -def create_trial_config( - base_config_path: str, trial: optuna.Trial, trial_dir: Path -) -> str: - """Create a modified config file for the current trial.""" - - # Load base config - with open(base_config_path, "r") as f: - config = yaml.safe_load(f) - - # Sample hyperparameters - beta = trial.suggest_float("beta", 0.1, 50.0, log=True) - lr = trial.suggest_float("lr", 5e-5, 5e-3, log=True) - warmup_epochs = trial.suggest_int("warmup_epochs", 10, 100) - latent_dim = trial.suggest_categorical("latent_dim", [512, 1024, 2048]) - batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) - - # Modify model config - config["model"]["init_args"]["beta"] = beta - config["model"]["init_args"]["lr"] = lr - config["model"]["init_args"]["beta_warmup_epochs"] = warmup_epochs - - # Modify data config - config["data"]["init_args"]["batch_size"] = batch_size - - # Reduce training for faster search - config["trainer"]["max_epochs"] = 30 - config["trainer"]["check_val_every_n_epoch"] = 2 - - # Set unique logging directory - config["trainer"]["logger"]["init_args"]["save_dir"] = str(trial_dir) - config["trainer"]["logger"]["init_args"]["version"] = f"trial_{trial.number}" - - # Fix loss function to use mean reduction - config["model"]["init_args"]["loss_function"]["init_args"]["reduction"] = "mean" - - # Save trial config - trial_config_path = trial_dir / f"trial_{trial.number}_config.yml" - with open(trial_config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False) - - print( - f"Trial {trial.number} params: beta={beta:.4f}, lr={lr:.2e}, " - f"warmup={warmup_epochs}, latent={latent_dim}, batch={batch_size}" - ) - - return str(trial_config_path) - - -def objective(trial: optuna.Trial) -> float: - """Optuna objective function.""" - - # Create temporary directory for this trial - with tempfile.TemporaryDirectory( - prefix=f"optuna_trial_{trial.number}_" - ) as temp_dir: - trial_dir = Path(temp_dir) - - try: - # Create trial config - base_config = "/hpc/projects/organelle_phenotyping/models/SEC61B/vae/fit_phase_only.yml" - trial_config_path = create_trial_config(base_config, trial, trial_dir) - - # Run training - cmd = [ - "python", - "-m", - "viscy.cli.train", - "fit", - "--config", - trial_config_path, - ] - - print(f"Running trial {trial.number}: {' '.join(cmd)}") - - # Run with timeout to prevent hanging - result = subprocess.run( - cmd, - cwd="/hpc/mydata/eduardo.hirata/repos/viscy", - capture_output=True, - text=True, - timeout=3600, # 1 hour timeout - ) - - if result.returncode != 0: - print( - f"Trial {trial.number} failed with return code {result.returncode}" - ) - print(f"STDERR: {result.stderr}") - return float("inf") - - # Extract validation loss - log_dir = trial_dir / f"trial_{trial.number}" - val_loss = extract_best_val_loss(str(log_dir)) - - print(f"Trial {trial.number} completed with val_loss: {val_loss:.6f}") - return val_loss - - except subprocess.TimeoutExpired: - print(f"Trial {trial.number} timed out") - return float("inf") - except Exception as e: - print(f"Trial {trial.number} failed with error: {e}") - return float("inf") - - -def main(): - """Main optimization loop.""" - - # Set up study - study_name = "vae_hyperparameter_optimization" - storage_url = "sqlite:///optuna_vae_study.db" - - study = optuna.create_study( - study_name=study_name, - storage=storage_url, - direction="minimize", - load_if_exists=True, # Resume if study exists - sampler=optuna.samplers.TPESampler(seed=42), - pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10), - ) - - print( - f"Starting Optuna optimization with {torch.cuda.device_count()} GPUs available" - ) - print(f"Study storage: {storage_url}") - - try: - # Run optimization - study.optimize(objective, n_trials=50, timeout=24 * 3600) # 24 hour timeout - - # Print results - print("\nOptimization completed!") - print(f"Best trial: {study.best_trial.number}") - print(f"Best value: {study.best_value:.6f}") - print("Best params:") - for key, value in study.best_params.items(): - print(f" {key}: {value}") - - # Save best config - best_config_path = "best_vae_config.yml" - base_config = ( - "/hpc/projects/organelle_phenotyping/models/SEC61B/vae/fit_phase_only.yml" - ) - - with open(base_config, "r") as f: - config = yaml.safe_load(f) - - # Apply best parameters - best_params = study.best_params - config["model"]["init_args"]["beta"] = best_params["beta"] - config["model"]["init_args"]["lr"] = best_params["lr"] - config["model"]["init_args"]["beta_warmup_epochs"] = best_params[ - "warmup_epochs" - ] - config["model"]["init_args"]["encoder"]["init_args"]["latent_dim"] = ( - best_params["latent_dim"] - ) - config["model"]["init_args"]["decoder"]["init_args"]["latent_dim"] = ( - best_params["latent_dim"] - ) - config["data"]["init_args"]["batch_size"] = best_params["batch_size"] - config["model"]["init_args"]["loss_function"]["init_args"]["reduction"] = "mean" - - # Restore full training settings - config["trainer"]["max_epochs"] = 300 - config["trainer"]["check_val_every_n_epoch"] = 1 - - with open(best_config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False) - - print(f"Best configuration saved to: {best_config_path}") - - except KeyboardInterrupt: - print("\nOptimization interrupted by user") - print( - f"Current best trial: {study.best_trial.number if study.best_trial else 'None'}" - ) - if study.best_trial: - print(f"Current best value: {study.best_value:.6f}") - - -if __name__ == "__main__": - main() diff --git a/viscy/scripts/optimization/optuna_vae_slurm.sh b/viscy/scripts/optimization/optuna_vae_slurm.sh deleted file mode 100644 index 0394aa0db..000000000 --- a/viscy/scripts/optimization/optuna_vae_slurm.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=optuna_vae_search -#SBATCH --output=/slurm_out/optuna_vae_%j.out -#SBATCH --error=/slurm_out/optuna_vae_%j.err -#SBATCH --time=0-20:00:00 -#SBATCH --partition=gpu -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=12 -#SBATCH --gres=gpu:1 -#SBATCH --mem=64G - -# Print job info -echo "Job ID: $SLURM_JOB_ID" -echo "Job Name: $SLURM_JOB_NAME" -echo "Node: $SLURM_NODELIST" -echo "Start Time: $(date)" - -# Load modules/environment -module load anaconda/25.3.1 -conda activate viscy - -# Set environment variables -export CUDA_VISIBLE_DEVICES=0 -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK - -# Change to repo directory -ROOT_DIR="/home/eduardo.hirata/repos/viscy" -cd $ROOT_DIR - -# Create output directory for this job -OUTPUT_DIR="/hpc/projects/organelle_phenotyping/models/SEC61B/vae/optuna_results/job_${SLURM_JOB_ID}" -mkdir -p $OUTPUT_DIR - -# Run Optuna optimization -echo "Starting Optuna VAE hyperparameter search..." -python viscy/scripts/optimization/optuna_vae_search.py \ - --output_dir $OUTPUT_DIR \ - --n_trials 50 \ - --timeout 86400 - -# Copy results to output directory -cp optuna_vae_study.db $OUTPUT_DIR/ -cp best_vae_config.yml $OUTPUT_DIR/ 2>/dev/null || echo "No best config generated yet" - -echo "Job completed at: $(date)" -echo "Results saved to: $OUTPUT_DIR" \ No newline at end of file From d2f36597f6087b2259d0cbfaebeca94e7169600e Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:08:25 -0700 Subject: [PATCH 045/101] numpy docstring --- .../DynaCLR/SAM2/sam2_embeddings.py | 115 +++++++++++++++--- 1 file changed, 97 insertions(+), 18 deletions(-) diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py index 3d13d5a5a..66f7dfb83 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -26,6 +26,21 @@ def __init__( channel_names: Optional[List[str]] = None, middle_slice_index: Optional[int] = None, ): + """ + SAM2 module for feature extraction. + + Parameters + ---------- + model_name : str, optional + SAM2 model name from HuggingFace Model Hub (default: "facebook/sam2-hiera-base-plus"). + channel_reduction_methods : dict[str, {"middle_slice", "mean", "max"}], optional + Dictionary mapping channel names to reduction methods for 5D inputs (default: None, uses "middle_slice"). + channel_names : list of str, optional + List of channel names corresponding to input channels (default: None). + middle_slice_index : int, optional + Specific z-slice index to use for "middle_slice" reduction (default: None, uses D//2). + + """ super().__init__() self.model_name = model_name self.channel_reduction_methods = channel_reduction_methods or {} @@ -36,20 +51,32 @@ def __init__( self.model = None # Initialize in on_predict_start when device is set def on_predict_start(self): - """Initialize model with proper device when prediction starts""" + """ + Initialize model with proper device when prediction starts. + + Notes + ----- + This method is called automatically by Lightning when prediction begins. + It ensures the SAM2 model is properly initialized on the correct device. + """ if self.model is None: self.model = SAM2ImagePredictor.from_pretrained( self.model_name, device=self.device ) def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + """ + Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - Args: - x: 5D input tensor + Parameters + ---------- + x : torch.Tensor + 5D input tensor with shape (B, C, D, H, W). - Returns: - 4D tensor after applying reduction methods + Returns + ------- + torch.Tensor + 4D tensor after applying reduction methods with shape (B, C, H, W). """ if x.dim() != 5: return x @@ -95,13 +122,18 @@ def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: return result def _convert_to_rgb(self, x: torch.Tensor) -> list: - """Convert input tensor to 3-channel RGB format as needed for SAM2. + """ + Convert input tensor to 3-channel RGB format as needed for SAM2. - Args: - x: Input tensor with 1, 2, or 3+ channels + Parameters + ---------- + x : torch.Tensor + Input tensor with 1, 2, or 3+ channels and shape (B, C, H, W). - Returns: - List of numpy arrays in HWC format for SAM2 + Returns + ------- + list of numpy.ndarray + List of numpy arrays in HWC format for SAM2 processing. """ # Convert to RGB and scale to [0, 255] range for SAM2 if x.shape[1] == 1: @@ -130,10 +162,25 @@ def _convert_to_rgb(self, x: torch.Tensor) -> list: ] def predict_step(self, batch, batch_idx, dataloader_idx=0): - """Extract features from the input images. - - Returns: - Dictionary with features, properly shaped empty projections tensor, and index information + """ + Extract features from the input images. + + Parameters + ---------- + batch : dict + Batch dictionary containing "anchor" key with input tensors. + batch_idx : int + Index of the current batch. + dataloader_idx : int, optional + Index of the dataloader (default: 0). + + Returns + ------- + dict + Dictionary containing: + - "features": Extracted features tensor + - "projections": Empty tensor for compatibility (B, 0) + - "index": Batch index information """ x = batch["anchor"] @@ -161,14 +208,38 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): def load_config(config_file): - """Load configuration from a YAML file.""" + """ + Load configuration from a YAML file. + + Parameters + ---------- + config_file : str or Path + Path to the YAML configuration file. + + Returns + ------- + dict + Configuration dictionary loaded from the YAML file. + """ with open(config_file, "r") as f: config = yaml.safe_load(f) return config def load_normalization_from_config(norm_config): - """Load a normalization transform from a configuration dictionary.""" + """ + Load a normalization transform from a configuration dictionary. + + Parameters + ---------- + norm_config : dict + Configuration dictionary containing "class_path" and optional "init_args". + + Returns + ------- + object + Instantiated normalization transform object. + """ class_path = norm_config["class_path"] init_args = norm_config.get("init_args", {}) @@ -194,7 +265,15 @@ def load_normalization_from_config(norm_config): help="Path to YAML configuration file", ) def main(config): - """Extract SAM2 embeddings and save to zarr format using VisCy Trainer.""" + """ + Extract SAM2 embeddings and save to zarr format using VisCy Trainer. + + Parameters + ---------- + config : str or Path + Path to the YAML configuration file containing all parameters for + data loading, model configuration, and output settings. + """ # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) From a36318aa9700d0a7fc8d8784b9515fc566ac309c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:33:44 -0700 Subject: [PATCH 046/101] fix compute smoothness script --- .../smoothness/compute_smoothness.py | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py index ad3bdb6a5..6653ff373 100644 --- a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -2,46 +2,31 @@ from pathlib import Path import matplotlib.pyplot as plt -import numpy as np import pandas as pd import seaborn as sns -from lightning.pytorch import seed_everything -from matplotlib.patches import FancyArrowPatch from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.smoothness import compute_embeddings_smoothness -colormap = { - 2: "orange", - 1: "steelblue", -} #%% # FEATURES # openphenom_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/open_phenom/features/open_phenom_features.csv") # imagenet_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/imagenet/features/imagenet_features.csv") dynaclr_features_path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/dtw_evaluation/SAM2/sam2_sensor_only.zarr") - -# ANNOTATIONS -ann_root = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" -) - -# TRACKS - -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) +dinov3_features_path = Path("/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_phase_only_2.zarr") # LOADING DATASETS # openphenom_features = read_embedding_dataset(openphenom_features_path) # imagenet_features = read_embedding_dataset(imagenet_features_path) dynaclr_embedding_dataset = read_embedding_dataset(dynaclr_features_path) +dinov3_embedding_dataset = read_embedding_dataset(dinov3_features_path) #%% # Compute the smoothness of the features DISTANCE_METRIC = "cosine" feature_paths ={ - "dynaclr": dynaclr_features_path, + # "dynaclr": dynaclr_features_path, + "dinov3": dinov3_features_path, } cmap = plt.get_cmap("tab10") # or use "Set2", "tab20", etc. labels = list(feature_paths.keys()) @@ -111,17 +96,5 @@ "dynamic_range": stats["dynamic_range"] } # Create DataFrame with single row - stats_df = pd.DataFrame(stats) # Note the list wrapper + stats_df = pd.DataFrame(scalar_metrics, index=[0]) stats_df.to_csv(output_dir/f"{label}_smoothness_stats.csv", index=False) - - - - - - - - - -#%% - -#%% \ No newline at end of file From 173f29748adcc27799d608c42a2c1c0e4e524970 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:48:19 -0700 Subject: [PATCH 047/101] archiving old scripts --- .../evaluation/{ => archive}/ALFI_MSD_v2.py | 0 .../evaluation/{ => archive}/analyze_embeddings.py | 0 .../evaluation/{ => archive}/cosine_dissimilarity_dataset.py | 0 .../evaluation/{ => archive}/displacement.py | 0 .../evaluation/{ => archive}/linear_probing.py | 0 .../evaluation/{ => archive}/log_regresssion_training.py | 0 .../evaluation/{ => archive}/time_decay_knn.py | 0 .../{ => knowledge_distillation}/knowledge_distillation.py | 0 .../knowledge_distillation_teacher.py | 0 9 files changed, 0 insertions(+), 0 deletions(-) rename applications/contrastive_phenotyping/evaluation/{ => archive}/ALFI_MSD_v2.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/analyze_embeddings.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/cosine_dissimilarity_dataset.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/displacement.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/linear_probing.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/log_regresssion_training.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => archive}/time_decay_knn.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => knowledge_distillation}/knowledge_distillation.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => knowledge_distillation}/knowledge_distillation_teacher.py (100%) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py rename to applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py diff --git a/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py b/applications/contrastive_phenotyping/evaluation/archive/analyze_embeddings.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/analyze_embeddings.py rename to applications/contrastive_phenotyping/evaluation/archive/analyze_embeddings.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/archive/cosine_dissimilarity_dataset.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py rename to applications/contrastive_phenotyping/evaluation/archive/cosine_dissimilarity_dataset.py diff --git a/applications/contrastive_phenotyping/evaluation/displacement.py b/applications/contrastive_phenotyping/evaluation/archive/displacement.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/displacement.py rename to applications/contrastive_phenotyping/evaluation/archive/displacement.py diff --git a/applications/contrastive_phenotyping/evaluation/linear_probing.py b/applications/contrastive_phenotyping/evaluation/archive/linear_probing.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/linear_probing.py rename to applications/contrastive_phenotyping/evaluation/archive/linear_probing.py diff --git a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py b/applications/contrastive_phenotyping/evaluation/archive/log_regresssion_training.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/log_regresssion_training.py rename to applications/contrastive_phenotyping/evaluation/archive/log_regresssion_training.py diff --git a/applications/contrastive_phenotyping/evaluation/time_decay_knn.py b/applications/contrastive_phenotyping/evaluation/archive/time_decay_knn.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/time_decay_knn.py rename to applications/contrastive_phenotyping/evaluation/archive/time_decay_knn.py diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/knowledge_distillation.py rename to applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation.py diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation_teacher.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py rename to applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation_teacher.py From f12bc33eca6500f8d1c39f80e22eefbdde178247 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:54:41 -0700 Subject: [PATCH 048/101] re org the pc features scripts --- .../evaluation/{ => imagenet}/imagenet_pretrained_features.py | 0 .../{ => pc_vs_computed_features}/PC_vs_computed_features.py | 0 .../{ => pc_vs_computed_features}/compute_pca_features.py | 0 .../evaluation/{ => pc_vs_computed_features}/cosine_similarity.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename applications/contrastive_phenotyping/evaluation/{ => imagenet}/imagenet_pretrained_features.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => pc_vs_computed_features}/PC_vs_computed_features.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => pc_vs_computed_features}/compute_pca_features.py (100%) rename applications/contrastive_phenotyping/evaluation/{ => pc_vs_computed_features}/cosine_similarity.py (100%) diff --git a/applications/contrastive_phenotyping/evaluation/imagenet_pretrained_features.py b/applications/contrastive_phenotyping/evaluation/imagenet/imagenet_pretrained_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/imagenet_pretrained_features.py rename to applications/contrastive_phenotyping/evaluation/imagenet/imagenet_pretrained_features.py diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py b/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/PC_vs_computed_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py rename to applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/PC_vs_computed_features.py diff --git a/applications/contrastive_phenotyping/evaluation/compute_pca_features.py b/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/compute_pca_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/compute_pca_features.py rename to applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/compute_pca_features.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/cosine_similarity.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_similarity.py rename to applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/cosine_similarity.py From f04bfa7afd4eb44829eac616e45ab4dbdd51055f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Sep 2025 17:56:25 -0700 Subject: [PATCH 049/101] embeddings for phase --- .../DINOV3/config_dinov3_convnext_tiny.yml | 10 +- .../DynaCLR/DINOV3/dinov3_embeddings.py | 128 +++++++++--------- 2 files changed, 66 insertions(+), 72 deletions(-) diff --git a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml index 4d7fe1a03..ab6bb52cc 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml +++ b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml @@ -1,13 +1,13 @@ datamodule: batch_size: 32 final_yx_patch_size: - - 224 - - 224 + - 256 + - 256 include_fov_names: null include_track_ids: null initial_yx_patch_size: - - 224 - - 224 + - 256 + - 256 normalizations: - class_path: viscy.transforms.ScaleIntensityRangePercentilesd init_args: @@ -50,7 +50,7 @@ execution: model: model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m pooling_method: mean # Options: "mean", "max", "cls_token" - middle_slice_index: 30 # Specific z-slice index (if null, uses D//2) + middle_slice_index: 18 # Specific z-slice index (if null, uses D//2) channel_reduction_methods: Phase3D: middle_slice RFP: max diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py index 5f16b1e9a..984894c94 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -26,18 +26,25 @@ def __init__( Dict[str, Literal["middle_slice", "mean", "max"]] ] = None, channel_names: Optional[List[str]] = None, - pooling_method: str = "mean", # "mean", "max", or "cls_token" + pooling_method: Literal["mean", "max", "cls_token"] = "mean", middle_slice_index: Optional[int] = None, ): """ DINOv3 module for feature extraction. - - Args: - model_name: DINOv3 model name from HuggingFace - channel_reduction_methods: How to reduce 5D inputs per channel - channel_names: Names of channels for reduction mapping - pooling_method: How to pool spatial tokens ("mean", "max", "cls_token") - middle_slice_index: Specific z-slice index to use (if None, uses D//2) + + Parameters + ---------- + model_name : str, optional + DINOv3 model name from HuggingFace Model Hub (default: "facebook/dinov3-vitb16-pretrain-lvd1689m"). + channel_reduction_methods : dict[str, {"middle_slice", "mean", "max"}], optional + Dictionary mapping channel names to reduction methods for 5D inputs (default: None, uses "middle_slice"). + channel_names : list of str, optional + List of channel names corresponding to input channels (default: None). + pooling_method : Literal["mean", "max", "cls_token"], optional + Method to pool spatial tokens from the model output (default: "mean"). + middle_slice_index : int, optional + Specific z-slice index to use for "middle_slice" reduction (default: None, uses D//2). + """ super().__init__() self.model_name = model_name @@ -51,7 +58,6 @@ def __init__( self.processor = None def on_predict_start(self): - """Initialize model and processor when prediction starts""" if self.model is None: self.processor = AutoImageProcessor.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name) @@ -59,13 +65,18 @@ def on_predict_start(self): self.model.to(self.device) def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + """ + Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - Args: - x: 5D input tensor + Parameters + ---------- + x : torch.Tensor + 5D input tensor with shape (B, C, D, H, W). - Returns: - 4D tensor after applying reduction methods + Returns + ------- + torch.Tensor + 4D tensor after applying reduction methods with shape (B, C, H, W). """ if x.dim() != 5: return x @@ -108,13 +119,18 @@ def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: return result def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: - """Convert tensor to list of PIL Images for DINOv3 processing. + """ + Convert tensor to list of PIL Images for DINOv3 processing. - Args: - x: Input tensor (B, C, H, W) + Parameters + ---------- + x : torch.Tensor + Input tensor with shape (B, C, H, W). - Returns: - List of PIL Images + Returns + ------- + list of PIL.Image.Image + List of PIL Images ready for DINOv3 processing. """ images = [] @@ -161,13 +177,18 @@ def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: return images def _pool_features(self, features: torch.Tensor) -> torch.Tensor: - """Pool spatial features from DINOv3 tokens. + """ + Pool spatial features from DINOv3 tokens. - Args: - features: Token features (B, num_tokens, hidden_dim) + Parameters + ---------- + features : torch.Tensor + Token features with shape (B, num_tokens, hidden_dim). - Returns: - Pooled features (B, hidden_dim) + Returns + ------- + torch.Tensor + Pooled features with shape (B, hidden_dim). """ if self.pooling_method == "cls_token": # For ViT models, first token is usually CLS token @@ -183,11 +204,6 @@ def _pool_features(self, features: torch.Tensor) -> torch.Tensor: return features.mean(dim=1) def predict_step(self, batch, batch_idx, dataloader_idx=0): - """Extract features from input images using DINOv3. - - Returns: - Dictionary with pooled features, empty projections, and index information - """ x = batch["anchor"] # Handle 5D input (B, C, D, H, W) @@ -197,26 +213,14 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): # Convert to PIL Images for DINOv3 processing pil_images = self._convert_to_pil_images(x) - # Process all images in batch - batch_features = [] - - for pil_img in pil_images: - # Process single image - inputs = self.processor(pil_img, return_tensors="pt") - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - # Get all tokens from last hidden state - token_features = outputs.last_hidden_state # (1, num_tokens, hidden_dim) - - # Pool spatial tokens to get single feature vector - pooled_features = self._pool_features(token_features) # (1, hidden_dim) - - batch_features.append(pooled_features) + # Batch process all images at once for better GPU utilization + inputs = self.processor(pil_images, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} - # Concatenate all features in batch - features = torch.cat(batch_features, dim=0) # (B, hidden_dim) + with torch.no_grad(): + outputs = self.model(**inputs) + token_features = outputs.last_hidden_state # (B, num_tokens, hidden_dim) + features = self._pool_features(token_features) # (B, hidden_dim) return { "features": features, @@ -226,27 +230,21 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): def load_config(config_file): - """Load configuration from a YAML file.""" with open(config_file, "r") as f: config = yaml.safe_load(f) return config def load_normalization_from_config(norm_config): - """Load a normalization transform from a configuration dictionary.""" class_path = norm_config["class_path"] init_args = norm_config.get("init_args", {}) - # Split module and class name module_path, class_name = class_path.rsplit(".", 1) - # Import the module module = importlib.import_module(module_path) - # Get the class transform_class = getattr(module, class_name) - # Instantiate the transform return transform_class(**init_args) @@ -259,19 +257,23 @@ def load_normalization_from_config(norm_config): help="Path to YAML configuration file", ) def main(config): - """Extract DINOv3 embeddings and save to zarr format using VisCy Trainer.""" - # Configure logging + """ + Extract DINOv3 embeddings and save to zarr format using VisCy Trainer. + + Parameters + ---------- + config : str or Path + Path to the YAML configuration file containing all parameters for + data loading, model configuration, and output settings. + """ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - # Load config file cfg = load_config(config) logger.info(f"Loaded configuration from {config}") - # Prepare datamodule parameters dm_params = {} - # Add data and tracks paths from the paths section if "paths" not in cfg: raise ValueError("Configuration must contain a 'paths' section") @@ -287,11 +289,9 @@ def main(config): ) dm_params["tracks_path"] = cfg["paths"]["tracks_path"] - # Add datamodule parameters if "datamodule" not in cfg: raise ValueError("Configuration must contain a 'datamodule' section") - # Prepare normalizations if ( "normalizations" not in cfg["datamodule"] or not cfg["datamodule"]["normalizations"] @@ -304,7 +304,6 @@ def main(config): normalizations = [load_normalization_from_config(norm) for norm in norm_configs] dm_params["normalizations"] = normalizations - # Copy all other datamodule parameters for param, value in cfg["datamodule"].items(): if param != "normalizations": # Handle patch sizes @@ -314,7 +313,6 @@ def main(config): else: dm_params[param] = value - # Set up the data module logger.info("Setting up data module") dm = TripletDataModule(**dm_params) @@ -335,7 +333,6 @@ def main(config): middle_slice_index=middle_slice_index, ) - # Get dimensionality reduction parameters from config phate_kwargs = None pca_kwargs = None @@ -345,7 +342,6 @@ def main(config): if "pca_kwargs" in cfg["embedding"]: pca_kwargs = cfg["embedding"]["pca_kwargs"] - # Check if output path exists and should be overwritten if "output_path" not in cfg["paths"]: raise ValueError( "Output path is required in the configuration file (paths.output_path)" @@ -362,7 +358,6 @@ def main(config): logger.warning(f"Output path {output_path} already exists, will overwrite") overwrite = True - # Set up EmbeddingWriter callback embedding_writer = EmbeddingWriter( output_path=output_path, phate_kwargs=phate_kwargs, @@ -370,7 +365,6 @@ def main(config): overwrite=overwrite, ) - # Set up and run VisCy trainer logger.info("Setting up VisCy trainer") trainer = VisCyTrainer( accelerator="gpu" if torch.cuda.is_available() else "cpu", From 13fbd578cbf2fd379827d39e972b137b06acc230 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 9 Sep 2025 09:39:34 -0700 Subject: [PATCH 050/101] add smoothness (mean rand vs adj frame) to the csv --- .../evaluation/smoothness/compute_smoothness.py | 1 + viscy/representation/evaluation/smoothness.py | 10 +++------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py index 6653ff373..9c0bf5111 100644 --- a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -93,6 +93,7 @@ "random_frame_std": stats["random_frame_std"], "random_frame_median": stats["random_frame_median"], "random_frame_peak": stats["random_frame_peak"], + "smoothness_score": stats["smoothness_score"], "dynamic_range": stats["dynamic_range"] } # Create DataFrame with single row diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 27174cd8c..605f17ce6 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -129,6 +129,7 @@ def compute_embeddings_smoothness( - random_frame_median: Median of random sampling dissimilarity - random_frame_peak: Peak of random sampling distribution - random_frame_distribution: Full distribution of random sampling dissimilarities + - smoothness_score: Score of smoothness - dynamic_range: Difference between random and adjacent peaks distributions: dict: Dictionary containing distributions including: - adjacent_frame_distribution: Full distribution of adjacent frame dissimilarities @@ -154,13 +155,6 @@ def compute_embeddings_smoothness( all_piecewise_distances = np.concatenate(piecewise_distance_per_track) - # p99_piece_wise_distance = np.array( - # [np.percentile(track, 99) for track in piecewise_distance_per_track] - # ) - # p1_percentile_piece_wise_distance = np.array( - # [np.percentile(track, 1) for track in piecewise_distance_per_track] - # ) - # Random sampling values in the distance matrix with same size as adjacent frame measurements n_samples = len(all_piecewise_distances) # Avoid sampling the diagonal elements @@ -178,6 +172,7 @@ def compute_embeddings_smoothness( # Compute the peaks of both distributions using KDE adjacent_peak = find_distribution_peak(all_piecewise_distances, method="kde_robust") random_peak = find_distribution_peak(sampled_values, method="kde_robust") + smoothness_score = np.mean(all_piecewise_distances) / np.mean(sampled_values) dynamic_range = random_peak - adjacent_peak stats = { @@ -193,6 +188,7 @@ def compute_embeddings_smoothness( "random_frame_median": float(np.median(sampled_values)), "random_frame_peak": float(random_peak), # "random_frame_distribution": sampled_values, + "smoothness_score": float(smoothness_score), "dynamic_range": float(dynamic_range), } distributions = { From 395ddc5088d003f1bfcd6d80f290b7e2430dba45 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 9 Sep 2025 15:26:37 -0700 Subject: [PATCH 051/101] archiving old beta vae code --- .../benchmarking/DynaCLR/BetaVAE/{ => archive}/config_betavae.yml | 0 .../DynaCLR/BetaVAE/{ => archive}/config_betavae_convnext.yml | 0 .../DynaCLR/BetaVAE/{ => archive}/debug_dimensions.py | 0 .../benchmarking/DynaCLR/BetaVAE/{ => archive}/debug_stem.py | 0 .../benchmarking/DynaCLR/BetaVAE/{ => archive}/test_run.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename applications/benchmarking/DynaCLR/BetaVAE/{ => archive}/config_betavae.yml (100%) rename applications/benchmarking/DynaCLR/BetaVAE/{ => archive}/config_betavae_convnext.yml (100%) rename applications/benchmarking/DynaCLR/BetaVAE/{ => archive}/debug_dimensions.py (100%) rename applications/benchmarking/DynaCLR/BetaVAE/{ => archive}/debug_stem.py (100%) rename applications/benchmarking/DynaCLR/BetaVAE/{ => archive}/test_run.py (100%) diff --git a/applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml b/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml similarity index 100% rename from applications/benchmarking/DynaCLR/BetaVAE/config_betavae.yml rename to applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml diff --git a/applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml b/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml similarity index 100% rename from applications/benchmarking/DynaCLR/BetaVAE/config_betavae_convnext.yml rename to applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py similarity index 100% rename from applications/benchmarking/DynaCLR/BetaVAE/debug_dimensions.py rename to applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py diff --git a/applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py similarity index 100% rename from applications/benchmarking/DynaCLR/BetaVAE/debug_stem.py rename to applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py diff --git a/applications/benchmarking/DynaCLR/BetaVAE/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py similarity index 100% rename from applications/benchmarking/DynaCLR/BetaVAE/test_run.py rename to applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py From a73e7b679942cecdae80e5d824a71568855c09a3 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 16 Sep 2025 13:19:20 -0700 Subject: [PATCH 052/101] ruff --- viscy/transforms/_redef.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 394f30f0c..2e4363849 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -5,7 +5,6 @@ from monai.transforms import ( CenterSpatialCropd, Decollated, - NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, From 7fbf6c88fdfb9f3f6b159cb45f5b7de96f3166fe Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 16 Sep 2025 15:15:11 -0700 Subject: [PATCH 053/101] fix format --- viscy/data/cell_division_triplet.py | 11 +- viscy/data/triplet.py | 14 +- viscy/representation/engine.py | 143 +++++++++++------- viscy/representation/evaluation/smoothness.py | 44 +++--- viscy/representation/vae.py | 50 +++--- viscy/transforms/__init__.py | 2 +- viscy/transforms/_redef.py | 2 +- 7 files changed, 156 insertions(+), 110 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index f5a5180d1..86039ad37 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -22,11 +22,11 @@ class CellDivisionTripletDataset(Dataset): # Hardcoded channel mapping for .npy files CHANNEL_MAPPING = { # Channel 0 aliases (brightfield) - 'bf': 0, - 'brightfield': 0, + "bf": 0, + "brightfield": 0, # Channel 1 aliases (h2b) - 'h2b': 1, - 'nuclei': 1, + "h2b": 1, + "nuclei": 1, } def __init__( @@ -346,7 +346,7 @@ def _setup_fit(self, dataset_settings: dict): shuffled_indices = self._set_fit_global_state(len(self.npy_files)) npy_files = [self.npy_files[i] for i in shuffled_indices] - #Se the train an dval positions + # Se the train an dval positions num_train_files = int(len(self.npy_files) * self.split_ratio) train_npy_files = npy_files[:num_train_files] val_npy_files = npy_files[num_train_files:] @@ -354,7 +354,6 @@ def _setup_fit(self, dataset_settings: dict): _logger.debug(f"Number of training files: {len(train_npy_files)}") _logger.debug(f"Number of validation files: {len(val_npy_files)}") - # Determine anchor transform based on time interval anchor_transform = ( no_aug_transform diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 28270af7a..f913ffb9e 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -511,10 +511,16 @@ def _setup_fit(self, dataset_settings: dict): ) # Choose transforms for validation based on augment_validation parameter - val_positive_transform = augment_transform if self.augment_validation else no_aug_transform - val_negative_transform = augment_transform if self.augment_validation else no_aug_transform - val_anchor_transform = anchor_transform if self.augment_validation else no_aug_transform - + val_positive_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_negative_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_anchor_transform = ( + anchor_transform if self.augment_validation else no_aug_transform + ) + self.val_dataset = TripletDataset( positions=val_positions, tracks_tables=val_tracks_tables, diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 5d007ba4b..67b8c594b 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -22,6 +22,7 @@ "monai_beta": BetaVaeMonai, } + class ContrastivePrediction(TypedDict): features: Tensor projections: Tensor @@ -61,42 +62,44 @@ def __init__( def on_train_start(self) -> None: """Log comprehensive hyperparameters including model architecture details.""" super().on_train_start() - + # Collect comprehensive hyperparameters hparams = { # Training hyperparameters "lr": self.lr, "schedule": self.schedule, - "input_shape": self.example_input_array, + "input_shape": self.example_input_array, "loss_function_class": self.loss_function.__class__.__name__, } - + # Add loss function specific parameters - if hasattr(self.loss_function, 'margin'): + if hasattr(self.loss_function, "margin"): hparams["loss_margin"] = self.loss_function.margin - if hasattr(self.loss_function, 'temperature'): + if hasattr(self.loss_function, "temperature"): hparams["loss_temperature"] = self.loss_function.temperature - if hasattr(self.loss_function, 'normalize_embeddings'): - hparams["loss_normalize_embeddings"] = self.loss_function.normalize_embeddings - + if hasattr(self.loss_function, "normalize_embeddings"): + hparams["loss_normalize_embeddings"] = ( + self.loss_function.normalize_embeddings + ) + # Add encoder details if it's a ContrastiveEncoder - if hasattr(self.model, 'backbone'): + if hasattr(self.model, "backbone"): hparams["encoder_backbone"] = self.model.backbone - if hasattr(self.model, 'in_channels'): + if hasattr(self.model, "in_channels"): hparams["encoder_in_channels"] = self.model.in_channels - if hasattr(self.model, 'in_stack_depth'): + if hasattr(self.model, "in_stack_depth"): hparams["encoder_in_stack_depth"] = self.model.in_stack_depth - if hasattr(self.model, 'embedding_dim'): + if hasattr(self.model, "embedding_dim"): hparams["encoder_embedding_dim"] = self.model.embedding_dim - if hasattr(self.model, 'projection_dim'): + if hasattr(self.model, "projection_dim"): hparams["encoder_projection_dim"] = self.model.projection_dim - if hasattr(self.model, 'drop_path_rate'): + if hasattr(self.model, "drop_path_rate"): hparams["encoder_drop_path_rate"] = self.model.drop_path_rate - if hasattr(self.model, 'stem_kernel_size'): + if hasattr(self.model, "stem_kernel_size"): hparams["encoder_stem_kernel_size"] = str(self.model.stem_kernel_size) - if hasattr(self.model, 'stem_stride'): + if hasattr(self.model, "stem_stride"): hparams["encoder_stem_stride"] = str(self.model.stem_stride) - + # Log to TensorBoard if self.logger is not None: self.logger.log_hyperparams(hparams) @@ -267,11 +270,11 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - + # Log UMAP embeddings from validation set every N epochs if ( - self.log_embeddings - and self.current_epoch % self.embedding_log_frequency == 0 + self.log_embeddings + and self.current_epoch % self.embedding_log_frequency == 0 and self.current_epoch > 0 ): self._collect_and_log_embeddings() @@ -284,7 +287,9 @@ def _collect_and_log_embeddings(self): # Get validation dataloader val_dataloaders = self.trainer.val_dataloaders if val_dataloaders is None: - _logger.warning("No validation dataloader available for embedding logging") + _logger.warning( + "No validation dataloader available for embedding logging" + ) return elif isinstance(val_dataloaders, list): val_dataloader = val_dataloaders[0] if val_dataloaders else None @@ -292,32 +297,36 @@ def _collect_and_log_embeddings(self): val_dataloader = val_dataloaders if val_dataloader is None: - _logger.warning("No validation dataloader available for embedding logging") + _logger.warning( + "No validation dataloader available for embedding logging" + ) return - _logger.info(f"Collecting embeddings for visualization at epoch {self.current_epoch}") - + _logger.info( + f"Collecting embeddings for visualization at epoch {self.current_epoch}" + ) + # Collect embeddings, images, and metadata from validation set embeddings_list = [] images_list = [] labels_list = [] max_samples = 500 # Reduced for memory efficiency with images sample_count = 0 - + self.eval() with torch.no_grad(): for batch in val_dataloader: if sample_count >= max_samples: break - + # Move batch to device anchor = batch["anchor"].to(self.device) batch_size = anchor.size(0) - + # Get embeddings (features, not projections) features, _ = self(anchor) embeddings_list.append(features.cpu()) - + # Collect images for sprite visualization # Take middle slice for 3D data and first channel if multi-channel if anchor.ndim == 5: # (B, C, D, H, W) @@ -326,7 +335,7 @@ def _collect_and_log_embeddings(self): else: # (B, C, H, W) img_slice = anchor[:, 0].cpu() # (B, H, W) images_list.append(img_slice) - + # Collect labels from index information if "index" in batch and batch["index"] is not None: for i, idx_info in enumerate(batch["index"][:batch_size]): @@ -341,33 +350,35 @@ def _collect_and_log_embeddings(self): # Fallback labels for i in range(batch_size): labels_list.append(f"sample_{sample_count + i}") - + sample_count += batch_size - + if embeddings_list: embeddings = torch.cat(embeddings_list, dim=0)[:max_samples] images = torch.cat(images_list, dim=0)[:max_samples] labels = labels_list[:max_samples] - + # Normalize images for visualization (0-1 range) images = (images - images.min()) / (images.max() - images.min() + 1e-8) - - # Log UMAP visualization + + # Log UMAP visualization self.log_embedding_umap(embeddings, tag="validation") - + # Log to TensorBoard's embedding projector with images and labels self.logger.experiment.add_embedding( embeddings, metadata=labels, label_img=images.unsqueeze(1), # Add channel dimension global_step=self.current_epoch, - tag="validation_embeddings" + tag="validation_embeddings", + ) + + _logger.info( + f"Logged {len(embeddings)} embeddings with images and labels" ) - - _logger.info(f"Logged {len(embeddings)} embeddings with images and labels") else: _logger.warning("No embeddings collected from validation set") - + except Exception as e: _logger.error(f"Error collecting embeddings: {e}") @@ -386,10 +397,11 @@ def predict_step( "index": batch["index"], } + class BetaVaeModule(LightningModule): def __init__( self, - architecture: Literal["monai_beta","2.5D"], + architecture: Literal["monai_beta", "2.5D"], model_config: dict = {}, loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), beta: float = 1.0, @@ -408,7 +420,7 @@ def __init__( ): super().__init__() - net_class= _VAE_ARCHITECTURE.get(architecture) + net_class = _VAE_ARCHITECTURE.get(architecture) if not net_class: raise ValueError( f"Architecture {architecture} not in {_VAE_ARCHITECTURE.keys()}" @@ -426,14 +438,14 @@ def __init__( self.lr = lr self.lr_schedule = lr_schedule - + self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch self.example_input_array = torch.rand(*example_input_array_shape) self.compute_disentanglement = compute_disentanglement self.disentanglement_frequency = disentanglement_frequency - + self.log_enhanced_visualizations = log_enhanced_visualizations self.log_enhanced_visualizations_frequency = ( log_enhanced_visualizations_frequency @@ -442,7 +454,7 @@ def __init__( self.validation_step_outputs = [] self._min_beta = 1e-15 - self._logvar_minmax = (-20,20) + self._logvar_minmax = (-20, 20) # Handle different parameter names for latent dimensions latent_dim = None @@ -450,11 +462,13 @@ def __init__( latent_dim = self.model_config["latent_dim"] elif "latent_size" in self.model_config: latent_dim = self.model_config["latent_size"] - + if latent_dim is not None: self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) else: - _logger.warning("No latent dimension provided for BetaVaeLogger. Using default with 128 dimensions.") + _logger.warning( + "No latent dimension provided for BetaVaeLogger. Using default with 128 dimensions." + ) self.vae_logger = BetaVaeLogger() def setup(self, stage: str = None): @@ -508,39 +522,56 @@ def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" original_shape = x.shape - is_monai_2d = (self.architecture == "monai_beta" and - self.model_config.get("spatial_dims") == 2) + is_monai_2d = ( + self.architecture == "monai_beta" + and self.model_config.get("spatial_dims") == 2 + ) if is_monai_2d and len(x.shape) == 5 and x.shape[2] == 1: x = x.squeeze(2) - + # Handle different model output formats model_output = self.model(x) - + recon_x = model_output.recon_x mu = model_output.mean logvar = model_output.logvar z = model_output.z - + if is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1: # Convert back (B, C, H, W) to (B, C, 1, H, W) recon_x = recon_x.unsqueeze(2) - current_beta = self._get_current_beta() batch_size = original_shape[0] # Use original input for loss computation to ensure shape consistency - x_original = x if not (is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1) else x.unsqueeze(2) + x_original = ( + x + if not (is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1) + else x.unsqueeze(2) + ) recon_loss = self.loss_function(recon_x, x_original) if isinstance(self.loss_function, nn.MSELoss): - if hasattr(self.loss_function, 'reduction') and self.loss_function.reduction == 'sum': + if ( + hasattr(self.loss_function, "reduction") + and self.loss_function.reduction == "sum" + ): recon_loss = recon_loss / batch_size - elif hasattr(self.loss_function, 'reduction') and self.loss_function.reduction == 'mean': + elif ( + hasattr(self.loss_function, "reduction") + and self.loss_function.reduction == "mean" + ): # Correct the over-normalization by PyTorch's mean reduction by multiplying by the number of elements per image num_elements_per_image = x_original[0].numel() recon_loss = recon_loss * num_elements_per_image - kl_loss = -0.5 * torch.sum(1 + torch.clamp(logvar,self._logvar_minmax[0],self._logvar_minmax[1]) - mu.pow(2) - logvar.exp(), dim=1) + kl_loss = -0.5 * torch.sum( + 1 + + torch.clamp(logvar, self._logvar_minmax[0], self._logvar_minmax[1]) + - mu.pow(2) + - logvar.exp(), + dim=1, + ) kl_loss = torch.mean(kl_loss) total_loss = recon_loss + current_beta * kl_loss diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 605f17ce6..1abfcc00a 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -18,14 +18,17 @@ def compute_piece_wise_distance( - features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray,groupby:list[str] = ["fov_name", "track_id"] -)->tuple[list[list[float]], list[list[float]]]: + features_df: pd.DataFrame, + cross_dist: NDArray, + rank_fractions: NDArray, + groupby: list[str] = ["fov_name", "track_id"], +) -> tuple[list[list[float]], list[list[float]]]: """ Computing the piece-wise distance and rank difference - Get the off diagonal per block and compute the mode - The blocks are not square, so we need to get the off diagonal elements - Get the 1 and 99 percentile of the off diagonal per block - + Parameters ---------- features_df : pd.DataFrame @@ -62,9 +65,11 @@ def compute_piece_wise_distance( return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track -def find_distribution_peak(data: np.ndarray, method: Literal["histogram", "kde_robust"] = "kde_robust") -> float: - """ Find the peak of a distribution - +def find_distribution_peak( + data: np.ndarray, method: Literal["histogram", "kde_robust"] = "kde_robust" +) -> float: + """Find the peak of a distribution + Parameters ---------- data: np.ndarray @@ -76,18 +81,20 @@ def find_distribution_peak(data: np.ndarray, method: Literal["histogram", "kde_r ------- float: The peak of the distribution (highest peak if multiple) """ - if method == 'histogram': + if method == "histogram": # Simple histogram-based peak finding hist, bin_edges = np.histogram(data, bins=50, density=True) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 - peaks, properties = find_peaks(hist, height=np.max(hist) * 0.1) # 10% of max height + peaks, properties = find_peaks( + hist, height=np.max(hist) * 0.1 + ) # 10% of max height if len(peaks) == 0: return bin_centers[np.argmax(hist)] # Fallback to global max # Return peak with highest density - peak_heights = properties['peak_heights'] + peak_heights = properties["peak_heights"] return bin_centers[peaks[np.argmax(peak_heights)]] - elif method == 'kde_robust': + elif method == "kde_robust": # More robust KDE approach kde = gaussian_kde(data) x_range = np.linspace(np.min(data), np.max(data), 1000) @@ -96,9 +103,8 @@ def find_distribution_peak(data: np.ndarray, method: Literal["histogram", "kde_r if len(peaks) == 0: return x_range[np.argmax(kde_vals)] # Fallback to global max # Return peak with highest KDE value - peak_heights = properties['peak_heights'] + peak_heights = properties["peak_heights"] return x_range[peaks[np.argmax(peak_heights)]] - def compute_embeddings_smoothness( @@ -149,8 +155,8 @@ def compute_embeddings_smoothness( # Compute piece-wise distance and rank difference features_df = features["sample"].to_dataframe().reset_index(drop=True) - piecewise_distance_per_track, _ = ( - compute_piece_wise_distance(features_df, cross_dist, rank_fractions) + piecewise_distance_per_track, _ = compute_piece_wise_distance( + features_df, cross_dist, rank_fractions ) all_piecewise_distances = np.concatenate(piecewise_distance_per_track) @@ -161,11 +167,12 @@ def compute_embeddings_smoothness( np.random.seed(42) i_indices = np.random.randint(0, len(cross_dist), size=n_samples) j_indices = np.random.randint(0, len(cross_dist), size=n_samples) - + diagonal_mask = i_indices == j_indices while diagonal_mask.any(): - j_indices[diagonal_mask] = np.random.randint(0, len(cross_dist), - size=diagonal_mask.sum()) + j_indices[diagonal_mask] = np.random.randint( + 0, len(cross_dist), size=diagonal_mask.sum() + ) diagonal_mask = i_indices == j_indices sampled_values = cross_dist[i_indices, j_indices] @@ -189,7 +196,7 @@ def compute_embeddings_smoothness( "random_frame_peak": float(random_peak), # "random_frame_distribution": sampled_values, "smoothness_score": float(smoothness_score), - "dynamic_range": float(dynamic_range), + "dynamic_range": float(dynamic_range), } distributions = { "adjacent_frame_distribution": all_piecewise_distances, @@ -201,4 +208,3 @@ def compute_embeddings_smoothness( print(f"{key}: {value}") return stats, distributions, piecewise_distance_per_track - diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 1e2d00ddf..a496ef569 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -108,7 +108,7 @@ def __init__( latent_dim: int = 1024, input_spatial_size: tuple[int, int] = (256, 256), stem_kernel_size: tuple[int, int, int] = (2, 4, 4), - stem_stride: tuple[int, int, int] = (2, 4, 4), + stem_stride: tuple[int, int, int] = (2, 4, 4), drop_path_rate: float = 0.0, pretrained: bool = True, ): @@ -134,7 +134,9 @@ def __init__( encoder.conv1 = nn.Identity() out_channels_encoder = num_channels[-1] else: - raise ValueError(f"Backbone {backbone} not supported. Use 'resnet50', 'convnext_tiny', or 'convnextv2_tiny'") + raise ValueError( + f"Backbone {backbone} not supported. Use 'resnet50', 'convnext_tiny', or 'convnextv2_tiny'" + ) # Stem for 3d multichannel and to convert 3D to 2D self.stem = StemDepthtoChannels( @@ -148,22 +150,24 @@ def __init__( self.num_channels = num_channels self.in_channels_encoder = in_channels_encoder self.out_channels_encoder = out_channels_encoder - + # Calculate spatial size after stem stem_spatial_size_h = input_spatial_size[0] // stem_stride[1] stem_spatial_size_w = input_spatial_size[1] // stem_stride[2] - + # Spatial size after backbone backbone_reduction = 2 ** (len(num_channels) - 1) final_spatial_size_h = stem_spatial_size_h // backbone_reduction final_spatial_size_w = stem_spatial_size_w // backbone_reduction - - flattened_size = out_channels_encoder * final_spatial_size_h * final_spatial_size_w + + flattened_size = ( + out_channels_encoder * final_spatial_size_h * final_spatial_size_w + ) self.fc = nn.Linear(flattened_size, latent_dim) self.fc_mu = nn.Linear(latent_dim, latent_dim) self.fc_logvar = nn.Linear(latent_dim, latent_dim) - + # Store final spatial size for decoder (assuming square for simplicity) self.encoder_spatial_size = final_spatial_size_h # Assuming square output @@ -189,9 +193,9 @@ def forward(self, x: Tensor) -> SimpleNamespace: x_intermediate = self.fc(x_flat) - mu = self.fc_mu(x_intermediate) - logvar = self.fc_logvar(x_intermediate) - z = self.reparameterize(mu, logvar) + mu = self.fc_mu(x_intermediate) + logvar = self.fc_logvar(x_intermediate) + z = self.reparameterize(mu, logvar) return SimpleNamespace(mean=mu, log_covariance=logvar, z=z) @@ -207,7 +211,7 @@ def __init__( out_stack_depth: int = 16, head_expansion_ratio: int = 2, strides: list[int] = [2, 2, 2, 1], - encoder_spatial_size: int=16, + encoder_spatial_size: int = 16, head_pool: bool = False, upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, @@ -221,7 +225,6 @@ def __init__( self.out_stack_depth = out_stack_depth self.decoder_channels = decoder_channels - self.spatial_size = encoder_spatial_size self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) @@ -303,7 +306,7 @@ def __init__( upsample_pre_conv: Literal["default"] | Callable | None = None, ): super().__init__() - + self.encoder = VaeEncoder( backbone=backbone, in_channels=in_channels, @@ -320,8 +323,8 @@ def __init__( decoder_channels[-1] = ( (out_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio ) - - strides = [2] * (len(decoder_channels) - 1) + [1] + + strides = [2] * (len(decoder_channels) - 1) + [1] self.decoder = VaeDecoder( decoder_channels=decoder_channels, @@ -342,19 +345,20 @@ def forward(self, x: Tensor) -> SimpleNamespace: """Forward pass returning VAE outputs.""" encoder_output = self.encoder(x) recon_x = self.decoder(encoder_output.z) - + return SimpleNamespace( recon_x=recon_x, mean=encoder_output.mean, logvar=encoder_output.log_covariance, - z=encoder_output.z + z=encoder_output.z, ) class BetaVaeMonai(nn.Module): """Beta-VAE with Monai architecture.""" - def __init__(self, + def __init__( + self, spatial_dims: int, in_shape: Sequence[int], out_channels: int, @@ -366,8 +370,8 @@ def __init__(self, num_res_units: int = 0, use_sigmoid: bool = False, norm: Literal[Norm.BATCH, Norm.INSTANCE] = Norm.INSTANCE, - **kwargs - ): + **kwargs, + ): super().__init__() self.spatial_dims = spatial_dims @@ -381,7 +385,7 @@ def __init__(self, self.num_res_units = num_res_units self.use_sigmoid = use_sigmoid self.norm = norm - + self.model = VarAutoEncoder( spatial_dims=self.spatial_dims, in_shape=self.in_shape, @@ -394,10 +398,10 @@ def __init__(self, num_res_units=self.num_res_units, use_sigmoid=self.use_sigmoid, norm=self.norm, - **kwargs + **kwargs, ) def forward(self, x: Tensor) -> SimpleNamespace: """Forward pass returning VAE encoder outputs.""" recon_x, mu, logvar, z = self.model(x) - return SimpleNamespace(recon_x=recon_x, mean=mu, logvar=logvar, z=z) \ No newline at end of file + return SimpleNamespace(recon_x=recon_x, mean=mu, logvar=logvar, z=z) diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index cbabd39ec..abe088876 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -46,4 +46,4 @@ "StackChannelsd", "TiledSpatialCropSamplesd", "ToDeviced", -] \ No newline at end of file +] diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 2e4363849..696c81abc 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -196,4 +196,4 @@ def __init__( spatial_axis: Sequence[int] | int, **kwargs, ): - super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) \ No newline at end of file + super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) From a6c2c69c5586d448e5076f9cb4b4c62d86fa238d Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:32:26 -0700 Subject: [PATCH 054/101] fix typo --- viscy/data/cell_division_triplet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index 86039ad37..90ad2910f 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -346,7 +346,7 @@ def _setup_fit(self, dataset_settings: dict): shuffled_indices = self._set_fit_global_state(len(self.npy_files)) npy_files = [self.npy_files[i] for i in shuffled_indices] - # Se the train an dval positions + # Set the train and eval positions num_train_files = int(len(self.npy_files) * self.split_ratio) train_npy_files = npy_files[:num_train_files] val_npy_files = npy_files[num_train_files:] From 9ed78b4c63ef929a0401ca0b3f897c51d5551eff Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 11:47:48 -0700 Subject: [PATCH 055/101] remove the archived unecessary files --- .../BetaVAE/archive/config_betavae.yml | 130 -------- .../archive/config_betavae_convnext.yml | 146 --------- .../BetaVAE/archive/debug_dimensions.py | 278 ------------------ .../DynaCLR/BetaVAE/archive/debug_stem.py | 39 --- 4 files changed, 593 deletions(-) delete mode 100644 applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml delete mode 100644 applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml delete mode 100644 applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py delete mode 100644 applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py diff --git a/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml b/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml deleted file mode 100644 index 93784faf1..000000000 --- a/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae.yml +++ /dev/null @@ -1,130 +0,0 @@ -seed_everything: 42 -trainer: - accelerator: gpu - devices: 1 - num_nodes: 1 - strategy: auto - precision: 16-mixed - max_epochs: 200 - log_every_n_steps: 10 - check_val_every_n_epoch: 1 - logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: "/hpc/projects/organelle_phenotyping/models/SEC61B/vae" - version: "sensor_phase3d_zikv_denv_lr2e-4_beta1.5" - log_graph: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: "loss/total/val" - save_top_k: 5 - save_last: true - every_n_epochs: 1 - fast_dev_run: true - enable_checkpointing: true - # inference_mode: true - use_distributed_sampler: true - -model: - class_path: viscy.representation.engine.BetaVaeModule - init_args: - architecture: "monai_beta" - model_config: - spatial_dims: 3 - in_shape: [2, 16, 192, 192] - out_channels: 2 - latent_size: 1024 - channels: [64, 128, 256, 512] - strides: [[2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]] - beta: 1.0 # Conservative target - can increase later - beta_schedule: cosine - beta_min: 0.1 # Start low to learn reconstructions first - beta_warmup_epochs: 50 # Half of training for gradual ramp - lr: 0.0002 - example_input_array_shape: [1, 2, 16, 192, 192] - loss_function: - class_path: torch.nn.MSELoss - init_args: {reduction: 'mean'} -data: - class_path: viscy.data.triplet.TripletDataModule - init_args: - data_path: "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr" - tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr" - source_channel: - - &phase Phase3D - - &mcherry raw mCherry EX561 EM600-37 - z_range: [10, 26] - initial_yx_patch_size: [384, 384] - final_yx_patch_size: [192, 192] - batch_size: 64 - num_workers: 12 - time_interval: 1 - augment_validation: false - return_negative: false - fit_include_wells: ["B/3", "B/4", "C/3", "C/4"] - augmentations: - - class_path: viscy.transforms.RandAffined - init_args: - keys: [*phase, *mcherry] - prob: 0.8 - scale_range: [0, 0.2, 0.2] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.01, 0.01] - padding_mode: zeros - - class_path: viscy.transforms.RandAdjustContrastd - init_args: - keys: [*mcherry] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy.transforms.RandAdjustContrastd - init_args: - keys: [*phase] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy.transforms.RandScaleIntensityd - init_args: - keys: [*mcherry] - prob: 0.5 - factors: 0.5 - - class_path: viscy.transforms.RandScaleIntensityd - init_args: - keys: [*phase] - prob: 0.5 - factors: 0.5 - - class_path: viscy.transforms.RandGaussianSmoothd - init_args: - keys: [*phase, *mcherry] - prob: 0.5 - sigma_x: [0.25, 0.75] - sigma_y: [0.25, 0.75] - sigma_z: [0.0, 0.0] - - class_path: viscy.transforms.RandGaussianNoised - init_args: - keys: [*mcherry] - prob: 0.5 - mean: 0.0 - std: 0.2 - - class_path: viscy.transforms.RandGaussianNoised - init_args: - keys: [*phase] - prob: 0.5 - mean: 0.0 - std: 0.2 - normalizations: - - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [*phase] - level: fov_statistics - subtrahend: mean - divisor: std - - class_path: viscy.transforms.ScaleIntensityRangePercentilesd - init_args: - keys: [*mcherry] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml b/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml deleted file mode 100644 index c586855e0..000000000 --- a/applications/benchmarking/DynaCLR/BetaVAE/archive/config_betavae_convnext.yml +++ /dev/null @@ -1,146 +0,0 @@ -seed_everything: 42 -trainer: - accelerator: gpu - devices: 1 - num_nodes: 1 - strategy: auto - precision: 16-mixed - max_epochs: 200 - log_every_n_steps: 10 - check_val_every_n_epoch: 1 - logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: "/hpc/projects/organelle_phenotyping/models/SEC61B/vae" - version: "sensor_phase3d_zikv_denv_lr2e-4_beta1.5" - log_graph: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: "loss/total/val" - save_top_k: 5 - save_last: true - every_n_epochs: 1 - fast_dev_run: true - enable_checkpointing: true - # inference_mode: true - use_distributed_sampler: true - -model: - class_path: viscy.representation.engine.BetaVaeModule - init_args: - architecture: "2.5D" - model_config: - backbone: convnext_tiny - in_channels: 2 - in_stack_depth: 16 - out_stack_depth: 16 - latent_dim: 1024 - input_spatial_size: [192, 192] - stem_kernel_size: [4, 2, 2] - stem_stride: [4, 2, 2] - decoder_stages: 4 - head_expansion_ratio: 2 - head_pool: false - upsample_mode: pixelshuffle - conv_blocks: 2 - norm_name: batch - beta: 1.0 # Conservative target - can increase later - beta_schedule: cosine - beta_min: 0.1 # Start low to learn reconstructions first - beta_warmup_epochs: 50 # Half of training for gradual ramp - lr: 0.0002 - example_input_array_shape: [1, 2, 16, 192, 192] - loss_function: - class_path: torch.nn.MSELoss - init_args: - reduction: mean - log_batches_per_epoch: 8 - log_samples_per_batch: 1 - compute_disentanglement: false - disentanglement_frequency: 10 - log_enhanced_visualizations: false - log_enhanced_visualizations_frequency: 30 - -data: - class_path: viscy.data.triplet.TripletDataModule - init_args: - data_path: "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr" - tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr" - source_channel: - - &phase Phase3D - - &mcherry raw mCherry EX561 EM600-37 - z_range: [10, 26] - initial_yx_patch_size: [384, 384] - final_yx_patch_size: [192, 192] - batch_size: 64 - num_workers: 12 - time_interval: 1 - augment_validation: false - return_negative: false - fit_include_wells: ["B/3", "B/4", "C/3", "C/4"] - augmentations: - - class_path: viscy.transforms.RandAffined - init_args: - keys: [*phase, *mcherry] - prob: 0.8 - scale_range: [0, 0.2, 0.2] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.01, 0.01] - padding_mode: zeros - - class_path: viscy.transforms.RandAdjustContrastd - init_args: - keys: [*mcherry] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy.transforms.RandAdjustContrastd - init_args: - keys: [*phase] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy.transforms.RandScaleIntensityd - init_args: - keys: [*mcherry] - prob: 0.5 - factors: 0.5 - - class_path: viscy.transforms.RandScaleIntensityd - init_args: - keys: [*phase] - prob: 0.5 - factors: 0.5 - - class_path: viscy.transforms.RandGaussianSmoothd - init_args: - keys: [*phase, *mcherry] - prob: 0.5 - sigma_x: [0.25, 0.75] - sigma_y: [0.25, 0.75] - sigma_z: [0.0, 0.0] - - class_path: viscy.transforms.RandGaussianNoised - init_args: - keys: [*mcherry] - prob: 0.5 - mean: 0.0 - std: 0.2 - - class_path: viscy.transforms.RandGaussianNoised - init_args: - keys: [*phase] - prob: 0.5 - mean: 0.0 - std: 0.2 - normalizations: - - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [*phase] - level: fov_statistics - subtrahend: mean - divisor: std - - class_path: viscy.transforms.ScaleIntensityRangePercentilesd - init_args: - keys: [*mcherry] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py deleted file mode 100644 index 742ad5719..000000000 --- a/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_dimensions.py +++ /dev/null @@ -1,278 +0,0 @@ -# %% -import torch -from viscy.representation.vae import VaeEncoder, VaeDecoder - - -def debug_vae_dimensions(): - """Debug VAE encoder/decoder dimension compatibility.""" - - print("=== VAE Dimension Debugging (Updated Architecture) ===\n") - - # Configuration matching current config - z_stack_depth = 8 - input_shape = (1, 1, z_stack_depth, 128, 128) # 1 channel to match config - latent_dim = 1024 # Updated to new default - - print(f"Input shape: {input_shape}") - print(f"Expected latent dim: {latent_dim}") - print() - - # Debug encoder channel expectations - import timm - - debug_encoder = timm.create_model("resnet50", pretrained=False, features_only=True) - print(f"ResNet50 conv1.out_channels: {debug_encoder.conv1.out_channels}") - print(f"ResNet50 expects input channels: {debug_encoder.conv1.in_channels}") - print() - - # Create encoder - encoder = VaeEncoder( - backbone="resnet50", - in_channels=1, - in_stack_depth=z_stack_depth, - latent_dim=latent_dim, - input_spatial_size=(128, 128), # Match the actual input size - stem_kernel_size=(4, 2, 2), - stem_stride=(4, 2, 2), - ) - - # Create decoder - decoder = VaeDecoder( - decoder_channels=[1024, 512, 256, 128], - latent_dim=latent_dim, - out_channels=1, - out_stack_depth=z_stack_depth, - head_expansion_ratio=2, - head_pool=False, - upsample_mode="pixelshuffle", - conv_blocks=2, - norm_name="batch", - strides=[2, 2, 2, 1], - input_spatial_size=( - 128, - 128, - ), # Add input spatial size for correct spatial_size calculation - ) - - print("=== ENCODER FORWARD PASS ===") - - # Test encoder - x = torch.randn(*input_shape) - print(f"Input to encoder: {x.shape}") - - try: - # Step through encoder - print("\\n1. Stem processing:") - x_stem = encoder.stem(x) - print(f" After stem: {x_stem.shape}") - - print("\\n2. Backbone processing:") - features = encoder.encoder(x_stem) - for i, feat in enumerate(features): - print(f" Feature {i}: {feat.shape}") - - print("\\n3. Final processing:") - x_final = features[-1] - print(f" Final features: {x_final.shape}") - - # Flatten spatial dimensions (new approach) - x_flat = x_final.flatten(1) # Use flatten(1) like updated code - print(f" After flatten: {x_flat.shape}") - - print("\\n3b. Intermediate FC layer:") - # Test intermediate FC layer (new addition) - x_intermediate = encoder.fc(x_flat) - print(f" After intermediate FC: {x_intermediate.shape}") - - # Full encoder output - encoder_output = encoder(x) - mu = encoder_output.mean - logvar = encoder_output.log_covariance - z = encoder_output.z - print(f" Final mu: {mu.shape}") - print(f" Final logvar: {logvar.shape}") - print(f" Sampled z: {z.shape}") - - print("\\n=== DECODER FORWARD PASS ===") - - print(f"Input to decoder (sampled z): {z.shape}") - - print("\\n1. Reshape to spatial:") - batch_size = z.size(0) - z_spatial = decoder.latent_reshape(z) - print(f" After linear reshape: {z_spatial.shape}") - - z_spatial_reshaped = z_spatial.view( - batch_size, - decoder.spatial_channels, - decoder.spatial_size, - decoder.spatial_size, - ) - print(f" After view to spatial: {z_spatial_reshaped.shape}") - - print("\\n2. Latent projection:") - x_proj = decoder.latent_proj(z_spatial_reshaped) - print(f" After conv projection: {x_proj.shape}") - - print("\\n3. Decoder stages:") - x_current = x_proj - for i, stage in enumerate(decoder.decoder_stages): - x_current = stage(x_current) - print(f" After stage {i}: {x_current.shape}") - - print("\\n4. Head processing:") - final_output = decoder.head(x_current) - print(f" Final output: {final_output.shape}") - - # Full decoder output (now returns tensor directly, not dict) - reconstruction = decoder(z) - print(f" Full reconstruction: {reconstruction.shape}") - - print("\\n=== DIMENSION ANALYSIS ===") - print(f"✓ Encoder input: {input_shape}") - print(f"✓ Encoder output: {mu.shape}") - print(f"✓ Decoder input: {z.shape}") - print(f"✓ Decoder output: {reconstruction.shape}") - - # Calculate tensor sizes and compression ratio - input_size = torch.numel(x) - latent_size = torch.numel(mu) - recon_size = torch.numel(reconstruction) - - print(f" Input tensor size: {input_size:,}") - print(f" Latent tensor size: {latent_size:,}") - print(f" Reconstruction tensor size: {recon_size:,}") - print(f" Compression ratio: {input_size / latent_size:.1f}:1") - print(f" Size ratio (recon/input): {recon_size / input_size:.2f}") - - # Check if reconstruction matches input - if reconstruction.shape == x.shape: - print("✓ SUCCESS: Reconstruction shape matches input shape!") - else: - print(f"✗ ERROR: Shape mismatch!") - print(f" Input: {x.shape}") - print(f" Reconstruction: {reconstruction.shape}") - - # Analyze each dimension - for i, (inp_dim, recon_dim) in enumerate( - zip(x.shape, reconstruction.shape) - ): - if inp_dim != recon_dim: - print( - f" Dimension {i}: {inp_dim} → {recon_dim} (factor: {recon_dim/inp_dim:.2f})" - ) - - print("\\n=== VAE LOSS COMPUTATION TEST ===") - - # Simulate full VAE forward pass with loss computation - print("Testing full VAE forward pass with loss computation...") - - # Sample from latent distribution (reparameterization trick) - eps = torch.randn_like(mu) - z_sampled = mu + torch.exp(0.5 * logvar) * eps - print(f"Sampled latent z: {z_sampled.shape}") - - # Use the z from encoder (already sampled) - reconstruction_from_sampled = decoder(z) - print( - f"Reconstruction from encoder's sampled z: {reconstruction_from_sampled.shape}" - ) - - # Compute VAE losses - import torch.nn.functional as F - - # Reconstruction loss (MSE) - recon_loss = F.mse_loss(reconstruction_from_sampled, x, reduction="mean") - print(f"Reconstruction loss (MSE): {recon_loss.item():.6e}") - - # KL divergence loss - kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0) - print(f"KL divergence loss: {kl_loss.item():.6e}") - - # Total VAE loss with different beta values - betas = [0.1, 1.0, 1.5, 4.0] - for beta in betas: - total_loss = recon_loss + beta * kl_loss - print(f"Total loss (β={beta}): {total_loss.item():.6e}") - - # Check for problematic values - print("\\n=== LOSS HEALTH CHECK ===") - - if torch.isnan(recon_loss): - print("✗ CRITICAL: Reconstruction loss is NaN!") - elif torch.isinf(recon_loss): - print("✗ CRITICAL: Reconstruction loss is Inf!") - elif recon_loss.item() > 1e6: - print(f"⚠ WARNING: Very high reconstruction loss: {recon_loss.item():.2e}") - elif recon_loss.item() < 1e-10: - print(f"⚠ WARNING: Very low reconstruction loss: {recon_loss.item():.2e}") - else: - print(f"✓ Reconstruction loss looks reasonable: {recon_loss.item():.6f}") - - if torch.isnan(kl_loss): - print("✗ CRITICAL: KL loss is NaN!") - elif torch.isinf(kl_loss): - print("✗ CRITICAL: KL loss is Inf!") - else: - print(f"✓ KL loss looks reasonable: {kl_loss.item():.6f}") - - # Check reconstruction value ranges - recon_min, recon_max = ( - reconstruction_from_sampled.min(), - reconstruction_from_sampled.max(), - ) - input_min, input_max = x.min(), x.max() - - print(f"\\nValue ranges:") - print(f" Input range: [{input_min.item():.3f}, {input_max.item():.3f}]") - print( - f" Reconstruction range: [{recon_min.item():.3f}, {recon_max.item():.3f}]" - ) - - if recon_max.item() > 100 or recon_min.item() < -100: - print( - "⚠ WARNING: Reconstruction values are very large - possible gradient explosion" - ) - - # Check latent statistics - mu_mean, mu_std = mu.mean(), mu.std() - logvar_mean, logvar_std = logvar.mean(), logvar.std() - - print(f"\\nLatent statistics:") - print(f" μ mean/std: {mu_mean.item():.3f} / {mu_std.item():.3f}") - print(f" log(σ²) mean/std: {logvar_mean.item():.3f} / {logvar_std.item():.3f}") - - if mu_std.item() > 10: - print("⚠ WARNING: μ has very high variance - possible gradient explosion") - if logvar_mean.item() > 10: - print("⚠ WARNING: log(σ²) is very large - possible numerical instability") - - except Exception as e: - print(f"✗ ERROR during forward pass: {e}") - print(f"Error type: {type(e).__name__}") - import traceback - - traceback.print_exc() - - # Check flattened feature size for new architecture - print("\\n=== ENCODER FLATTENED SIZE ANALYSIS ===") - try: - x_stem = encoder.stem(x) - features = encoder.encoder(x_stem) - final_feat = features[-1] - print(f"Final feature shape: {final_feat.shape}") - - flattened_size = final_feat.flatten(1).shape[1] - print(f"Flattened size: {flattened_size:,}") - print(f"Expected latent dim: {latent_dim:,}") - - compression_ratio = flattened_size / latent_dim - print(f"Compression ratio: {compression_ratio:.1f}:1") - - except Exception as inner_e: - print(f"Error in flattened size analysis: {inner_e}") - - -if __name__ == "__main__": - debug_vae_dimensions() -# %% diff --git a/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py deleted file mode 100644 index 9b52fab5b..000000000 --- a/applications/benchmarking/DynaCLR/BetaVAE/archive/debug_stem.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from viscy.representation.vae import VaeEncoder - -# Test the stem layer computation -z_stack_depth = 32 -encoder = VaeEncoder( - backbone="resnet50", - in_channels=1, - in_stack_depth=z_stack_depth, - embedding_dim=256, - stem_kernel_size=(8, 4, 4), - stem_stride=(8, 4, 4), -) - -# Create test input -x = torch.randn(1, 1, z_stack_depth, 192, 192) -print(f"Input shape: {x.shape}") - -# Test stem output -stem_output = encoder.stem(x) -print(f"Stem output shape: {stem_output.shape}") - -# Check what the ResNet expects -import timm -resnet50 = timm.create_model("resnet50", pretrained=True, features_only=True) -print(f"ResNet50 conv1 expects input channels: {resnet50.conv1.in_channels}") -print(f"ResNet50 conv1 produces output channels: {resnet50.conv1.out_channels}") - -# Test if we can pass stem output to ResNet -try: - # Remove conv1 like in the encoder - resnet50.conv1 = torch.nn.Identity() - resnet_output = resnet50(stem_output) - print(f"ResNet output shapes: {[f.shape for f in resnet_output]}") - print("SUCCESS: No channel mismatch!") -except Exception as e: - print(f"ERROR: {e}") \ No newline at end of file From 6bf7a4a7b4e31cddc9fbba89323c6687faec9937 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 11:48:47 -0700 Subject: [PATCH 056/101] remove the test run archived file --- .../DynaCLR/BetaVAE/archive/test_run.py | 222 ------------------ 1 file changed, 222 deletions(-) delete mode 100644 applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py diff --git a/applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py b/applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py deleted file mode 100644 index d1c0290b1..000000000 --- a/applications/benchmarking/DynaCLR/BetaVAE/archive/test_run.py +++ /dev/null @@ -1,222 +0,0 @@ -# %% -import torch -from lightning.pytorch import seed_everything -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers import TensorBoardLogger -from monai.transforms.intensity.dictionary import ( - RandAdjustContrastd, - RandGaussianNoised, - RandGaussianSmoothd, - RandScaleIntensityd, - ScaleIntensityRangePercentilesd, -) -from monai.transforms.spatial.dictionary import RandAffined - -from viscy.data.triplet import TripletDataModule -from viscy.representation.engine import VaeModule -from viscy.representation.vae import VaeDecoder, VaeEncoder -from viscy.trainer import VisCyTrainer -from viscy.transforms import ( - NormalizeSampled, -) - - -# %% -def channel_augmentations(processing_channel: str): - return [ - RandAffined( - keys=[processing_channel], - prob=0.8, - scale_range=[0, 0.2, 0.2], - rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.01, 0.01], - padding_mode="zeros", - ), - RandAdjustContrastd( - keys=[processing_channel], - prob=0.5, - gamma=(0.8, 1.2), - ), - RandScaleIntensityd( - keys=[processing_channel], - prob=0.5, - factors=0.5, - ), - RandGaussianSmoothd( - keys=[processing_channel], - prob=0.5, - sigma_x=(0.25, 0.75), - sigma_y=(0.25, 0.75), - sigma_z=(0.0, 0.0), - ), - RandGaussianNoised( - keys=[processing_channel], - prob=0.5, - mean=0.0, - std=0.2, - ), - ] - - -# %% -def channel_normalization( - phase_channel: str | None = None, - fl_channel: str | None = None, -): - if phase_channel: - return [ - NormalizeSampled( - keys=[phase_channel], - level="fov_statistics", - subtrahend="mean", - divisor="std", - ) - ] - elif fl_channel: - return [ - ScaleIntensityRangePercentilesd( - keys=[fl_channel], - lower=50, - upper=99, - b_min=0.0, - b_max=1.0, - ) - ] - else: - raise NotImplementedError("Either phase_channel or fl_channel must be provided") - - -if __name__ == "__main__": - seed_everything(42) - - # use tensor cores on Ampere GPUs (24-bit tensorfloat matmul) - torch.set_float32_matmul_precision("high") - - initial_yx_patch_size = (384, 384) - final_yx_patch_size = (192, 192) - batch_size = 64 - num_workers = 12 - time_interval = 1 - z_stack_depth = 16 - - print("Creating model components...") - - # Create encoder with debug info - encoder = VaeEncoder( - backbone="resnet50", - in_channels=1, - in_stack_depth=z_stack_depth, - embedding_dim=256, - stem_kernel_size=(4, 4, 4), - stem_stride=(4, 4, 4), - ) - print(f"Encoder created successfully") - - # Test encoder forward pass - test_input = torch.randn(1, 1, z_stack_depth, 192, 192) - try: - encoder_output = encoder(test_input) - print(f"Encoder test passed: {encoder_output.embedding.shape}") - except Exception as e: - print(f"Encoder test failed: {e}") - exit(1) - - # Create decoder - decoder = VaeDecoder( - decoder_channels=[1024, 512, 256, 128], - latent_dim=256, - out_channels=1, - out_stack_depth=z_stack_depth, - latent_spatial_size=3, - head_expansion_ratio=2, - head_pool=False, - upsample_mode="pixelshuffle", - conv_blocks=2, - norm_name="batch", - upsample_pre_conv=None, - strides=[2, 2, 2, 2], - ) - print(f"Decoder created successfully") - - # Create VaeModule - model = VaeModule( - encoder=encoder, - decoder=decoder, - example_input_array_shape=(1, 1, z_stack_depth, 192, 192), - latent_dim=256, - beta=1.5, - lr=2e-4, - beta_schedule="linear", - beta_min=0.5, - beta_warmup_epochs=15, - ) - print(f"VaeModule created successfully") - - # Test model forward pass - try: - model_output = model(test_input) - print(f"Model test passed: loss={model_output['loss']}") - except Exception as e: - print(f"Model test failed: {e}") - exit(1) - - # Create data module - print("Creating data module...") - dm = TripletDataModule( - data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/2-assemble/2024_10_16_A549_SEC61_ZIKV_DENV.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", - source_channel=["Phase3D"], - z_range=(10, 26), - initial_yx_patch_size=initial_yx_patch_size, - final_yx_patch_size=final_yx_patch_size, - batch_size=batch_size, - num_workers=num_workers, - time_interval=time_interval, - augmentations=channel_augmentations("Phase3D"), - normalizations=channel_normalization(phase_channel="Phase3D"), - augment_validation=False, - return_negative=False, - fit_include_wells=["B/3", "B/4", "C/3", "C/4"], - ) - dm.setup("fit") - print(f"DataModule created successfully") - train_size = len(dm.train_dataset) - val_size = len(dm.val_dataset) - batches_per_epoch = train_size // batch_size - - print(f"Training samples: {train_size:,}") - print(f"Validation samples: {val_size:,}") - print(f"Batches per epoch: {batches_per_epoch:,}") - - # # Create trainer - trainer = VisCyTrainer( - accelerator="gpu", - strategy="ddp", - devices=4, - num_nodes=1, - precision="16-mixed", - # fast_dev_run=True, - max_epochs=100, - log_every_n_steps=10, - check_val_every_n_epoch=1, - logger=TensorBoardLogger( - save_dir="/hpc/projects/organelle_phenotyping/models/SEC61B/vae", - name="betavae_phase3D", - version="beta_1.5_16slice", - ), - callbacks=[ - LearningRateMonitor(logging_interval="step"), - ModelCheckpoint( - monitor="loss/total/val", - save_top_k=5, - save_last=True, - every_n_epochs=1, - ), - ], - use_distributed_sampler=True, - ) - - print("Starting training...") - trainer.fit(model, dm) - -# %% From 8de14e4ba1e776358a556060fc451553f8487e04 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 13:04:38 -0700 Subject: [PATCH 057/101] adding normalizeintensity --- .../DynaCLR/MonaiVAE/test_vae_magnitudes.py | 243 ------------------ viscy/transforms/_redef.py | 24 +- 2 files changed, 15 insertions(+), 252 deletions(-) delete mode 100644 applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py diff --git a/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py b/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py deleted file mode 100644 index 8a6a1612e..000000000 --- a/applications/benchmarking/DynaCLR/MonaiVAE/test_vae_magnitudes.py +++ /dev/null @@ -1,243 +0,0 @@ -#%% -from pathlib import Path - -import numpy as np -import torch -import torch.nn.functional as F -from monai.transforms import ( - NormalizeIntensity, -) -from torch.nn import KLDivLoss, MSELoss -from torchview import draw_graph - -from viscy.representation.vae import BetaVae25D, BetaVaeMonai - - -def compute_vae_losses(model_output, target, beta=1.0): - """Compute VAE losses: reconstruction (MSE) and KL divergence. - """ - mse_loss_fn = MSELoss(reduction='mean') - recon_loss = mse_loss_fn(model_output.recon_x, target) - - # Standard VAE: per-sample, per-dimension KL loss normalization - batch_size = target.size(0) - latent_dim = model_output.mean.size(1) # Get latent dimension - normalizer = batch_size * latent_dim # Normalize by both batch size and latent dim - - kl_loss = -0.5 * torch.sum(1 + model_output.logvar - model_output.mean.pow(2) - model_output.logvar.exp()) - print(f" Debug - KL raw: {kl_loss.item():.6f}, normalizer: {normalizer}, batch_size: {target.size(0)}") - kl_loss = kl_loss / normalizer - - total_loss = recon_loss + beta * kl_loss - - return { - 'mu': model_output.mean, - 'logvar': model_output.logvar, - 'recon_loss': recon_loss.item(), - 'kl_loss': kl_loss.item(), - 'total_loss': total_loss.item(), - 'beta': beta, - # 'recon_magnitude': torch.abs(model_output.recon_x).mean().item(), - # 'target_magnitude': torch.abs(target).mean().item(), - # 'latent_mean_magnitude': torch.abs(model_output.mean).mean().item(), - # 'latent_std_magnitude': torch.exp(0.5 * model_output.logvar).mean().item(), - } - - -def create_synthetic_data(batch_size=2, channels=2, depth=16, height=256, width=256): - """Create synthetic microscopy-like data with known statistics. - These are from one FOV of the Phase3D - - mean: 8.196415001293644e-05 ≈ 0.0001 - - std: 0.09095408767461777 ≈ 0.091 - """ - torch.manual_seed(42) - synthetic_data = torch.randn(batch_size, channels, depth, height, width) * 0.091 + 0.0001 - - for b in range(batch_size): - for c in range(channels): - for d in range(depth): - # Add some blob-like structures - y_center, x_center = np.random.randint(50, height-50), np.random.randint(50, width-50) - y, x = np.ogrid[:height, :width] - mask = (y - y_center)**2 + (x - x_center)**2 < np.random.randint(400, 1600) - synthetic_data[b, c, d][mask] += np.random.normal(0.05, 0.02) - - synthetic_data = torch.clamp(synthetic_data, min=0) - - return synthetic_data - - -def create_known_target(input_data, noise_level=0.1): - """Create a target with known relationship to input for testing MSE magnitude. - """ - target = input_data.clone() - - noise = torch.randn_like(target) * noise_level * target.std() - target = target + noise - - target = target * 0.95 + 0.01 - - return torch.clamp(target, min=0) - - -def test_vae_magnitudes(): - """Test VAE models with both real dataloader and synthetic data.""" - print("=== VAE Magnitude Testing ===\n") - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Using device: {device}") - - model_configs = [ - # { - # 'name': 'BetaVae25D_ResNet50', - # 'model_class': BetaVae25D, - # 'kwargs': { - # 'backbone': 'resnet50', - # 'in_channels': 2, - # 'in_stack_depth': 16, - # 'latent_dim': 1024, - # 'input_spatial_size': (256, 256), - # } - # }, - # Uncomment to test MONAI version - { - 'name': 'BetaVaeMonai', - 'model_class': BetaVaeMonai, - 'kwargs': { - 'spatial_dims': 3, - 'in_shape': (2, 16, 256, 256), # (C, D, H, W) - 'out_channels': 2, - 'latent_size': 1024, - 'channels': (32, 64, 128, 256), - 'strides': (2, 2, 2, 2), - } - } - ] - - # Test different beta values - beta_values = [0.1, 1.0, 4.0, 10.0] - - for model_config in model_configs: - print(f"\n{'='*50}") - print(f"Testing {model_config['name']}") - print(f"{'='*50}") - - # Initialize model - model = model_config['model_class'](**model_config['kwargs']) - model = model.to(device) - model.eval() - - print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - - # Draw model graph - print(f"\n--- Model Architecture ---") - sample_input = create_synthetic_data(batch_size=1).to(device) - try: - model_graph = draw_graph( - model, - input_data=sample_input, - expand_nested=True, - depth=6, - save_graph=True, - filename=f'{model_config["name"]}_graph', - directory='./model_graphs/' - ) - print(f"Model graph saved to: ./model_graphs/{model_config['name']}_graph.png") - except Exception as e: - print(f"Could not generate model graph: {e}") - - # Test 1: Synthetic data with known target - print(f"\n--- Test 1: Synthetic Data ---") - synthetic_input = create_synthetic_data().to(device) - synthetic_target = create_known_target(synthetic_input).to(device) - - print(f"Input shape: {synthetic_input.shape}") - print(f"Input stats - mean: {synthetic_input.mean():.6f}, std: {synthetic_input.std():.6f}") - print(f"Target stats - mean: {synthetic_target.mean():.6f}, std: {synthetic_target.std():.6f}") - - with torch.no_grad(): - synthetic_output = model(synthetic_input) - - print(f"Output shape: {synthetic_output.recon_x.shape}") - print(f"Latent shape: {synthetic_output.z.shape}") - - for beta in beta_values: - losses = compute_vae_losses(model_output=synthetic_output, target=synthetic_target, beta=beta) - print(f"\nBeta = {beta}:") - print(f" Mu shape: {losses['mu'].shape}, mean: {losses['mu'].mean():.6f}, std: {losses['mu'].std():.6f}") - print(f" Logvar shape: {losses['logvar'].shape}, mean: {losses['logvar'].mean():.6f}, std: {losses['logvar'].std():.6f}") - print(f" Reconstruction Loss: {losses['recon_loss']:.6f}") - print(f" KL Loss: {losses['kl_loss']:.6f}") - print(f" Total Loss: {losses['total_loss']:.6f}") - # print(f" Recon magnitude: {losses['recon_magnitude']:.6f}") - # print(f" Target magnitude: {losses['target_magnitude']:.6f}") - # print(f" Latent mean magnitude: {losses['latent_mean_magnitude']:.6f}") - # print(f" Latent std magnitude: {losses['latent_std_magnitude']:.6f}") - - #TODO: use the dataloader to run it with real data - # data_path = "/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV" - # zarr_path = Path(data_path) / "2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr" - zarr_path = None - if not zarr_path: - print(f"Found real data at: {zarr_path}") - - normalizations = [ - NormalizeIntensity() - ] - - print("Testing with real data format...") - - real_like_data = create_synthetic_data(batch_size=1, channels=2, depth=16, height=256, width=256) - - normalized_data = (real_like_data - real_like_data.mean()) / real_like_data.std() - normalized_data = normalized_data.to(device) - - print(f"Normalized data stats - mean: {normalized_data.mean():.6f}, std: {normalized_data.std():.6f}") - - with torch.no_grad(): - real_output = model(normalized_data) - - losses = compute_vae_losses(model_output=real_output, target=normalized_data, beta=1.0) - print(f"\nPerfect reconstruction test (beta=1.0):") - print(f" Reconstruction Loss: {losses['recon_loss']:.6f}") - print(f" KL Loss: {losses['kl_loss']:.6f}") - print(f" Total Loss: {losses['total_loss']:.6f}") - - else: - raise NotImplementedError("not implemented") - - -def print_expected_ranges(): - """Print expected ranges for VAE loss components.""" - print("\n" + "="*60) - print("EXPECTED LOSS MAGNITUDE RANGES") - print("="*60) - print(""" -For Beta-VAE with normalized input (0-mean, 1-std): - -NOTES - -1. RECONSTRUCTION LOSS (MSE): - - Well-trained model: 0.01 - 0.1 - - Untrained/poorly trained: 0.5 - 2.0 - - Perfect reconstruction: < 0.001 - -2. KL DIVERGENCE LOSS: - - Posterior collapse (BAD): < 10 (model ignores latent space) - - Well-regularized: depends on latent dim, but should allow reconstruction - - Over-regularized (BAD): Forces posterior too close to prior, hurts reconstruction - - Typical untrained: can be very high as posterior is random - -3. BETA PARAMETER EFFECTS: - - Beta < 1.0: Prioritizes reconstruction (lower MSE, higher KL) - - Beta = 1.0: Standard VAE balance - - Beta > 1.0: Prioritizes disentanglement (higher MSE, lower KL) - """ - ) - - -if __name__ == "__main__": - print_expected_ranges() - test_vae_magnitudes() - print("\n=== Testing Complete ===") -# %% diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 696c81abc..8976094a5 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -5,6 +5,7 @@ from monai.transforms import ( CenterSpatialCropd, Decollated, + NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, @@ -188,12 +189,17 @@ def __init__( super().__init__(keys=keys, roi_size=roi_size, **kwargs) -class RandFlipd(RandFlipd): - def __init__( - self, - keys: Sequence[str] | str, - prob: float, - spatial_axis: Sequence[int] | int, - **kwargs, - ): - super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) + class RandFlipd(RandFlipd): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + spatial_axis: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) + + +class NormalizeIntensityd(NormalizeIntensityd): + def __init__(self, keys: Sequence[str] | str, **kwargs): + super().__init__(keys=keys, **kwargs) \ No newline at end of file From f5412d51838526f9f0a7d34a80180bd87951acda Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 13:08:18 -0700 Subject: [PATCH 058/101] fixing the vae_logging typing and removing PC plotting from here --- viscy/representation/engine.py | 9 --- viscy/representation/vae_logging.py | 91 +---------------------------- 2 files changed, 2 insertions(+), 98 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 67b8c594b..fa66528e7 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -754,15 +754,6 @@ def _log_enhanced_visualizations(self): lightning_module=self, n_dims=8, n_steps=7 ) - # Log latent space visualization (every 40 epochs to avoid overhead) - if self.current_epoch % 40 == 0: - self.vae_logger.log_latent_space_visualization( - lightning_module=self, - dataloader=val_dataloader, - max_samples=500, - method="pca", - ) - except Exception as e: _logger.error(f"Error logging enhanced visualizations: {e}") diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 875ac0edf..310f909d3 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -1,6 +1,6 @@ import io import logging -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import matplotlib.pyplot as plt import numpy as np @@ -342,95 +342,8 @@ def log_factor_traversal_matrix( dataformats="CHW", ) - def log_latent_space_visualization( - self, lightning_module, dataloader, max_samples: int = 500, method: str = "pca" - ): - """ - Log 2D visualization of latent space using PCA or t-SNE. - - Args: - lightning_module: Lightning module instance - dataloader: DataLoader for samples - max_samples: Maximum samples to visualize - method: Visualization method ("pca" or "tsne") - """ - if not hasattr(lightning_module, "model"): - return - - lightning_module.model.eval() - - # Collect latent representations - latents = [] - samples_collected = 0 - - with torch.no_grad(): - for batch in dataloader: - if samples_collected >= max_samples: - break - - x = batch["anchor"].to(lightning_module.device) - model_output = lightning_module(x) # Use lightning module forward - # Handle both Pythae dict format and object format - if isinstance(model_output, dict): - z = model_output["z"] - else: - z = ( - model_output.z - if hasattr(model_output, "z") - else model_output.embedding - ) - - latents.append(z.cpu().numpy()) - samples_collected += x.shape[0] - - if not latents: - return - - latents = np.vstack(latents)[:max_samples] - - # Apply dimensionality reduction - if method == "pca": - reducer = PCA(n_components=2) - reduced = reducer.fit_transform(latents) - title = f"PCA Latent Space (Variance: {reducer.explained_variance_ratio_.sum():.2f})" - elif method == "tsne": - reducer = TSNE(n_components=2, random_state=42) - reduced = reducer.fit_transform(latents) - title = "t-SNE Latent Space" - else: - _logger.warning(f"Unknown method: {method}") - return - - # Create scatter plot - plt.figure(figsize=(10, 8)) - plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=20) - plt.title(title) - plt.xlabel("Component 1") - plt.ylabel("Component 2") - plt.grid(True, alpha=0.3) - - # Convert to image - buf = io.BytesIO() - plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") - buf.seek(0) - - # Log to TensorBoard - img = Image.open(buf) - img_array = np.array(img) - img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) / 255.0 - - lightning_module.logger.experiment.add_image( - f"latent_space_{method}", - img_tensor, - lightning_module.current_epoch, - dataformats="CHW", - ) - - plt.close() - buf.close() - def log_beta_schedule( - self, lightning_module, beta_schedule: Optional[callable] = None + self, lightning_module, beta_schedule: Optional[Callable] = None ): """ Log β annealing schedule. From 2892af09fb7ab0682694c8685ade54d4cba48432 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 13:10:21 -0700 Subject: [PATCH 059/101] fixing the compute_embedding_smoothness docstring --- viscy/representation/evaluation/smoothness.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 1abfcc00a..2e9dedaf1 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -123,21 +123,17 @@ def compute_embeddings_smoothness( Returns: ------- stats: dict: Dictionary containing metrics including: - - adj_frame_mean: Mean of adjacent frame dissimilarity - - adj_frame_std: Standard deviation of adjacent frame dissimilarity - - adj_frame_median: Median of adjacent frame dissimilarity - - adj_frame_peak: Peak of adjacent frame distribution - - adj_frame_p99: 99th percentile of adjacent frame dissimilarity - - adj_frame_p1: 1st percentile of adjacent frame dissimilarity - - adj_frame_distribution: Full distribution of adjacent frame dissimilarities + - adjacent_frame_mean: Mean of adjacent frame dissimilarity + - adjacent_frame_std: Standard deviation of adjacent frame dissimilarity + - adjacent_frame_median: Median of adjacent frame dissimilarity + - adjacent_frame_peak: Peak of adjacent frame distribution - random_frame_mean: Mean of random sampling dissimilarity - random_frame_std: Standard deviation of random sampling dissimilarity - random_frame_median: Median of random sampling dissimilarity - random_frame_peak: Peak of random sampling distribution - - random_frame_distribution: Full distribution of random sampling dissimilarities - smoothness_score: Score of smoothness - dynamic_range: Difference between random and adjacent peaks - distributions: dict: Dictionary containing distributions including: + distributions: dict: Dictionary containing distributions including: - adjacent_frame_distribution: Full distribution of adjacent frame dissimilarities - random_frame_distribution: Full distribution of random sampling dissimilarities piecewise_distance_per_track: list[list[float]] From dc78397a989aef88cc2746cd96a3812b4f717d1f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 13:48:26 -0700 Subject: [PATCH 060/101] simplify the distance metrics and removing deprecated functions and scripts --- .../cosine_similarity.py | 0 .../evaluation/archive/displacement.py | 108 -------------- .../evaluation/test_distance.py | 107 ++++++++++++++ viscy/representation/evaluation/distance.py | 134 ++---------------- 4 files changed, 117 insertions(+), 232 deletions(-) rename applications/contrastive_phenotyping/evaluation/{pc_vs_computed_features => archive}/cosine_similarity.py (100%) delete mode 100644 applications/contrastive_phenotyping/evaluation/archive/displacement.py create mode 100644 tests/representation/evaluation/test_distance.py diff --git a/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/archive/cosine_similarity.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/cosine_similarity.py rename to applications/contrastive_phenotyping/evaluation/archive/cosine_similarity.py diff --git a/applications/contrastive_phenotyping/evaluation/archive/displacement.py b/applications/contrastive_phenotyping/evaluation/archive/displacement.py deleted file mode 100644 index a807c6134..000000000 --- a/applications/contrastive_phenotyping/evaluation/archive/displacement.py +++ /dev/null @@ -1,108 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.distance import ( - calculate_normalized_euclidean_distance_cell, - compute_displacement_mean_std_full, -) - -# %% paths - -features_path_30_min = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) - -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") - -features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") - -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) - -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - -# %% Load embedding datasets for all three sampling -fov_name = '/B/4/6' -track_id = 52 - -embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) -embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) -embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) - -#%% -# Calculate displacement for each sampling -time_points_30_min, cosine_similarities_30_min = calculate_normalized_euclidean_distance_cell(embedding_dataset_30_min, fov_name, track_id) -time_points_no_track, cosine_similarities_no_track = calculate_normalized_euclidean_distance_cell(embedding_dataset_no_track, fov_name, track_id) -time_points_any_time, cosine_similarities_any_time = calculate_normalized_euclidean_distance_cell(embedding_dataset_any_time, fov_name, track_id) - -# %% Plot displacement over time for all three conditions - -plt.figure(figsize=(10, 6)) - -plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') -plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') -plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') - -plt.xlabel("Time Delay (t)", fontsize=10) -plt.ylabel("Normalized Euclidean Distance with First Time Point", fontsize=10) -plt.title("Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12) - -plt.grid(True) -plt.legend(fontsize=10) - -#plt.savefig('4_euc_dist_full.svg', format='svg') -plt.show() - - -# %% Paths to datasets -features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") - -embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) -embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) - - -# %% -max_tau = 10 - -mean_displacement_30_min_euc, std_displacement_30_min_euc = compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) -mean_displacement_no_track_euc, std_displacement_no_track_euc = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) - -# %% Plot 2: Cosine Displacements -plt.figure(figsize=(10, 6)) - -taus = list(mean_displacement_30_min_euc.keys()) - -mean_values_30_min_euc = list(mean_displacement_30_min_euc.values()) -std_values_30_min_euc = list(std_displacement_30_min_euc.values()) - -plt.plot(taus, mean_values_30_min_euc, marker='o', label='Cell & Time Aware (30 min interval)', color='green') -plt.fill_between(taus, - np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), - np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), - color='green', alpha=0.3, label='Std Dev (30 min interval)') - -mean_values_no_track_euc = list(mean_displacement_no_track_euc.values()) -std_values_no_track_euc = list(std_displacement_no_track_euc.values()) - -plt.plot(taus, mean_values_no_track_euc, marker='o', label='Classical Contrastive (No Tracking)', color='blue') -plt.fill_between(taus, - np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), - np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), - color='blue', alpha=0.3, label='Std Dev (No Tracking)') - -plt.xlabel('Time Shift (τ)') -plt.ylabel('Euclidean Distance') -plt.title('Embedding Displacement Over Time (Features)') - -plt.grid(True) -plt.legend() - -plt.show() diff --git a/tests/representation/evaluation/test_distance.py b/tests/representation/evaluation/test_distance.py new file mode 100644 index 000000000..bdd75ecd3 --- /dev/null +++ b/tests/representation/evaluation/test_distance.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest +import xarray as xr + +from viscy.representation.evaluation.distance import ( + calculate_cosine_similarity_cell, + compute_track_displacement, +) + + +@pytest.fixture +def sample_embedding_dataset(): + """Create a sample embedding dataset for testing.""" + n_samples = 10 + n_features = 5 + + features = np.random.rand(n_samples, n_features) + fov_names = ["fov1"] * 5 + ["fov2"] * 5 + track_ids = [1, 1, 1, 2, 2, 3, 3, 3, 4, 4] + time_points = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1] + + dataset = xr.Dataset( + { + "features": (["sample", "features"], features), + "fov_name": (["sample"], fov_names), + "track_id": (["sample"], track_ids), + "t": (["sample"], time_points), + } + ) + return dataset + + +def test_calculate_cosine_similarity_cell(sample_embedding_dataset): + """Test cosine similarity calculation for a single track.""" + time_points, similarities = calculate_cosine_similarity_cell( + sample_embedding_dataset, "fov1", 1 + ) + + assert len(time_points) == len(similarities) + assert len(time_points) == 3 + assert np.isclose(similarities[0], 1.0, atol=1e-6) + assert all(-1 <= sim <= 1 for sim in similarities) + + +@pytest.mark.parametrize("distance_metric", ["cosine", "euclidean","sqeuclidean"]) +def test_compute_track_displacement(sample_embedding_dataset, distance_metric): + """Test track displacement computation with different metrics.""" + result = compute_track_displacement( + sample_embedding_dataset, distance_metric=distance_metric + ) + + assert isinstance(result, dict) + assert all(isinstance(tau, int) for tau in result.keys()) + assert all(isinstance(displacements, list) for displacements in result.values()) + assert all( + all(isinstance(d, (int, float)) and d >= 0 for d in displacements) + for displacements in result.values() + ) + +def test_compute_track_displacement_numerical(): + """Test compute_track_displacement with known embeddings and expected results.""" + features = np.array([ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ]) + + dataset = xr.Dataset({ + "features": (["sample", "features"], features), + "fov_name": (["sample"], ["fov1", "fov1", "fov1"]), + "track_id": (["sample"], [1, 1, 1]), + "t": (["sample"], [0, 1, 2]), + }) + result_euclidean = compute_track_displacement(dataset, distance_metric="euclidean") + + assert 1 in result_euclidean + assert 2 in result_euclidean + assert len(result_euclidean[1]) == 2 + assert len(result_euclidean[2]) == 1 + + result_sqeuclidean = compute_track_displacement(dataset, distance_metric="sqeuclidean") + expected_tau1 = [2.0, 1.0] + expected_tau2 = [1.0] + + assert np.allclose(sorted(result_sqeuclidean[1]), sorted(expected_tau1), atol=1e-10) + assert np.allclose(result_sqeuclidean[2], expected_tau2, atol=1e-10) + + + result_cosine = compute_track_displacement(dataset, distance_metric="cosine") + expected_cosine_tau1 = [1.0, 1 - 1/np.sqrt(2)] + expected_cosine_tau2 = [1 - 1/np.sqrt(2)] + + assert np.allclose(sorted(result_cosine[1]), sorted(expected_cosine_tau1), atol=1e-10) + assert np.allclose(result_cosine[2], expected_cosine_tau2, atol=1e-10) + + +def test_compute_track_displacement_empty_dataset(): + """Test behavior with empty dataset.""" + empty_dataset = xr.Dataset({ + "features": (["sample", "features"], np.empty((0, 5))), + "fov_name": (["sample"], []), + "track_id": (["sample"], []), + "t": (["sample"], []), + }) + + result = compute_track_displacement(empty_dataset) + assert result == {} \ No newline at end of file diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index ab9f3e05c..e55354be1 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -11,8 +11,6 @@ pairwise_distance_matrix, ) -_logger = logging.getLogger(__name__) - def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): """Extract embeddings and calculate cosine similarities for a specific cell""" @@ -21,8 +19,8 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): & (embedding_dataset["track_id"] == track_id), drop=True, ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) + features = filtered_data["features"].values + time_points = filtered_data["t"].values first_time_point_embedding = features[0].reshape(1, -1) cosine_similarities = cosine_similarity( first_time_point_embedding, features @@ -32,7 +30,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_track_displacement( embedding_dataset: xr.Dataset, - distance_metric: Literal["euclidean", "cosine"] = "cosine", + distance_metric: str = "cosine", ) -> dict[int, list[float]]: """ Compute Mean Squared Displacement using pairwise distance matrix. @@ -41,8 +39,10 @@ def compute_track_displacement( ---------- embedding_dataset : xr.Dataset Dataset containing embeddings and metadata - distance_metric : Literal["euclidean", "cosine"] + distance_metric : str Distance metric to use. Default is cosine. + See for other supported distance metrics. + https://github.com/scipy/scipy/blob/main/scipy/spatial/distance.py Returns ------- @@ -72,17 +72,9 @@ def compute_track_displacement( track_embeddings = track_data["features"].values[time_order] # Compute pairwise distance matrix - if distance_metric == "euclidean": - distance_matrix = pairwise_distance_matrix( - track_embeddings, metric="euclidean" - ) - distance_matrix = distance_matrix**2 # Square for MSD - elif distance_metric == "cosine": - distance_matrix = pairwise_distance_matrix( - track_embeddings, metric="cosine" - ) - else: - raise ValueError(f"Unsupported distance metric: {distance_metric}") + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric=distance_metric + ) # Extract displacements using diagonal offsets n_timepoints = len(times) @@ -93,110 +85,4 @@ def compute_track_displacement( tau = int(times[i + time_offset] - times[i]) displacement_per_tau[tau].append(displacement) - return dict(displacement_per_tau) - - -def compute_displacement_statistics( - displacement_per_tau: dict[int, list[float]], -) -> tuple[dict[int, float], dict[int, float]]: - """Compute mean and standard deviation of displacements for each tau. - - Parameters - ---------- - displacement_per_tau : dict[int, list[float]] - Dictionary mapping τ to list of displacements - - Returns - ------- - tuple[dict[int, float], dict[int, float]] - Tuple of (mean_displacements, std_displacements) where each is a - dictionary mapping τ to the statistic - """ - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - return mean_displacement_per_tau, std_displacement_per_tau - - -def compute_dynamic_range(mean_displacement_per_tau): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. - - Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value - - Returns: - float: dynamic range (max displacement - min displacement) - """ - displacements = list(mean_displacement_per_tau.values()) - return max(displacements) - min(displacements) - - -def compute_rms_per_track(embedding_dataset): - """ - Compute RMS of the time derivative of embeddings per track. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - - Returns: - list: A list of RMS values, one for each track. - """ - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - cell_identifiers = np.array( - list(zip(fov_names, track_ids)), - dtype=[("fov_name", "O"), ("track_id", "int64")], - ) - unique_cells = np.unique(cell_identifiers) - - rms_values = [] - - for cell in unique_cells: - fov_name = cell["fov_name"] - track_id = cell["track_id"] - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - cell_timepoints = timepoints[indices] - cell_embeddings = embeddings[indices] - - if len(cell_embeddings) < 2: - continue - - sorted_indices = np.argsort(cell_timepoints) - cell_embeddings = cell_embeddings[sorted_indices] - differences = np.diff(cell_embeddings, axis=0) - - if differences.shape[0] == 0: - continue - - norms = np.linalg.norm(differences, axis=1) - rms = np.sqrt(np.mean(norms**2)) - rms_values.append(rms) - - return rms_values - - -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) - first_time_point_embedding = normalized_features[0].reshape(1, -1) - euclidean_distances = np.linalg.norm( - first_time_point_embedding - normalized_features, axis=1 - ) - return time_points, euclidean_distances.tolist() + return dict(displacement_per_tau) \ No newline at end of file From 9a1af7054666197bf7a9cb6ff25eecddbecada14 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 13:50:11 -0700 Subject: [PATCH 061/101] remove deprecated functions from clustering.py --- viscy/representation/evaluation/clustering.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index fcd3964d6..66d9bda7b 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -151,31 +151,3 @@ def clustering_evaluation(embeddings, annotations, method="nmi"): return score - -def compute_track_msd_statistics( - msd_per_tau: dict[int, list[float]], -) -> tuple[dict[int, float], dict[int, float]]: - """ - Compute MSD statistics (mean and std) for a single track. - - Parameters - ---------- - features : ArrayLike - Feature matrix (n_timepoints, n_features) for a single track - timepoints : ArrayLike - Time points corresponding to each feature vector - metric : str, optional - Distance metric to use, by default "euclidean" - - Returns - ------- - tuple[dict[int, float], dict[int, float]] - Tuple of (mean_msd, std_msd) dictionaries mapping τ to statistics - """ - - mean_msd = { - tau: np.mean(displacements) for tau, displacements in msd_per_tau.items() - } - std_msd = {tau: np.std(displacements) for tau, displacements in msd_per_tau.items()} - - return mean_msd, std_msd From 6daf62e7ebffdc03da3c6f96ad5af1e701130df4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:03:32 -0700 Subject: [PATCH 062/101] add timelpase to grad_attr.py script --- .../figures/grad_attr.py | 414 ++++++++- .../figures/grad_attr_time.py | 833 ------------------ 2 files changed, 399 insertions(+), 848 deletions(-) delete mode 100644 applications/contrastive_phenotyping/figures/grad_attr_time.py diff --git a/applications/contrastive_phenotyping/figures/grad_attr.py b/applications/contrastive_phenotyping/figures/grad_attr.py index f0873c288..9d9ce7246 100644 --- a/applications/contrastive_phenotyping/figures/grad_attr.py +++ b/applications/contrastive_phenotyping/figures/grad_attr.py @@ -1,7 +1,10 @@ # %% +import logging +import warnings from pathlib import Path import matplotlib as mpl +import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -9,6 +12,13 @@ from cmap import Colormap from lightning.pytorch import seed_everything from skimage.exposure import rescale_intensity +from sklearn.metrics import ( + accuracy_score, + auc, + f1_score, + precision_recall_curve, + roc_auc_score, +) from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset @@ -19,21 +29,34 @@ fit_logistic_regression, linear_from_binary_logistic_regression, ) -from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd +from viscy.transforms import ( + Decollated, + NormalizeSampled, + ScaleIntensityRangePercentilesd, +) -# %% seed_everything(42, workers=True) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +# %% +# Dataset for display and occlusion analysis +data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +tracks_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +annotation_occlusion_infection_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" +annotation_occlusion_division_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" fov = "/B/4/8" -track = 44 +track = [44, 46] # %% dm = TripletDataModule( - data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr", + data_path=data_path, + tracks_path=tracks_path, source_channel=["Phase3D", "RFP"], z_range=[25, 40], - batch_size=48, + batch_size=1, num_workers=0, initial_yx_patch_size=(128, 128), final_yx_patch_size=(128, 128), @@ -44,10 +67,13 @@ ScaleIntensityRangePercentilesd( keys=["RFP"], lower=50, upper=99, b_min=0.0, b_max=1.0 ), + Decollated( + keys=["Phase3D","RFP"], + ) ], predict_cells=True, - include_fov_names=[fov], - include_track_ids=[track], + include_fov_names=[fov] * len(track), + include_track_ids=track, ) dm.setup("predict") len(dm.predict_dataset) @@ -67,6 +93,256 @@ ), ).eval() +# %% +def load_and_combine_datasets( + datasets, + target_type="infection", + standardization_mapping=None, +): + """Load and combine multiple embedding datasets with their annotations. + + Parameters + ---------- + datasets : list of tuple + List of (embedding_path, annotation_path, train_fovs) tuples containing + paths to embedding files, annotation CSV files, and training FOVs. + target_type : str, default='infection' + Type of classification target. Either 'infection' or 'division' - determines + which column to look for in the annotation files. + standardization_mapping : dict, optional + Dictionary to standardize different annotation formats across datasets. + Maps original values to standardized values. + Example: {'infected': 2, 'uninfected': 1, 'background': 0, + 2.0: 2, 1.0: 1, 0.0: 0, 'mitosis': 2, 'interphase': 1, 'unknown': 0} + + Returns + ------- + combined_features : xarray.DataArray + Combined feature embeddings from all successfully loaded datasets. + combined_annotations : pandas.Series + Combined and standardized annotations from all datasets. + + Raises + ------ + ValueError + If no datasets were successfully loaded. + """ + + all_features = [] + all_annotations = [] + + # Default standardization mappings + if standardization_mapping is None: + if target_type == "infection": + standardization_mapping = { + # String formats + "infected": 2, + "uninfected": 1, + "background": 0, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + 1: 1, + 0: 0, + } + elif target_type == "division": + standardization_mapping = { + # String formats + "mitosis": 2, + "interphase": 1, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + 1: 1, + 0: 0, + } + + for emb_path, ann_path, train_fovs in datasets: + try: + logger.debug(f"Loading dataset: {emb_path}") + dataset = read_embedding_dataset(emb_path) + + # Read annotation CSV to detect column names + logger.debug(f"Reading annotation CSV: {ann_path}") + ann_df = pd.read_csv(ann_path) + # make sure the ann_fov_names start with '/' otherwise add it, and strip whitespace + ann_df["fov_name"] = ann_df["fov_name"].apply( + lambda x: ( + "/" + x.strip() if not x.strip().startswith("/") else x.strip() + ) + ) + + if train_fovs == "all": + train_fovs = np.unique(dataset["fov_name"]) + + # Auto-detect annotation column based on target_type + annotation_key = None + if target_type == "infection": + for col in [ + "infection_state", + "infection", + "infection_status", + ]: + if col in ann_df.columns: + annotation_key = col + break + + elif target_type == "division": + for col in ["division", "cell_division", "cell_state"]: + if col in ann_df.columns: + annotation_key = col + break + + if annotation_key is None: + print(f" No {target_type} column found, skipping...") + continue + + # Filter the dataset to only include the FOVs in the annotation + # Use xarray's native filtering methods + ann_fov_names = set(ann_df["fov_name"].unique()) + train_fovs = set(train_fovs) + + logger.debug(f"Dataset FOVs: {dataset['fov_name'].values}") + logger.debug(f"Annotation FOV names: {ann_fov_names}") + logger.debug(f"Train FOVs: {train_fovs}") + logger.debug(f"Dataset samples before filtering: {len(dataset.sample)}") + + # Filter and get only the intersection of train_fovs and ann_fov_names + common_fovs = train_fovs.intersection(ann_fov_names) + # missed out fovs in the dataset + missed_fovs = train_fovs - common_fovs + # missed out fovs in the annotations + missed_fovs_ann = ann_fov_names - common_fovs + + if len(common_fovs) == 0: + raise ValueError( + f"No common FOVs found between dataset and annotations: {train_fovs} not in {ann_fov_names}" + ) + elif len(missed_fovs) > 0: + warnings.warn( + f"No matching found for FOVs in the train dataset: {missed_fovs}" + ) + elif len(missed_fovs_ann) > 0: + warnings.warn( + f"No matching found for FOVs in the annotations: {missed_fovs_ann}" + ) + + logger.debug(f"Intersection of train_fovs and ann_fov_names: {common_fovs}") + + # Filter the dataset to only include the intersection of train_fovs and ann_fov_names + dataset = dataset.where( + dataset["fov_name"].isin(list(common_fovs)), drop=True + ) + + logger.debug(f"Dataset samples after filtering: {len(dataset.sample)}") + + # Load annotations without class mapping first + annotations = load_annotation(dataset, ann_path, annotation_key) + + # Check unique values before standardization + unique_vals = annotations.unique() + logger.debug(f"Original unique values: {unique_vals}") + + # Apply standardization mapping + standardized_annotations = annotations.copy() + if standardization_mapping: + for original_val, standard_val in standardization_mapping.items(): + mask = annotations == original_val + if mask.any(): + standardized_annotations[mask] = standard_val + logger.debug( + f"Mapped {original_val} -> {standard_val} ({mask.sum()} instances)" + ) + + # Check standardized values + std_unique_vals = standardized_annotations.unique() + logger.debug(f"Standardized unique values: {std_unique_vals}") + + # Convert to categorical for consistency + standardized_annotations = standardized_annotations.astype("category") + + # Keep features as xarray DataArray for compatibility with fit_logistic_regression + all_features.append(dataset["features"]) + all_annotations.append(standardized_annotations) + + logger.debug(f"Features shape: {dataset['features'].shape}") + logger.debug(f"Annotations shape: {standardized_annotations.shape}") + except Exception as e: + raise ValueError(f"Error loading dataset {emb_path}: {e}") + + # Combine all datasets + if all_features: + # Extract features and coordinates from each dataset + all_features_arrays = [] + all_coords = [] + + for dataset in all_features: + # Extract the features array + features_array = dataset["features"].values + all_features_arrays.append(features_array) + + # Extract coordinates + coords_dict = {} + for coord_name in dataset.coords: + if coord_name != "sample": # skip sample coordinate + coords_dict[coord_name] = dataset.coords[coord_name].values + all_coords.append(coords_dict) + + # Combine feature arrays + combined_features_array = np.concatenate(all_features_arrays, axis=0) + + # Combine coordinates (excluding 'features' from coordinates) + combined_coords = {} + for coord_name in all_coords[0].keys(): + if coord_name != "features": # Don't include 'features' in coordinates + coord_values = [] + for coords_dict in all_coords: + coord_values.extend(coords_dict[coord_name]) + combined_coords[coord_name] = coord_values + + # Create new combined dataset in the correct format + coords_dict = { + "sample": range(len(combined_features_array)), + } + + # Add each coordinate as a 1D coordinate along the sample dimension + for coord_name, coord_values in combined_coords.items(): + coords_dict[coord_name] = ("sample", coord_values) + + combined_dataset = xr.Dataset( + { + "features": (("sample", "features"), combined_features_array), + }, + coords=coords_dict, + ) + + # Set the index properly like the original datasets + if "fov_name" in combined_coords: + available_coords = [ + coord + for coord in combined_coords.keys() + if coord in ["fov_name", "track_id", "t"] + ] + combined_dataset = combined_dataset.set_index(sample=available_coords) + + combined_annotations = pd.concat(all_annotations, ignore_index=True) + + logger.debug(f"Combined features shape: {combined_dataset['features'].shape}") + logger.debug(f"Combined annotations shape: {combined_annotations.shape}") + + # Final check of combined annotations + final_unique = combined_annotations.unique() + logger.debug(f"Final combined unique values: {final_unique}") + + return combined_dataset["features"], combined_annotations + + # %% # train linear classifier path_infection_embedding = Path( @@ -149,7 +425,7 @@ ) track_classes_infection = infection[infection["fov_name"] == fov[1:]] track_classes_infection = track_classes_infection[ - track_classes_infection["track_id"] == track + track_classes_infection["track_id"].isin(track) ]["infection_state"] # %% @@ -159,13 +435,17 @@ ) track_classes_division = division[division["fov_name"] == fov[1:]] track_classes_division = track_classes_division[ - track_classes_division["track_id"] == track + track_classes_division["track_id"].isin(track) ]["division"] # %% +# Loading the lineage images +img = [] for sample in dm.predict_dataloader(): - img = sample["anchor"].numpy() + img.append(sample["anchor"].numpy()) +img = np.concatenate(img, axis=0) +print(f"Loaded images with shape: {img.shape}") # %% img_tensor = torch.from_numpy(img).to(model.device) @@ -217,18 +497,15 @@ def clim_percentile(heatmap, low=1, high=99): np.concatenate([phase_heatmap_div, rfp_heatmap_div], axis=2), -g_lim, g_lim ) - # %% plt.style.use("./figure.mplstyle") selected_time_points = [3, 6, 15, 16] selected_div_states = [False] * 3 + [True] -sps = len(selected_time_points) - icefire = Colormap("icefire").to_mpl() -f, ax = plt.subplots(3, sps, figsize=(5.5, 3), layout="compressed") +f, ax = plt.subplots(3, len(selected_time_points), figsize=(5.5, 3), layout="compressed") for i, time in enumerate(selected_time_points): hpi = 3 + 0.5 * time prob = infection_probs[time].item() @@ -264,3 +541,110 @@ def clim_percentile(heatmap, low=1, high=99): ) # %% +# Create video animation of occlusion analysis +icefire = Colormap("icefire").to_mpl() +plt.style.use("./figure.mplstyle") + +fig, ax = plt.subplots(3, 1, figsize=(6, 8), layout="compressed") + +# Initialize plots +im1 = ax[0].imshow(img_render[0], cmap="gray") +ax[0].set_title("Original Image") +ax[0].axis("off") + +im2 = ax[1].imshow(inf_render[0], cmap=icefire, vmin=0, vmax=1) +ax[1].set_title("Infection Occlusion Attribution") +ax[1].axis("off") + +im3 = ax[2].imshow(div_render[0], cmap=icefire, vmin=0, vmax=1) +ax[2].set_title("Division Occlusion Attribution") +ax[2].axis("off") + +# Store initial border colors +for a in ax: + for spine in a.spines.values(): + spine.set_linewidth(3) + spine.set_color("black") + +# Add colorbar +norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) +cbar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=icefire), + ax=ax[1:], + orientation="horizontal", + shrink=0.8, + pad=0.1, +) +cbar.set_label("Occlusion Attribution") + + +# Animation function +def animate(frame): + time = frame + hpi = 3 + 0.5 * time + + # Update images + im1.set_array(img_render[time]) + im2.set_array(inf_render[time]) + im3.set_array(div_render[time]) + + # Update titles with probabilities + inf_prob = infection_probs[time].item() + div_prob = division_probs[time].item() + inf_binary = bool(track_classes_infection.iloc[time] - 1) + div_binary = bool(track_classes_division.iloc[time] - 1) + + # Color code labels - red for true, green for false + inf_color = "darkorange" if inf_binary else "blue" + div_color = "darkorange" if div_binary else "blue" + + # Make label text bold when true + inf_weight = "bold" if inf_binary else "normal" + div_weight = "bold" if div_binary else "normal" + + # Update border colors to highlight true labels + for spine in ax[1].spines.values(): + spine.set_color(inf_color) + spine.set_linewidth(4 if inf_binary else 2) + + for spine in ax[2].spines.values(): + spine.set_color(div_color) + spine.set_linewidth(4 if div_binary else 2) + + ax[0].set_title(f"Original Image - {hpi:.1f} HPI", fontsize=12, fontweight="bold") + ax[1].set_title( + f"Infection Attribution - Prob: {inf_prob:.3f} (Label: {str(inf_binary).lower()})", + fontsize=12, + fontweight=inf_weight, + color=inf_color, + ) + ax[2].set_title( + f"Division Attribution - Prob: {div_prob:.3f} (Label: {str(div_binary).lower()})", + fontsize=12, + fontweight=div_weight, + color=div_color, + ) + + return [im1, im2, im3] + +#%% + +# Create animation +anim = animation.FuncAnimation( + fig, animate, frames=len(img_render), interval=200, blit=True, repeat=True +) + +# Save as video +video_path = ( + Path.home() + / "mydata" + / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/occlusion_analysis_video.mp4" +) +video_path.parent.mkdir(parents=True, exist_ok=True) + +# Save as MP4 +Writer = animation.writers["ffmpeg"] +writer = Writer(fps=5, metadata=dict(artist="VisCy"), bitrate=1800) +anim.save(str(video_path), writer=writer) + +print(f"Video saved to: {video_path}") \ No newline at end of file diff --git a/applications/contrastive_phenotyping/figures/grad_attr_time.py b/applications/contrastive_phenotyping/figures/grad_attr_time.py deleted file mode 100644 index fbe585241..000000000 --- a/applications/contrastive_phenotyping/figures/grad_attr_time.py +++ /dev/null @@ -1,833 +0,0 @@ -# %% -import logging -import warnings -from pathlib import Path - -import matplotlib as mpl -import matplotlib.animation as animation -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import xarray as xr -from cmap import Colormap -from lightning.pytorch import seed_everything -from skimage.exposure import rescale_intensity -from sklearn.metrics import ( - accuracy_score, - auc, - f1_score, - precision_recall_curve, - roc_auc_score, -) - -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.engine import ContrastiveEncoder, ContrastiveModule -from viscy.representation.evaluation import load_annotation -from viscy.representation.evaluation.lca import ( - AssembledClassifier, - fit_logistic_regression, - linear_from_binary_logistic_regression, -) -from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd - -seed_everything(42, workers=True) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# %% -# Dataset for display and occlusion analysis -data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -tracks_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -annotation_occlusion_infection_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" -annotation_occlusion_division_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" -fov = "/B/4/8" -track = [44, 46] - -# %% -dm = TripletDataModule( - data_path=data_path, - tracks_path=tracks_path, - source_channel=["Phase3D", "RFP"], - z_range=[25, 40], - batch_size=48, - num_workers=0, - initial_yx_patch_size=(128, 128), - final_yx_patch_size=(128, 128), - normalizations=[ - NormalizeSampled( - keys=["Phase3D"], level="fov_statistics", subtrahend="mean", divisor="std" - ), - ScaleIntensityRangePercentilesd( - keys=["RFP"], lower=50, upper=99, b_min=0.0, b_max=1.0 - ), - ], - predict_cells=True, - include_fov_names=[fov] * len(track), - include_track_ids=track, -) -dm.setup("predict") -len(dm.predict_dataset) - -# %% -# load model -model = ContrastiveModule.load_from_checkpoint( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/epoch=178-step=16826.ckpt", - encoder=ContrastiveEncoder( - backbone="convnext_tiny", - in_channels=2, - in_stack_depth=15, - stem_kernel_size=(5, 4, 4), - stem_stride=(5, 4, 4), - embedding_dim=768, - projection_dim=32, - ), -).eval() - -# %% -# TODO add the patsh to the combination of sec61 and tomm20 -# train linear classifier -# INFECTION -## Embedding and Annotations - -path_infection_embedding_1 = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) - -path_annotations_infection_1 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" -) -# TOMM20 -path_infection_embedding_2 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/1-predictions/sensor_160patch_99ckpt_max.zarr" -) -path_annotations_infection_2 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_infection_annotation.csv" -) - -# SEC61 -path_infection_embedding_3 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/prediction_infection/2chan_192patch_100ckpt_timeAware_ntxent_rerun.zarr" -) - -path_annotations_infection_3 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_infection_annotation.csv" -) - -# CELL DIVISION -path_annotations_division_1 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" -) -path_division_embedding_1 = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178_gt_tracks.zarr" -) -# TOMM20 -path_annotations_division_2 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv" -) -# SEC61 -path_annotations_division_3 = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_cell_state_annotation.csv" -) -# %% -######### -# Make tuple of tuples of embedding and annotations - -# Train FOVs - use a broader set since we have multiple datasets - -infection_classifier_pairs = ( - ( - path_infection_embedding_1, - path_annotations_infection_1, - ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/6", "/B/4/7"], - ), - (path_infection_embedding_2, path_annotations_infection_2, "all"), - (path_infection_embedding_3, path_annotations_infection_3, "all"), -) - -# NOTE: embedding 1 and annotations 1 are not used. They are not wll annotated for division -division_classifier_pairs = ( - # ( - # path_division_embedding_1, - # path_annotations_division_1, - # ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/6", "/B/4/7"], - # ), - # (path_infection_embedding_2, path_annotations_division_2, "all"), - (path_infection_embedding_3, path_annotations_division_3, "all"), -) - - -def load_and_combine_datasets( - datasets, - target_type="infection", - standardization_mapping=None, -): - """Load and combine multiple embedding datasets with their annotations. - - Parameters - ---------- - datasets : list of tuple - List of (embedding_path, annotation_path, train_fovs) tuples containing - paths to embedding files, annotation CSV files, and training FOVs. - target_type : str, default='infection' - Type of classification target. Either 'infection' or 'division' - determines - which column to look for in the annotation files. - standardization_mapping : dict, optional - Dictionary to standardize different annotation formats across datasets. - Maps original values to standardized values. - Example: {'infected': 2, 'uninfected': 1, 'background': 0, - 2.0: 2, 1.0: 1, 0.0: 0, 'mitosis': 2, 'interphase': 1, 'unknown': 0} - - Returns - ------- - combined_features : xarray.DataArray - Combined feature embeddings from all successfully loaded datasets. - combined_annotations : pandas.Series - Combined and standardized annotations from all datasets. - - Raises - ------ - ValueError - If no datasets were successfully loaded. - """ - - all_features = [] - all_annotations = [] - - # Default standardization mappings - if standardization_mapping is None: - if target_type == "infection": - standardization_mapping = { - # String formats - "infected": 2, - "uninfected": 1, - "background": 0, - "unknown": 0, - # Numeric formats - 2.0: 2, - 1.0: 1, - 0.0: 0, - 2: 2, - 1: 1, - 0: 0, - } - elif target_type == "division": - standardization_mapping = { - # String formats - "mitosis": 2, - "interphase": 1, - "unknown": 0, - # Numeric formats - 2.0: 2, - 1.0: 1, - 0.0: 0, - 2: 2, - 1: 1, - 0: 0, - } - - for emb_path, ann_path, train_fovs in datasets: - try: - logger.debug(f"Loading dataset: {emb_path}") - dataset = read_embedding_dataset(emb_path) - - # Read annotation CSV to detect column names - logger.debug(f"Reading annotation CSV: {ann_path}") - ann_df = pd.read_csv(ann_path) - # make sure the ann_fov_names start with '/' otherwise add it, and strip whitespace - ann_df["fov_name"] = ann_df["fov_name"].apply( - lambda x: ( - "/" + x.strip() if not x.strip().startswith("/") else x.strip() - ) - ) - - if train_fovs == "all": - train_fovs = np.unique(dataset["fov_name"]) - - # Auto-detect annotation column based on target_type - annotation_key = None - if target_type == "infection": - for col in [ - "infection_state", - "infection", - "infection_status", - ]: - if col in ann_df.columns: - annotation_key = col - break - - elif target_type == "division": - for col in ["division", "cell_division", "cell_state"]: - if col in ann_df.columns: - annotation_key = col - break - - if annotation_key is None: - print(f" No {target_type} column found, skipping...") - continue - - # Filter the dataset to only include the FOVs in the annotation - # Use xarray's native filtering methods - ann_fov_names = set(ann_df["fov_name"].unique()) - train_fovs = set(train_fovs) - - logger.debug(f"Dataset FOVs: {dataset['fov_name'].values}") - logger.debug(f"Annotation FOV names: {ann_fov_names}") - logger.debug(f"Train FOVs: {train_fovs}") - logger.debug(f"Dataset samples before filtering: {len(dataset.sample)}") - - # Filter and get only the intersection of train_fovs and ann_fov_names - common_fovs = train_fovs.intersection(ann_fov_names) - # missed out fovs in the dataset - missed_fovs = train_fovs - common_fovs - # missed out fovs in the annotations - missed_fovs_ann = ann_fov_names - common_fovs - - if len(common_fovs) == 0: - raise ValueError( - f"No common FOVs found between dataset and annotations: {train_fovs} not in {ann_fov_names}" - ) - elif len(missed_fovs) > 0: - warnings.warn( - f"No matching found for FOVs in the train dataset: {missed_fovs}" - ) - elif len(missed_fovs_ann) > 0: - warnings.warn( - f"No matching found for FOVs in the annotations: {missed_fovs_ann}" - ) - - logger.debug(f"Intersection of train_fovs and ann_fov_names: {common_fovs}") - - # Filter the dataset to only include the intersection of train_fovs and ann_fov_names - dataset = dataset.where( - dataset["fov_name"].isin(list(common_fovs)), drop=True - ) - - logger.debug(f"Dataset samples after filtering: {len(dataset.sample)}") - - # Load annotations without class mapping first - annotations = load_annotation(dataset, ann_path, annotation_key) - - # Check unique values before standardization - unique_vals = annotations.unique() - logger.debug(f"Original unique values: {unique_vals}") - - # Apply standardization mapping - standardized_annotations = annotations.copy() - if standardization_mapping: - for original_val, standard_val in standardization_mapping.items(): - mask = annotations == original_val - if mask.any(): - standardized_annotations[mask] = standard_val - logger.debug( - f"Mapped {original_val} -> {standard_val} ({mask.sum()} instances)" - ) - - # Check standardized values - std_unique_vals = standardized_annotations.unique() - logger.debug(f"Standardized unique values: {std_unique_vals}") - - # Convert to categorical for consistency - standardized_annotations = standardized_annotations.astype("category") - - # Keep features as xarray DataArray for compatibility with fit_logistic_regression - all_features.append(dataset["features"]) - all_annotations.append(standardized_annotations) - - logger.debug(f"Features shape: {dataset['features'].shape}") - logger.debug(f"Annotations shape: {standardized_annotations.shape}") - except Exception as e: - raise ValueError(f"Error loading dataset {emb_path}: {e}") - - # Combine all datasets - if all_features: - # Extract features and coordinates from each dataset - all_features_arrays = [] - all_coords = [] - - for dataset in all_features: - # Extract the features array - features_array = dataset["features"].values - all_features_arrays.append(features_array) - - # Extract coordinates - coords_dict = {} - for coord_name in dataset.coords: - if coord_name != "sample": # skip sample coordinate - coords_dict[coord_name] = dataset.coords[coord_name].values - all_coords.append(coords_dict) - - # Combine feature arrays - combined_features_array = np.concatenate(all_features_arrays, axis=0) - - # Combine coordinates (excluding 'features' from coordinates) - combined_coords = {} - for coord_name in all_coords[0].keys(): - if coord_name != "features": # Don't include 'features' in coordinates - coord_values = [] - for coords_dict in all_coords: - coord_values.extend(coords_dict[coord_name]) - combined_coords[coord_name] = coord_values - - # Create new combined dataset in the correct format - coords_dict = { - "sample": range(len(combined_features_array)), - } - - # Add each coordinate as a 1D coordinate along the sample dimension - for coord_name, coord_values in combined_coords.items(): - coords_dict[coord_name] = ("sample", coord_values) - - combined_dataset = xr.Dataset( - { - "features": (("sample", "features"), combined_features_array), - }, - coords=coords_dict, - ) - - # Set the index properly like the original datasets - if "fov_name" in combined_coords: - available_coords = [ - coord - for coord in combined_coords.keys() - if coord in ["fov_name", "track_id", "t"] - ] - combined_dataset = combined_dataset.set_index(sample=available_coords) - - combined_annotations = pd.concat(all_annotations, ignore_index=True) - - logger.debug(f"Combined features shape: {combined_dataset['features'].shape}") - logger.debug(f"Combined annotations shape: {combined_annotations.shape}") - - # Final check of combined annotations - final_unique = combined_annotations.unique() - logger.debug(f"Final combined unique values: {final_unique}") - - return combined_dataset["features"], combined_annotations - - -# %% - -# Load and combine infection datasets -logger.info("Loading infection classification datasets...") -infection_features, infection_labels = load_and_combine_datasets( - infection_classifier_pairs, - target_type="infection", -) -# %% -# Load and combine division datasets -logger.info("Loading division classification datasets...") -division_features, division_labels = load_and_combine_datasets( - division_classifier_pairs, - target_type="division", -) - - -# %% - -logistic_regression_infection, _ = fit_logistic_regression( - features=infection_features.copy(), - annotations=infection_labels.copy(), - train_ratio=0.8, - remove_background_class=True, - scale_features=True, - class_weight="balanced", - solver="liblinear", - random_state=42, -) -# %% - -logistic_regression_division, _ = fit_logistic_regression( - division_features.copy(), - division_labels.copy(), - train_ratio=0.8, - remove_background_class=True, - scale_features=True, - class_weight="balanced", - solver="liblinear", - random_state=42, -) - -# %% -linear_classifier_infection = linear_from_binary_logistic_regression( - logistic_regression_infection -) -assembled_classifier_infection = ( - AssembledClassifier(model.model, linear_classifier_infection) - .eval() - .to(model.device) -) - -# %% -linear_classifier_division = linear_from_binary_logistic_regression( - logistic_regression_division -) -assembled_classifier_division = ( - AssembledClassifier(model.model, linear_classifier_division).eval().to(model.device) -) - - -# %% -# Loading the lineage images -img = [] -for sample in dm.predict_dataloader(): - img.append(sample["anchor"].numpy()) -img = np.concatenate(img, axis=0) -print(img.shape) - -# %% -img_tensor = torch.from_numpy(img).to(model.device) - -with torch.inference_mode(): - infection_probs = assembled_classifier_infection(img_tensor).sigmoid() - division_probs = assembled_classifier_division(img_tensor).sigmoid() - -# %% -attr_kwargs = dict( - img=img_tensor, - sliding_window_shapes=(1, 15, 12, 12), - strides=(1, 15, 4, 4), - show_progress=True, -) - - -infection_attribution = ( - assembled_classifier_infection.attribute_occlusion(**attr_kwargs).cpu().numpy() -) -division_attribution = ( - assembled_classifier_division.attribute_occlusion(**attr_kwargs).cpu().numpy() -) - - -# %% -def clip_rescale(img, low, high): - return rescale_intensity(img.clip(low, high), out_range=(0, 1)) - - -def clim_percentile(heatmap, low=1, high=99): - lo, hi = np.percentile(heatmap, (low, high)) - return clip_rescale(heatmap, lo, hi) - - -g_lim = 1 -z_slice = 5 -phase = clim_percentile(img[:, 0, z_slice]) -rfp = clim_percentile(img[:, 1, z_slice]) -img_render = np.concatenate([phase, rfp], axis=2) -phase_heatmap_inf = infection_attribution[:, 0, z_slice] -rfp_heatmap_inf = infection_attribution[:, 1, z_slice] -inf_render = clip_rescale( - np.concatenate([phase_heatmap_inf, rfp_heatmap_inf], axis=2), -g_lim, g_lim -) -phase_heatmap_div = division_attribution[:, 0, z_slice] -rfp_heatmap_div = division_attribution[:, 1, z_slice] -div_render = clip_rescale( - np.concatenate([phase_heatmap_div, rfp_heatmap_div], axis=2), -g_lim, g_lim -) - - -# %% -# Filter the dataframe to only include the fovs and track_id of the current fov -infection = pd.read_csv(annotation_occlusion_infection_path) -infection = infection[infection["fov_name"] == fov[1:]] -infection = infection[infection["track_id"].isin(track)] -track_classes_infection = infection["infection_state"] - -# load division annotations -division = pd.read_csv(annotation_occlusion_division_path) -division = division[division["fov_name"] == fov[1:]] -division = division[division["track_id"].isin(track)] - -division["division"] = 1 # default: not dividing -division.loc[division["t"].between(16, 22, inclusive="both"), "division"] = ( - 2 # dividing for t in 16-20 -) - -track_classes_division = division["division"] - - -# %% -plt.style.use("./figure.mplstyle") - -all_time_points = list(range(len(img_render))) -selected_time_points = all_time_points[ - :: max(1, len(all_time_points) // 8) -] # Show up to 8 time points - - -sps = len(selected_time_points) - -icefire = Colormap("icefire").to_mpl() - -f, ax = plt.subplots(3, sps, figsize=(2 * sps, 3), layout="compressed") -for i, time in enumerate(selected_time_points): - hpi = 3 + 0.5 * time - prob = infection_probs[time].item() - inf_binary = str(bool(track_classes_infection.iloc[time] - 1)).lower() - div_binary = str(bool(track_classes_division.iloc[time] - 1)).lower() - ax[0, i].imshow(img_render[time], cmap="gray") - ax[0, i].set_title(f"{hpi} HPI") - ax[1, i].imshow(inf_render[time], cmap=icefire, vmin=0, vmax=1) - ax[1, i].set_title( - f"infected: {prob:.3f}\n" f"label: {inf_binary}", - ) - ax[2, i].imshow(div_render[time], cmap=icefire, vmin=0, vmax=1) - ax[2, i].set_title( - f"dividing: {division_probs[time].item():.3f}\n" f"label: {div_binary}", - ) -for a in ax.ravel(): - a.axis("off") -norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) -cbar = f.colorbar( - mpl.cm.ScalarMappable(norm=norm, cmap=icefire), - orientation="vertical", - ax=ax[1:].ravel().tolist(), - format=mpl.ticker.StrMethodFormatter("{x:.1f}"), -) -cbar.set_label("occlusion attribution") - -# %% -# f.savefig( -# Path.home() -# / "mydata" -# / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/fig_explanation_patch12_stride4.pdf", -# dpi=300, -# ) - -# %% -# Create video animation of occlusion analysis -icefire = Colormap("icefire").to_mpl() -plt.style.use("./figure.mplstyle") - -fig, ax = plt.subplots(3, 1, figsize=(6, 8), layout="compressed") - -# Initialize plots -im1 = ax[0].imshow(img_render[0], cmap="gray") -ax[0].set_title("Original Image") -ax[0].axis("off") - -im2 = ax[1].imshow(inf_render[0], cmap=icefire, vmin=0, vmax=1) -ax[1].set_title("Infection Occlusion Attribution") -ax[1].axis("off") - -im3 = ax[2].imshow(div_render[0], cmap=icefire, vmin=0, vmax=1) -ax[2].set_title("Division Occlusion Attribution") -ax[2].axis("off") - -# Store initial border colors -for a in ax: - for spine in a.spines.values(): - spine.set_linewidth(3) - spine.set_color("black") - -# Add colorbar -norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) -cbar = fig.colorbar( - mpl.cm.ScalarMappable(norm=norm, cmap=icefire), - ax=ax[1:], - orientation="horizontal", - shrink=0.8, - pad=0.1, -) -cbar.set_label("Occlusion Attribution") - - -# Animation function -def animate(frame): - time = frame - hpi = 3 + 0.5 * time - - # Update images - im1.set_array(img_render[time]) - im2.set_array(inf_render[time]) - im3.set_array(div_render[time]) - - # Update titles with probabilities - inf_prob = infection_probs[time].item() - div_prob = division_probs[time].item() - inf_binary = bool(track_classes_infection.iloc[time] - 1) - div_binary = bool(track_classes_division.iloc[time] - 1) - - # Color code labels - red for true, green for false - inf_color = "darkorange" if inf_binary else "blue" - div_color = "darkorange" if div_binary else "blue" - - # Make label text bold when true - inf_weight = "bold" if inf_binary else "normal" - div_weight = "bold" if div_binary else "normal" - - # Update border colors to highlight true labels - for spine in ax[1].spines.values(): - spine.set_color(inf_color) - spine.set_linewidth(4 if inf_binary else 2) - - for spine in ax[2].spines.values(): - spine.set_color(div_color) - spine.set_linewidth(4 if div_binary else 2) - - ax[0].set_title(f"Original Image - {hpi:.1f} HPI", fontsize=12, fontweight="bold") - ax[1].set_title( - f"Infection Attribution - Prob: {inf_prob:.3f} (Label: {str(inf_binary).lower()})", - fontsize=12, - fontweight=inf_weight, - color=inf_color, - ) - ax[2].set_title( - f"Division Attribution - Prob: {div_prob:.3f} (Label: {str(div_binary).lower()})", - fontsize=12, - fontweight=div_weight, - color=div_color, - ) - - return [im1, im2, im3] - - -# Create animation -anim = animation.FuncAnimation( - fig, animate, frames=len(img_render), interval=200, blit=True, repeat=True -) - -# %% -# Save as video -video_path = ( - Path.home() - / "mydata" - / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/occlusion_analysis_video.mp4" -) -video_path.parent.mkdir(parents=True, exist_ok=True) - -# Save as MP4 -Writer = animation.writers["ffmpeg"] -writer = Writer(fps=5, metadata=dict(artist="VisCy"), bitrate=1800) -anim.save(str(video_path), writer=writer) - -print(f"Video saved to: {video_path}") - - -# %% -# Performance metrics over time -def calculate_metrics_over_time(y_true, y_pred_probs, threshold=0.5): - """Calculate accuracy, F1, and AUC for each time point""" - y_pred = (y_pred_probs > threshold).astype(int) - - metrics = {"accuracy": [], "f1": [], "auc": []} - - for i in range(len(y_true)): - # Get predictions up to current time point - true_up_to_i = y_true[: i + 1] - pred_up_to_i = y_pred[: i + 1] - prob_up_to_i = y_pred_probs[: i + 1] - - # Skip if we don't have both classes - if len(np.unique(true_up_to_i)) < 2: - metrics["accuracy"].append(np.nan) - metrics["f1"].append(np.nan) - metrics["auc"].append(np.nan) - continue - - # Calculate metrics - acc = accuracy_score(true_up_to_i, pred_up_to_i) - f1 = f1_score(true_up_to_i, pred_up_to_i, average="binary") - try: - auc_score = roc_auc_score(true_up_to_i, prob_up_to_i) - except: - auc_score = np.nan - - metrics["accuracy"].append(acc) - metrics["f1"].append(f1) - metrics["auc"].append(auc_score) - - return metrics - - -# Ensure we have matching lengths - use the minimum length -min_length = min( - len(track_classes_infection), len(track_classes_division), len(infection_probs) -) - -# Convert labels to binary for metrics calculation - truncate to min_length -inf_true = (track_classes_infection.values[:min_length] - 1).astype(bool).astype(int) -div_true = track_classes_division.values[:min_length].astype(bool).astype(int) - -inf_probs = infection_probs[:min_length].cpu().numpy() -div_probs = division_probs[:min_length].cpu().numpy() - -print(f"Using {min_length} time points for metrics calculation") -print(f"Infection labels shape: {inf_true.shape}") -print(f"Division labels shape: {div_true.shape}") -print(f"Infection probs shape: {inf_probs.shape}") -print(f"Division probs shape: {div_probs.shape}") - -# Calculate metrics -inf_metrics = calculate_metrics_over_time(inf_true, inf_probs) -div_metrics = calculate_metrics_over_time(div_true, div_probs) - -# Time points -time_points = np.arange(len(inf_true)) -hpi_values = 3 + 0.5 * time_points - -# Create metrics plot -fig, axes = plt.subplots(2, 3, figsize=(15, 8), layout="compressed") - -# Infection metrics -axes[0, 0].plot( - hpi_values, inf_metrics["accuracy"], "b-", linewidth=2, label="Accuracy" -) -axes[0, 0].set_title("Infection Classification Accuracy Over Time") -axes[0, 0].set_xlabel("Hours Post Infection (HPI)") -axes[0, 0].set_ylabel("Accuracy") -axes[0, 0].grid(True, alpha=0.3) -axes[0, 0].set_ylim(0, 1) - -axes[0, 1].plot(hpi_values, inf_metrics["f1"], "g-", linewidth=2, label="F1 Score") -axes[0, 1].set_title("Infection Classification F1 Score Over Time") -axes[0, 1].set_xlabel("Hours Post Infection (HPI)") -axes[0, 1].set_ylabel("F1 Score") -axes[0, 1].grid(True, alpha=0.3) -axes[0, 1].set_ylim(0, 1) - -axes[0, 2].plot(hpi_values, inf_metrics["auc"], "r-", linewidth=2, label="AUC") -axes[0, 2].set_title("Infection Classification AUC Over Time") -axes[0, 2].set_xlabel("Hours Post Infection (HPI)") -axes[0, 2].set_ylabel("AUC") -axes[0, 2].grid(True, alpha=0.3) -axes[0, 2].set_ylim(0, 1) - -# Division metrics -axes[1, 0].plot( - hpi_values, div_metrics["accuracy"], "b-", linewidth=2, label="Accuracy" -) -axes[1, 0].set_title("Division Classification Accuracy Over Time") -axes[1, 0].set_xlabel("Hours Post Infection (HPI)") -axes[1, 0].set_ylabel("Accuracy") -axes[1, 0].grid(True, alpha=0.3) -axes[1, 0].set_ylim(0, 1) - -axes[1, 1].plot(hpi_values, div_metrics["f1"], "g-", linewidth=2, label="F1 Score") -axes[1, 1].set_title("Division Classification F1 Score Over Time") -axes[1, 1].set_xlabel("Hours Post Infection (HPI)") -axes[1, 1].set_ylabel("F1 Score") -axes[1, 1].grid(True, alpha=0.3) -axes[1, 1].set_ylim(0, 1) - -axes[1, 2].plot(hpi_values, div_metrics["auc"], "r-", linewidth=2, label="AUC") -axes[1, 2].set_title("Division Classification AUC Over Time") -axes[1, 2].set_xlabel("Hours Post Infection (HPI)") -axes[1, 2].set_ylabel("AUC") -axes[1, 2].grid(True, alpha=0.3) -axes[1, 2].set_ylim(0, 1) - -plt.tight_layout() - -# %% -# Save metrics plot -metrics_path = ( - Path.home() - / "mydata" - / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/performance_metrics_over_time.pdf" -) -fig.savefig(str(metrics_path), dpi=300, bbox_inches="tight") -print(f"Metrics plot saved to: {metrics_path}") From 221cbeb992b8376daf911f2be82d5c29bfe7b941 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:23:56 -0700 Subject: [PATCH 063/101] refactoring the betavaemodule. removing the hyperparamter logging, adding the nn.Module as input for typing purposes and removing the fp32 custom fwd --- viscy/representation/engine.py | 198 ++------------------------------- 1 file changed, 12 insertions(+), 186 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index fa66528e7..c2d2d3ec5 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -17,11 +17,6 @@ _logger = logging.getLogger("lightning.pytorch") -_VAE_ARCHITECTURE = { - "2.5D": BetaVae25D, - "monai_beta": BetaVaeMonai, -} - class ContrastivePrediction(TypedDict): features: Tensor @@ -59,50 +54,7 @@ def __init__( self.log_embeddings = log_embeddings self.embedding_log_frequency = embedding_log_frequency - def on_train_start(self) -> None: - """Log comprehensive hyperparameters including model architecture details.""" - super().on_train_start() - - # Collect comprehensive hyperparameters - hparams = { - # Training hyperparameters - "lr": self.lr, - "schedule": self.schedule, - "input_shape": self.example_input_array, - "loss_function_class": self.loss_function.__class__.__name__, - } - - # Add loss function specific parameters - if hasattr(self.loss_function, "margin"): - hparams["loss_margin"] = self.loss_function.margin - if hasattr(self.loss_function, "temperature"): - hparams["loss_temperature"] = self.loss_function.temperature - if hasattr(self.loss_function, "normalize_embeddings"): - hparams["loss_normalize_embeddings"] = ( - self.loss_function.normalize_embeddings - ) - - # Add encoder details if it's a ContrastiveEncoder - if hasattr(self.model, "backbone"): - hparams["encoder_backbone"] = self.model.backbone - if hasattr(self.model, "in_channels"): - hparams["encoder_in_channels"] = self.model.in_channels - if hasattr(self.model, "in_stack_depth"): - hparams["encoder_in_stack_depth"] = self.model.in_stack_depth - if hasattr(self.model, "embedding_dim"): - hparams["encoder_embedding_dim"] = self.model.embedding_dim - if hasattr(self.model, "projection_dim"): - hparams["encoder_projection_dim"] = self.model.projection_dim - if hasattr(self.model, "drop_path_rate"): - hparams["encoder_drop_path_rate"] = self.model.drop_path_rate - if hasattr(self.model, "stem_kernel_size"): - hparams["encoder_stem_kernel_size"] = str(self.model.stem_kernel_size) - if hasattr(self.model, "stem_stride"): - hparams["encoder_stem_stride"] = str(self.model.stem_stride) - - # Log to TensorBoard - if self.logger is not None: - self.logger.log_hyperparams(hparams) + self.save_hyperparameters() def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Return both features and projections. @@ -270,118 +222,8 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - - # Log UMAP embeddings from validation set every N epochs - if ( - self.log_embeddings - and self.current_epoch % self.embedding_log_frequency == 0 - and self.current_epoch > 0 - ): - self._collect_and_log_embeddings() - self.validation_step_outputs = [] - def _collect_and_log_embeddings(self): - """Collect embeddings from validation dataloader and log UMAP visualization.""" - try: - # Get validation dataloader - val_dataloaders = self.trainer.val_dataloaders - if val_dataloaders is None: - _logger.warning( - "No validation dataloader available for embedding logging" - ) - return - elif isinstance(val_dataloaders, list): - val_dataloader = val_dataloaders[0] if val_dataloaders else None - else: - val_dataloader = val_dataloaders - - if val_dataloader is None: - _logger.warning( - "No validation dataloader available for embedding logging" - ) - return - - _logger.info( - f"Collecting embeddings for visualization at epoch {self.current_epoch}" - ) - - # Collect embeddings, images, and metadata from validation set - embeddings_list = [] - images_list = [] - labels_list = [] - max_samples = 500 # Reduced for memory efficiency with images - sample_count = 0 - - self.eval() - with torch.no_grad(): - for batch in val_dataloader: - if sample_count >= max_samples: - break - - # Move batch to device - anchor = batch["anchor"].to(self.device) - batch_size = anchor.size(0) - - # Get embeddings (features, not projections) - features, _ = self(anchor) - embeddings_list.append(features.cpu()) - - # Collect images for sprite visualization - # Take middle slice for 3D data and first channel if multi-channel - if anchor.ndim == 5: # (B, C, D, H, W) - mid_z = anchor.size(2) // 2 - img_slice = anchor[:, 0, mid_z].cpu() # (B, H, W) - else: # (B, C, H, W) - img_slice = anchor[:, 0].cpu() # (B, H, W) - images_list.append(img_slice) - - # Collect labels from index information - if "index" in batch and batch["index"] is not None: - for i, idx_info in enumerate(batch["index"][:batch_size]): - if isinstance(idx_info, dict): - # Create label from track_id and time info - track_id = idx_info.get("track_id", "unknown") - t = idx_info.get("t", "unknown") - labels_list.append(f"track_{track_id}_t_{t}") - else: - labels_list.append(f"sample_{sample_count + i}") - else: - # Fallback labels - for i in range(batch_size): - labels_list.append(f"sample_{sample_count + i}") - - sample_count += batch_size - - if embeddings_list: - embeddings = torch.cat(embeddings_list, dim=0)[:max_samples] - images = torch.cat(images_list, dim=0)[:max_samples] - labels = labels_list[:max_samples] - - # Normalize images for visualization (0-1 range) - images = (images - images.min()) / (images.max() - images.min() + 1e-8) - - # Log UMAP visualization - self.log_embedding_umap(embeddings, tag="validation") - - # Log to TensorBoard's embedding projector with images and labels - self.logger.experiment.add_embedding( - embeddings, - metadata=labels, - label_img=images.unsqueeze(1), # Add channel dimension - global_step=self.current_epoch, - tag="validation_embeddings", - ) - - _logger.info( - f"Logged {len(embeddings)} embeddings with images and labels" - ) - else: - _logger.warning("No embeddings collected from validation set") - - except Exception as e: - _logger.error(f"Error collecting embeddings: {e}") - def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer @@ -401,8 +243,7 @@ def predict_step( class BetaVaeModule(LightningModule): def __init__( self, - architecture: Literal["monai_beta", "2.5D"], - model_config: dict = {}, + vae: nn.Module | BetaVae25D | BetaVaeMonai, loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), beta: float = 1.0, beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, @@ -420,15 +261,7 @@ def __init__( ): super().__init__() - net_class = _VAE_ARCHITECTURE.get(architecture) - if not net_class: - raise ValueError( - f"Architecture {architecture} not in {_VAE_ARCHITECTURE.keys()}" - ) - - self.model = net_class(**model_config) - self.model_config = model_config - self.architecture = architecture + self.model = vae self.loss_function = loss_function self.beta = beta @@ -458,10 +291,12 @@ def __init__( # Handle different parameter names for latent dimensions latent_dim = None - if "latent_dim" in self.model_config: - latent_dim = self.model_config["latent_dim"] - elif "latent_size" in self.model_config: - latent_dim = self.model_config["latent_size"] + if hasattr(self.model, 'latent_dim'): + latent_dim = self.model.latent_dim + elif hasattr(self.model, 'latent_size'): + latent_dim = self.model.latent_size + elif hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'latent_dim'): + latent_dim = self.model.encoder.latent_dim if latent_dim is not None: self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) @@ -517,14 +352,14 @@ def _get_current_beta(self) -> float: else: return max(self.beta, self._min_beta) - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) def forward(self, x: Tensor) -> dict: """Forward pass through Beta-VAE.""" original_shape = x.shape is_monai_2d = ( - self.architecture == "monai_beta" - and self.model_config.get("spatial_dims") == 2 + isinstance(self.model, BetaVaeMonai) + and hasattr(self.model, 'spatial_dims') + and self.model.spatial_dims == 2 ) if is_monai_2d and len(x.shape) == 5 and x.shape[2] == 1: x = x.squeeze(2) @@ -676,7 +511,6 @@ def on_validation_epoch_end(self) -> None: self._log_samples("val_reconstructions", self.validation_step_outputs) self.validation_step_outputs = [] - # Compute disentanglement metrics periodically if ( self.compute_disentanglement and self.current_epoch % self.disentanglement_frequency == 0 @@ -694,7 +528,6 @@ def on_validation_epoch_end(self) -> None: def _compute_and_log_disentanglement_metrics(self): """Compute and log disentanglement metrics.""" try: - # Get validation dataloader - handle both single DataLoader and list cases val_dataloaders = self.trainer.val_dataloaders if val_dataloaders is None: val_dataloader = None @@ -709,7 +542,6 @@ def _compute_and_log_disentanglement_metrics(self): ) return - # Use the logger's disentanglement metrics method self.vae_logger.log_disentanglement_metrics( lightning_module=self, dataloader=val_dataloader, @@ -722,7 +554,6 @@ def _compute_and_log_disentanglement_metrics(self): def _log_enhanced_visualizations(self): """Log enhanced β-VAE visualizations.""" try: - # Get validation dataloader - handle both single DataLoader and list cases val_dataloaders = self.trainer.val_dataloaders if val_dataloaders is None: val_dataloader = None @@ -739,17 +570,12 @@ def _log_enhanced_visualizations(self): f"Logging enhanced β-VAE visualizations at epoch {self.current_epoch}" ) - # Log latent traversals -for how recons change when moving along a latent dim self.vae_logger.log_latent_traversal( lightning_module=self, n_dims=8, n_steps=11 ) - - # Log latent interpolations - smooth transitions between different data points in the latent space self.vae_logger.log_latent_interpolation( lightning_module=self, n_pairs=3, n_steps=11 ) - - # Log factor traversal matrix - grid visualization how each dim affects the recon self.vae_logger.log_factor_traversal_matrix( lightning_module=self, n_dims=8, n_steps=7 ) From 7a9254863d0e0cc97c230857b3f17b33f1a64ff7 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:24:07 -0700 Subject: [PATCH 064/101] remove the optuna dependency --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c36efbbe..281d4211c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,10 +45,6 @@ metrics = [ phate = [ "phate", ] -optimization = [ - "optuna", - "optuna-dashboard", -] examples = ["napari", "jupyter", "jupytext", "transformers>=4.51.3"] visual = [ From 06c7ea297b5a4c9f52ee8b491ce3705db44c1d86 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:25:28 -0700 Subject: [PATCH 065/101] deleting old msd test --- .../tests/test_distance.py | 958 ------------------ 1 file changed, 958 deletions(-) delete mode 100644 applications/contrastive_phenotyping/tests/test_distance.py diff --git a/applications/contrastive_phenotyping/tests/test_distance.py b/applications/contrastive_phenotyping/tests/test_distance.py deleted file mode 100644 index 448e25fc7..000000000 --- a/applications/contrastive_phenotyping/tests/test_distance.py +++ /dev/null @@ -1,958 +0,0 @@ -# %% -from typing import Literal - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import xarray as xr -from scipy import stats - -from viscy.representation.evaluation.distance import ( - compute_msd, -) - - -def generate_directional_embeddings_corrected( - n_timepoints: int = 100, - embedding_dim: int = 3, - n_tracks: int = 5, - movement_type: Literal[ - "smooth", "mild_chaos", "moderate_chaos", "high_chaos" - ] = "smooth", - target_direction: np.ndarray = None, - noise_std: float = 0.05, - seed: int = 42, - normalize_method: Literal["zscore", "l2"] | None = "zscore", -) -> xr.Dataset: - """ - Generate embeddings with multiple chaos levels. - - Parameters - ---------- - movement_type : str - - "smooth": Consistent direction and step size - - "mild_chaos": Slight randomness, similar to smooth - - "moderate_chaos": Moderate randomness and variability - - "high_chaos": High randomness and large jumps - """ - np.random.seed(seed) - - # Default target direction (toward positive x-axis) - if target_direction is None: - target_direction = np.zeros(embedding_dim) - target_direction[0] = 2.0 - - # Normalize target direction - target_direction = target_direction / (np.linalg.norm(target_direction) + 1e-8) - - # Define chaos parameters for each movement type - chaos_params = { - "smooth": { - "random_prob": 0.0, - "noise_scale": 0.15, - "jump_prob": 0.0, - "base_step": 0.12, - "step_std": 0.15, - }, - "mild_chaos": { - "random_prob": 0.1, - "noise_scale": 0.2, - "jump_prob": 0.03, - "exp_scales": [0.15, 0.25], - "jump_range": (1.5, 2.5), - }, - "moderate_chaos": { - "random_prob": 0.25, - "noise_scale": 0.3, - "jump_prob": 0.08, - "exp_scales": [0.12, 0.3, 0.6], - "jump_range": (2.0, 4.0), - }, - "high_chaos": { - "random_prob": 0.4, - "noise_scale": 0.4, - "jump_prob": 0.15, - "exp_scales": [0.1, 0.3, 0.8], - "jump_range": (3.0, 8.0), - }, - } - - params = chaos_params[movement_type] - - all_embeddings = [] - all_indices = [] - fov_name = "000000" - - for track_id in range(n_tracks): - timepoints = np.arange(n_timepoints) - embeddings = np.zeros((n_timepoints, embedding_dim)) - embeddings[0] = np.random.randn(embedding_dim) * 0.5 - - for t in range(1, n_timepoints): - if movement_type == "smooth": - # Smooth movement (original logic) - random_component = ( - np.random.randn(embedding_dim) * params["noise_scale"] - ) - direction = target_direction + random_component - direction = direction / (np.linalg.norm(direction) + 1e-8) - - step_size = params["base_step"] * ( - 1 + np.random.normal(0, params["step_std"]) - ) - step_size = max(0.05, step_size) - - else: - # Chaotic movement with varying levels - # Direction logic - if np.random.random() < params["random_prob"]: - direction = np.random.randn(embedding_dim) - direction = direction / (np.linalg.norm(direction) + 1e-8) - else: - random_component = ( - np.random.randn(embedding_dim) * params["noise_scale"] - ) - direction = target_direction + random_component - direction = direction / (np.linalg.norm(direction) + 1e-8) - - # Step size distribution - exp_scales = params["exp_scales"] - if len(exp_scales) == 2: # mild_chaos - if np.random.random() < 0.5: - step_size = np.random.exponential(exp_scales[0]) - else: - step_size = np.random.exponential(exp_scales[1]) - else: # moderate_chaos, high_chaos - rand_val = np.random.random() - if rand_val < 0.2: - step_size = np.random.exponential(exp_scales[0]) - elif rand_val < 0.5: - step_size = np.random.exponential(exp_scales[1]) - else: - step_size = np.random.exponential(exp_scales[2]) - - # Large jumps - if np.random.random() < params["jump_prob"]: - step_size *= np.random.uniform(*params["jump_range"]) - - # Take step - step = step_size * direction - embeddings[t] = embeddings[t - 1] + step - embeddings[t] += np.random.normal(0, noise_std, embedding_dim) - - # Optional normalization - if normalize_method == "zscore": - embeddings = (embeddings - np.mean(embeddings, axis=0)) / ( - np.std(embeddings, axis=0) + 1e-8 - ) - if normalize_method == "l2": - embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) - - all_embeddings.append(embeddings) - - # Create indices - for t in range(n_timepoints): - all_indices.append( - { - "fov_name": fov_name, - "track_id": track_id, - "t": timepoints[t], - "id": len(all_indices), - } - ) - - # Combine all tracks - all_embeddings = np.vstack(all_embeddings) - ultrack_indices = pd.DataFrame(all_indices) - index = pd.MultiIndex.from_frame(ultrack_indices) - - dataset_dict = {"features": (("sample", "features"), all_embeddings)} - dataset = xr.Dataset(dataset_dict, coords={"sample": index}).reset_index("sample") - - return dataset - - -def analyze_step_sizes_before_and_after_normalization( - n_tracks: int = 5, - n_timepoints: int = 100, - embedding_dim: int = 3, - target_direction: np.ndarray = None, - seed: int = 42, -) -> tuple[plt.Figure, plt.Axes]: - """ - Compare step size distributions before and after normalization. - - This demonstrates how normalization affects the step size magnitudes. - """ - # Generate datasets with and without normalization - unnormalized_smooth = generate_directional_embeddings_corrected( - n_tracks=n_tracks, - n_timepoints=n_timepoints, - embedding_dim=embedding_dim, - movement_type="smooth", - target_direction=target_direction, - normalize_method=None, # Key difference - seed=seed, - ) - - unnormalized_chaotic = generate_directional_embeddings_corrected( - n_tracks=n_tracks, - n_timepoints=n_timepoints, - embedding_dim=embedding_dim, - movement_type="mild_chaos", - target_direction=target_direction, - normalize_method=None, # Key difference - seed=seed, - ) - - normalized_smooth = generate_directional_embeddings_corrected( - n_tracks=n_tracks, - n_timepoints=n_timepoints, - embedding_dim=embedding_dim, - movement_type="smooth", - target_direction=target_direction, - normalize_method=None, - seed=seed, - ) - - normalized_chaotic = generate_directional_embeddings_corrected( - n_tracks=n_tracks, - n_timepoints=n_timepoints, - embedding_dim=embedding_dim, - movement_type="mild_chaos", - target_direction=target_direction, - normalize_method=None, - seed=seed, - ) - - # Extract step sizes using the debug function logic - def extract_step_sizes_simple(dataset): - all_step_sizes = [] - unique_track_ids = np.unique(dataset["track_id"].values) - - for track_id in unique_track_ids: - track_mask = dataset["track_id"] == track_id - track_times = dataset["t"].values[track_mask] - track_embeddings = dataset["features"].values[track_mask] - - time_order = np.argsort(track_times) - sorted_embeddings = track_embeddings[time_order] - - if len(sorted_embeddings) > 1: - steps = np.diff(sorted_embeddings, axis=0) - step_sizes = np.linalg.norm(steps, axis=1) - all_step_sizes.extend(step_sizes) - - return np.array(all_step_sizes) - - # Extract step sizes - smooth_unnorm_steps = extract_step_sizes_simple(unnormalized_smooth) - chaotic_unnorm_steps = extract_step_sizes_simple(unnormalized_chaotic) - smooth_norm_steps = extract_step_sizes_simple(normalized_smooth) - chaotic_norm_steps = extract_step_sizes_simple(normalized_chaotic) - - # Create comparison plot - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) - - # Before normalization - ax1.hist( - smooth_unnorm_steps, - bins=50, - alpha=0.7, - color="#2ca02c", - label=f"Smooth (μ={np.mean(smooth_unnorm_steps):.3f}, σ={np.std(smooth_unnorm_steps):.3f})", - ) - ax1.hist( - chaotic_unnorm_steps, - bins=50, - alpha=0.7, - color="#d62728", - label=f"Chaotic (μ={np.mean(chaotic_unnorm_steps):.3f}, σ={np.std(chaotic_unnorm_steps):.3f})", - ) - ax1.set_title("Before Normalization") - ax1.set_xlabel("Step Size") - ax1.set_ylabel("Frequency") - ax1.legend() - - # After normalization - ax2.hist( - smooth_norm_steps, - bins=50, - alpha=0.7, - color="#2ca02c", - label=f"Smooth (μ={np.mean(smooth_norm_steps):.3f}, σ={np.std(smooth_norm_steps):.3f})", - ) - ax2.hist( - chaotic_norm_steps, - bins=50, - alpha=0.7, - color="#d62728", - label=f"Chaotic (μ={np.mean(chaotic_norm_steps):.3f}, σ={np.std(chaotic_norm_steps):.3f})", - ) - ax2.set_title("After Normalization") - ax2.set_xlabel("Step Size") - ax2.set_ylabel("Frequency") - ax2.legend() - - # Log-scale comparison (before normalization) - ax3.hist(smooth_unnorm_steps, bins=50, alpha=0.7, color="#2ca02c", label="Smooth") - ax3.hist(chaotic_unnorm_steps, bins=50, alpha=0.7, color="#d62728", label="Chaotic") - ax3.set_yscale("log") - ax3.set_title("Before Normalization (Log Scale)") - ax3.set_xlabel("Step Size") - ax3.set_ylabel("Frequency (log)") - ax3.legend() - - # Coefficient of variation comparison - cv_smooth_unnorm = np.std(smooth_unnorm_steps) / np.mean(smooth_unnorm_steps) - cv_chaotic_unnorm = np.std(chaotic_unnorm_steps) / np.mean(chaotic_unnorm_steps) - cv_smooth_norm = np.std(smooth_norm_steps) / np.mean(smooth_norm_steps) - cv_chaotic_norm = np.std(chaotic_norm_steps) / np.mean(chaotic_norm_steps) - - categories = [ - "Smooth\n(Unnorm)", - "Chaotic\n(Unnorm)", - "Smooth\n(Norm)", - "Chaotic\n(Norm)", - ] - cv_values = [cv_smooth_unnorm, cv_chaotic_unnorm, cv_smooth_norm, cv_chaotic_norm] - colors = ["#2ca02c", "#d62728", "#2ca02c", "#d62728"] - alphas = [1.0, 1.0, 0.5, 0.5] - - # Create individual bars with their own alpha values - bars = [] - for i, (cat, val, color, alpha) in enumerate( - zip(categories, cv_values, colors, alphas) - ): - bar = ax4.bar(cat, val, color=color, alpha=alpha) - bars.extend(bar) - - ax4.set_ylabel("Coefficient of Variation (σ/μ)") - ax4.set_title("Step Size Variability Comparison") - ax4.tick_params(axis="x", rotation=45) - - plt.tight_layout() - return fig, (ax1, ax2, ax3, ax4) - - -def plot_msd_comparison( - msd_data_dict: dict[str, dict[int, list[float]]], - title: str = "MSD: Smooth vs Chaotic Diffusion (Same Direction)", - log_scale: bool = True, - show_power_law_fits: bool = True, -) -> tuple[plt.Figure, plt.Axes]: - """ - Plot MSD curves comparing smooth and chaotic diffusion. - - Parameters - ---------- - msd_data_dict : dict[str, dict[int, list[float]]] - Dictionary mapping movement type to MSD data - title : str - Plot title - log_scale : bool - Whether to use log-log scale - show_power_law_fits : bool - Whether to show power law fits - - Returns - ------- - tuple[plt.Figure, plt.Axes] - Figure and axes objects - """ - fig, ax = plt.subplots(figsize=(10, 7)) - - colors = {"smooth": "#2ca02c", "chaotic": "#d62728"} - - for movement_type, msd_data in msd_data_dict.items(): - time_lags = sorted(msd_data.keys()) - msd_means = [] - msd_stds = [] - - for tau in time_lags: - displacements = np.array(msd_data[tau]) - msd_means.append(np.mean(displacements)) - msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) - - time_lags = np.array(time_lags) - msd_means = np.array(msd_means) - msd_stds = np.array(msd_stds) - - # Plot with error bars - color = colors.get(movement_type, "#1f77b4") - ax.errorbar( - time_lags, - msd_means, - yerr=msd_stds, - marker="o", - label=f"{movement_type.title()} Diffusion", - color=color, - capsize=3, - capthick=1, - linewidth=2, - ) - - # Fit power law if requested - if show_power_law_fits and len(time_lags) > 3: - valid_mask = (time_lags > 0) & (msd_means > 0) - if np.sum(valid_mask) > 3: - log_tau = np.log(time_lags[valid_mask]) - log_msd = np.log(msd_means[valid_mask]) - - slope, intercept, r_value, p_value, std_err = stats.linregress( - log_tau, log_msd - ) - - # Plot fit line - tau_fit = np.linspace( - time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 - ) - msd_fit = np.exp(intercept) * tau_fit**slope - - ax.plot( - tau_fit, - msd_fit, - "--", - color=color, - alpha=0.7, - label=f"{movement_type}: α={slope:.2f} (R²={r_value**2:.3f})", - ) - - ax.set_xlabel("Time Lag (τ)", fontsize=12) - ax.set_ylabel("Mean Squared Displacement", fontsize=12) - ax.set_title(title, fontsize=14) - - if log_scale: - ax.set_xscale("log") - ax.set_yscale("log") - ax.grid(True, alpha=0.3) - - ax.legend() - plt.tight_layout() - return fig, ax - - -def plot_trajectory_comparison_3d( - smooth_dataset: xr.Dataset, - chaotic_dataset: xr.Dataset, - target_direction: np.ndarray = None, - title: str = "3D Trajectory Comparison: Smooth vs Chaotic", -) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]: - """ - Plot 3D trajectories comparing smooth and chaotic diffusion side by side. - - Parameters - ---------- - smooth_dataset : xr.Dataset - Dataset with smooth diffusion trajectories - chaotic_dataset : xr.Dataset - Dataset with chaotic diffusion trajectories - target_direction : np.ndarray - Target direction vector - title : str - Plot title - - Returns - ------- - tuple[plt.Figure, tuple[plt.Axes, plt.Axes]] - Figure and axes objects - """ - fig = plt.figure(figsize=(16, 7)) - - # Default target direction - if target_direction is None: - target_direction = np.array([2.0, 0.0, 0.0]) - - # Smooth diffusion plot - ax1 = fig.add_subplot(121, projection="3d") - plot_single_trajectory_3d(smooth_dataset, ax1, "Smooth Diffusion", target_direction) - - # Chaotic diffusion plot - ax2 = fig.add_subplot(122, projection="3d") - plot_single_trajectory_3d( - chaotic_dataset, ax2, "Chaotic Diffusion", target_direction - ) - - fig.suptitle(title, fontsize=16) - plt.tight_layout() - return fig, (ax1, ax2) - - -def plot_single_trajectory_3d( - dataset: xr.Dataset, - ax: plt.Axes, - subtitle: str, - target_direction: np.ndarray, -): - """ - Plot trajectories for a single dataset in 3D. - - Parameters - ---------- - dataset : xr.Dataset - Dataset containing trajectories - ax : plt.Axes - 3D axes object - subtitle : str - Subtitle for the plot - target_direction : np.ndarray - Target direction vector - """ - n_tracks = len(np.unique(dataset["track_id"].values)) - colors = plt.cm.tab10(np.linspace(0, 1, n_tracks)) - - unique_tracks_df = ( - dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() - ) - - for i, (fov_name, track_id) in enumerate( - zip(unique_tracks_df["fov_name"], unique_tracks_df["track_id"]) - ): - track_data = dataset.where( - (dataset["fov_name"] == fov_name) & (dataset["track_id"] == track_id), - drop=True, - ) - - # Sort by time - time_order = np.argsort(track_data["t"].values) - embeddings = track_data["features"].values[time_order] - - x, y, z = embeddings[:, 0], embeddings[:, 1], embeddings[:, 2] - color = colors[int(track_id) % len(colors)] - - # Plot trajectory - ax.plot(x, y, z, "-", color=color, alpha=0.7, linewidth=2) - - # Start and end points - ax.scatter( - x[0], - y[0], - z[0], - color=color, - s=100, - marker="o", - edgecolors="black", - linewidth=1, - ) - ax.scatter( - x[-1], - y[-1], - z[-1], - color=color, - s=150, - marker="*", - edgecolors="black", - linewidth=1, - ) - - # Show target direction arrow - origin = np.array([0, 0, 0]) - ax.quiver( - origin[0], - origin[1], - origin[2], - target_direction[0], - target_direction[1], - target_direction[2], - color="red", - arrow_length_ratio=0.1, - linewidth=3, - label="Target Direction", - ) - - ax.set_xlabel("Dimension 1") - ax.set_ylabel("Dimension 2") - ax.set_zlabel("Dimension 3") - ax.set_title(subtitle) - ax.legend() - - -def analyze_step_size_distributions_debug( - smooth_dataset: xr.Dataset, - chaotic_dataset: xr.Dataset, -) -> tuple[plt.Figure, plt.Axes]: - """ - Analyze and plot step size distributions with debugging information. - """ - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) - - def extract_step_sizes_simple(dataset, dataset_name): - """Extract step sizes with simple coordinate access.""" - all_step_sizes = [] - - # Get unique track IDs - unique_track_ids = np.unique(dataset["track_id"].values) - - print(f"\n{dataset_name} Dataset:") - print(f"Total samples: {len(dataset['track_id'])}") - print(f"Unique track IDs: {unique_track_ids}") - - for track_id in unique_track_ids: - # Get all data for this track - track_mask = dataset["track_id"] == track_id - track_times = dataset["t"].values[track_mask] - track_embeddings = dataset["features"].values[track_mask] - - # Sort by time - time_order = np.argsort(track_times) - sorted_embeddings = track_embeddings[time_order] - sorted_times = track_times[time_order] - - # Remove duplicates in time (this might be the issue) - unique_times, unique_indices = np.unique(sorted_times, return_index=True) - final_embeddings = sorted_embeddings[unique_indices] - - print( - f"Track {track_id}: {len(sorted_times)} total, {len(unique_times)} unique timepoints" - ) - - # Calculate step sizes - if len(final_embeddings) > 1: - steps = np.diff(final_embeddings, axis=0) - step_sizes = np.linalg.norm(steps, axis=1) - all_step_sizes.extend(step_sizes) - print(f"Track {track_id}: {len(step_sizes)} steps") - - print(f"Total steps in {dataset_name}: {len(all_step_sizes)}") - return np.array(all_step_sizes) - - # Extract step sizes with debug info - smooth_steps = extract_step_sizes_simple(smooth_dataset, "Smooth") - chaotic_steps = extract_step_sizes_simple(chaotic_dataset, "Chaotic") - - # Plot histograms - ax1.hist( - smooth_steps, - bins=50, - alpha=0.7, - color="#2ca02c", - label=f"Smooth (n={len(smooth_steps)}, μ={np.mean(smooth_steps):.3f}, σ={np.std(smooth_steps):.3f})", - ) - ax1.hist( - chaotic_steps, - bins=50, - alpha=0.7, - color="#d62728", - label=f"Chaotic (n={len(chaotic_steps)}, μ={np.mean(chaotic_steps):.3f}, σ={np.std(chaotic_steps):.3f})", - ) - ax1.set_xlabel("Step Size") - ax1.set_ylabel("Frequency") - ax1.set_title("Step Size Distribution") - ax1.legend() - - # Plot coefficient of variation - cv_smooth = np.std(smooth_steps) / np.mean(smooth_steps) - cv_chaotic = np.std(chaotic_steps) / np.mean(chaotic_steps) - - ax2.bar( - ["Smooth", "Chaotic"], - [cv_smooth, cv_chaotic], - color=["#2ca02c", "#d62728"], - alpha=0.7, - ) - ax2.set_ylabel("Coefficient of Variation (σ/μ)") - ax2.set_title("Step Size Variability") - - plt.tight_layout() - return fig, (ax1, ax2) - - -def plot_trajectory_comparison_3d_multi( - datasets: dict[str, xr.Dataset], - target_direction: np.ndarray = None, - title: str = "3D Trajectory Comparison: Multiple Movement Types", -) -> tuple[plt.Figure, list[plt.Axes]]: - """ - Plot 3D trajectories for multiple movement types. - - Parameters - ---------- - datasets : dict[str, xr.Dataset] - Dictionary mapping movement type name to dataset - target_direction : np.ndarray - Target direction vector - title : str - Plot title - """ - n_types = len(datasets) - cols = 2 - rows = (n_types + 1) // 2 - - fig = plt.figure(figsize=(12, 6 * rows)) - - # Default target direction - if target_direction is None: - target_direction = np.array([2.0, 0.0, 0.0]) - - axes = [] - for i, (movement_type, dataset) in enumerate(datasets.items()): - ax = fig.add_subplot(rows, cols, i + 1, projection="3d") - plot_single_trajectory_3d( - dataset, - ax, - f"{movement_type.replace('_', ' ').title()} Movement", - target_direction, - ) - axes.append(ax) - - fig.suptitle(title, fontsize=16) - plt.tight_layout() - return fig, axes - - -def plot_msd_comparison_multi( - msd_data_dict: dict[str, dict[int, list[float]]], - title: str = "MSD: Multiple Movement Types Comparison", - log_scale: bool = True, - show_power_law_fits: bool = True, -) -> tuple[plt.Figure, plt.Axes]: - """ - Plot MSD curves for multiple movement types. - """ - fig, ax = plt.subplots(figsize=(12, 8)) - - # Color palette for different movement types - colors = { - "smooth": "#2ca02c", - "mild_chaos": "#ff7f0e", - "moderate_chaos": "#d62728", - "high_chaos": "#9467bd", - } - - for movement_type, msd_data in msd_data_dict.items(): - time_lags = sorted(msd_data.keys()) - msd_means = [] - msd_stds = [] - - for tau in time_lags: - displacements = np.array(msd_data[tau]) - msd_means.append(np.mean(displacements)) - msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) - - time_lags = np.array(time_lags) - msd_means = np.array(msd_means) - msd_stds = np.array(msd_stds) - - # Plot with error bars - color = colors.get(movement_type, "#1f77b4") - ax.errorbar( - time_lags, - msd_means, - yerr=msd_stds, - marker="o", - label=f"{movement_type.replace('_', ' ').title()}", - color=color, - capsize=3, - capthick=1, - linewidth=2, - markersize=6, - ) - - # Fit power law if requested - if show_power_law_fits and len(time_lags) > 3: - valid_mask = (time_lags > 0) & (msd_means > 0) - if np.sum(valid_mask) > 3: - log_tau = np.log(time_lags[valid_mask]) - log_msd = np.log(msd_means[valid_mask]) - - slope, intercept, r_value, p_value, std_err = stats.linregress( - log_tau, log_msd - ) - - # Plot fit line - tau_fit = np.linspace( - time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 - ) - msd_fit = np.exp(intercept) * tau_fit**slope - - ax.plot( - tau_fit, - msd_fit, - "--", - color=color, - alpha=0.7, - label=f"{movement_type}: α={slope:.2f} (R²={r_value**2:.3f})", - ) - - ax.set_xlabel("Time Lag (τ)", fontsize=12) - ax.set_ylabel("Mean Squared Displacement", fontsize=12) - ax.set_title(title, fontsize=14) - - if log_scale: - ax.set_xscale("log") - ax.set_yscale("log") - ax.grid(True, alpha=0.3) - - ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left") - plt.tight_layout() - return fig, ax - - -def analyze_step_size_distributions_multi( - datasets: dict[str, xr.Dataset], -) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]: - """ - Analyze step size distributions for multiple movement types. - """ - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) - - colors = { - "smooth": "#2ca02c", - "mild_chaos": "#ff7f0e", - "moderate_chaos": "#d62728", - "high_chaos": "#9467bd", - } - - def extract_step_sizes_simple(dataset): - """Extract step sizes with simple coordinate access.""" - all_step_sizes = [] - unique_track_ids = np.unique(dataset["track_id"].values) - - for track_id in unique_track_ids: - track_mask = dataset["track_id"] == track_id - track_times = dataset["t"].values[track_mask] - track_embeddings = dataset["features"].values[track_mask] - - time_order = np.argsort(track_times) - sorted_embeddings = track_embeddings[time_order] - sorted_times = track_times[time_order] - - # Remove duplicates in time - unique_times, unique_indices = np.unique(sorted_times, return_index=True) - final_embeddings = sorted_embeddings[unique_indices] - - if len(final_embeddings) > 1: - steps = np.diff(final_embeddings, axis=0) - step_sizes = np.linalg.norm(steps, axis=1) - all_step_sizes.extend(step_sizes) - - return np.array(all_step_sizes) - - # Extract step sizes for all datasets - all_step_data = {} - cv_values = [] - labels = [] - - for movement_type, dataset in datasets.items(): - steps = extract_step_sizes_simple(dataset) - all_step_data[movement_type] = steps - - # Calculate coefficient of variation - cv = np.std(steps) / np.mean(steps) - cv_values.append(cv) - labels.append(movement_type.replace("_", " ").title()) - - # Plot histograms - for movement_type, steps in all_step_data.items(): - color = colors.get(movement_type, "#1f77b4") - ax1.hist( - steps, - bins=50, - alpha=0.7, - color=color, - label=f"{movement_type.replace('_', ' ').title()} (n={len(steps)}, μ={np.mean(steps):.3f}, σ={np.std(steps):.3f})", - ) - - ax1.set_xlabel("Step Size") - ax1.set_ylabel("Frequency") - ax1.set_title("Step Size Distributions") - ax1.legend() - - # Plot coefficient of variation - bar_colors = [ - colors.get(movement_type, "#1f77b4") for movement_type in datasets.keys() - ] - bars = ax2.bar(labels, cv_values, color=bar_colors, alpha=0.7) - ax2.set_ylabel("Coefficient of Variation (σ/μ)") - ax2.set_title("Step Size Variability") - ax2.tick_params(axis="x", rotation=45) - - plt.tight_layout() - return fig, (ax1, ax2) - - -# %% -if __name__ == "__main__": - # Note: direction of the embedding to simulate movement/infection. - target_direction = np.array([10.0, 0, 0.0]) - - movement_types = ["smooth", "mild_chaos", "moderate_chaos", "high_chaos"] - - datasets = {} - print("=== Generating Datasets ===") - for movement_type in movement_types: - print(f"Generating {movement_type} dataset...") - datasets[movement_type] = generate_directional_embeddings_corrected( - n_tracks=5, - n_timepoints=100, - movement_type=movement_type, - target_direction=target_direction, - normalize_method=None, - seed=42, - ) - - print("=== Computing MSD for All Movement Types ===") - msd_data_dict = {} - for movement_type, dataset in datasets.items(): - print(f"Computing MSD for {movement_type}...") - msd_data_dict[movement_type] = compute_msd(dataset) - - print("\n=== Normalizing MSD by Embedding Variance ===") - normalized_msd_data_dict = normalize_msd_by_embedding_variance( - msd_data_dict, datasets - ) - - print("=== MSD vs Time Plot (Raw) ===") - fig_msd_raw, ax_msd_raw = plot_msd_comparison_multi( - msd_data_dict, title="MSD: Raw Values (All Movement Types)" - ) - plt.show() - - print("=== MSD vs Time Plot (Normalized by Embedding Variance) ===") - fig_msd_norm, ax_msd_norm = plot_msd_comparison_multi( - normalized_msd_data_dict, - title="MSD: Normalized by Embedding Variance (All Movement Types)", - ) - plt.show() - - print("=== 3D Trajectory Comparison (All Types) ===") - fig_3d, axes_3d = plot_trajectory_comparison_3d_multi(datasets, target_direction) - plt.show() - - print("=== Step Size Distribution Analysis (All Types) ===") - fig_step, (ax_step1, ax_step2) = analyze_step_size_distributions_multi(datasets) - plt.show() - - print("=== Step Size Normalization Analysis ===") - step_stats = normalize_step_sizes_by_embedding_variance(datasets) - - print("=== Summary Statistics ===") - for movement_type, dataset in datasets.items(): - print(f"\n{movement_type.replace('_', ' ').title()} Movement:") - print(f" Dataset shape: {dataset.dims}") - print(f" Total samples: {len(dataset.sample)}") - - # Calculate mean step size and CV - def get_step_stats(dataset): - all_step_sizes = [] - unique_track_ids = np.unique(dataset["track_id"].values) - for track_id in unique_track_ids: - track_mask = dataset["track_id"] == track_id - track_embeddings = dataset["features"].values[track_mask] - if len(track_embeddings) > 1: - steps = np.diff(track_embeddings, axis=0) - step_sizes = np.linalg.norm(steps, axis=1) - all_step_sizes.extend(step_sizes) - return np.array(all_step_sizes) - - steps = get_step_stats(dataset) - mean_step = np.mean(steps) - std_step = np.std(steps) - cv = std_step / mean_step - - print(f" Mean step size: {mean_step:.4f}") - print(f" Step size std: {std_step:.4f}") - print(f" Coefficient of variation: {cv:.4f}") - - -# %% From 66c0100105a222c15a2c9f56fbffb18df462222c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:26:43 -0700 Subject: [PATCH 066/101] ruff format --- .../evaluation/compare_dtw_embeddings_sam2.py | 2 +- viscy/representation/engine.py | 10 ++++++---- viscy/representation/evaluation/clustering.py | 1 - viscy/representation/evaluation/distance.py | 2 +- viscy/representation/evaluation/smoothness.py | 2 +- viscy/transforms/_redef.py | 3 +-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py index 78a5a719b..c18528f6f 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -35,7 +35,7 @@ import os import napari -s + os.environ["DISPLAY"] = ":1" viewer = napari.Viewer() # %% diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index c2d2d3ec5..963917ae5 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -291,11 +291,13 @@ def __init__( # Handle different parameter names for latent dimensions latent_dim = None - if hasattr(self.model, 'latent_dim'): + if hasattr(self.model, "latent_dim"): latent_dim = self.model.latent_dim - elif hasattr(self.model, 'latent_size'): + elif hasattr(self.model, "latent_size"): latent_dim = self.model.latent_size - elif hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'latent_dim'): + elif hasattr(self.model, "encoder") and hasattr( + self.model.encoder, "latent_dim" + ): latent_dim = self.model.encoder.latent_dim if latent_dim is not None: @@ -358,7 +360,7 @@ def forward(self, x: Tensor) -> dict: original_shape = x.shape is_monai_2d = ( isinstance(self.model, BetaVaeMonai) - and hasattr(self.model, 'spatial_dims') + and hasattr(self.model, "spatial_dims") and self.model.spatial_dims == 2 ) if is_monai_2d and len(x.shape) == 5 and x.shape[2] == 1: diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index 66d9bda7b..ebf49455f 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -150,4 +150,3 @@ def clustering_evaluation(embeddings, annotations, method="nmi"): raise ValueError("Invalid method. Choose 'nmi' or 'ari'.") return score - diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index e55354be1..27dc19ba8 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -85,4 +85,4 @@ def compute_track_displacement( tau = int(times[i + time_offset] - times[i]) displacement_per_tau[tau].append(displacement) - return dict(displacement_per_tau) \ No newline at end of file + return dict(displacement_per_tau) diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 2e9dedaf1..869cd4baa 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -133,7 +133,7 @@ def compute_embeddings_smoothness( - random_frame_peak: Peak of random sampling distribution - smoothness_score: Score of smoothness - dynamic_range: Difference between random and adjacent peaks - distributions: dict: Dictionary containing distributions including: + distributions: dict: Dictionary containing distributions including: - adjacent_frame_distribution: Full distribution of adjacent frame dissimilarities - random_frame_distribution: Full distribution of random sampling dissimilarities piecewise_distance_per_track: list[list[float]] diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 8976094a5..258d3a274 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -188,7 +188,6 @@ def __init__( ): super().__init__(keys=keys, roi_size=roi_size, **kwargs) - class RandFlipd(RandFlipd): def __init__( self, @@ -202,4 +201,4 @@ def __init__( class NormalizeIntensityd(NormalizeIntensityd): def __init__(self, keys: Sequence[str] | str, **kwargs): - super().__init__(keys=keys, **kwargs) \ No newline at end of file + super().__init__(keys=keys, **kwargs) From 3dae71eca84e8923799d89069aa075dc4dbb5f7b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:34:44 -0700 Subject: [PATCH 067/101] fix to explicitly stratify on fov level --- viscy/representation/evaluation/lca.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 9965dcbac..93a0d75aa 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -81,17 +81,26 @@ def fit_logistic_regression( train_selection = fov_selection test_selection = ~fov_selection else: - # Use stratified sampling - n_samples = len(annotations_filtered) - indices = range(n_samples) - train_indices, test_indices = train_test_split( - indices, + unique_fovs = features_filtered["fov_name"].unique() + + fov_class_dist = [] + for fov in unique_fovs: + fov_mask = features_filtered["fov_name"] == fov + fov_classes = annotations_filtered[fov_mask] + # Use majority class for stratification or class distribution + majority_class = pd.Series(fov_classes).mode()[0] + fov_class_dist.append(majority_class) + + # Split FOVs, not individual samples + train_fovs, test_fovs = train_test_split( + unique_fovs, test_size=1 - train_ratio, - stratify=annotations_filtered, + stratify=fov_class_dist, random_state=random_state, ) - train_selection = pd.Series(False, index=range(n_samples)) - train_selection.iloc[train_indices] = True + + # Create selection based on FOV assignment + train_selection = features_filtered["fov_name"].isin(train_fovs) test_selection = ~train_selection train_features = features_filtered.values[train_selection] test_features = features_filtered.values[test_selection] From 57e7811dd01cb7d8e05b671e1f54cb04c155f5fb Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:37:22 -0700 Subject: [PATCH 068/101] adding reference to dataset for rpe1 --- viscy/data/cell_division_triplet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index 90ad2910f..877eead99 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -19,7 +19,12 @@ class CellDivisionTripletDataset(Dataset): - # Hardcoded channel mapping for .npy files + """Dataset for triplet sampling of cell division data from npy files. + + For the dataset from the paper: + https://arxiv.org/html/2502.02182v1 + """ + #NOTE: Hardcoded channel mapping for .npy files CHANNEL_MAPPING = { # Channel 0 aliases (brightfield) "bf": 0, @@ -140,10 +145,9 @@ def _sample_positive(self, anchor_info: dict) -> Tensor: # Use future timepoint positive_t = anchor_t + self.time_interval - positive_patch = track["data"][positive_t] # Shape: (C, Y, X) - # Add depth dimension only if not output_2d: (C, Y, X) -> (C, D=1, Y, X) + positive_patch = track["data"][positive_t] if not self.output_2d: - positive_patch = positive_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + positive_patch = positive_patch.unsqueeze(1) return positive_patch def _sample_negative(self, anchor_info: dict) -> Tensor: From 4ca292677c30fa3ed1e0d0edfcba45dd1ffa42c5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 15:42:11 -0700 Subject: [PATCH 069/101] fix pyproject.toml dev --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 281d4211c..39ba15170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,6 @@ metrics = [ phate = [ "phate", ] - examples = ["napari", "jupyter", "jupytext", "transformers>=4.51.3"] visual = [ "ipykernel", @@ -58,7 +57,7 @@ visual = [ "dash", ] dev = [ - "viscy[metrics,phate,examples,visual,optimization]", + "viscy[metrics,phate,examples,visual]", "pytest", "pytest-cov", "hypothesis", From 6bd786f89f6f97ac7756241b0ba45a1df5daee0a Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:30:34 -0700 Subject: [PATCH 070/101] format and lint --- viscy/data/cell_division_triplet.py | 7 ++++--- viscy/data/triplet.py | 1 - viscy/representation/evaluation/distance.py | 2 -- viscy/representation/vae_logging.py | 5 ----- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index 877eead99..ed3b45929 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -20,11 +20,12 @@ class CellDivisionTripletDataset(Dataset): """Dataset for triplet sampling of cell division data from npy files. - + For the dataset from the paper: https://arxiv.org/html/2502.02182v1 """ - #NOTE: Hardcoded channel mapping for .npy files + + # NOTE: Hardcoded channel mapping for .npy files CHANNEL_MAPPING = { # Channel 0 aliases (brightfield) "bf": 0, @@ -147,7 +148,7 @@ def _sample_positive(self, anchor_info: dict) -> Tensor: positive_patch = track["data"][positive_t] if not self.output_2d: - positive_patch = positive_patch.unsqueeze(1) + positive_patch = positive_patch.unsqueeze(1) return positive_patch def _sample_negative(self, anchor_info: dict) -> Tensor: diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 32f2c9a0d..9f495fbcc 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -507,7 +507,6 @@ def _setup_fit(self, dataset_settings: dict): **dataset_settings, ) - self.val_dataset = TripletDataset( positions=val_positions, tracks_tables=val_tracks_tables, diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 27dc19ba8..9ea940384 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,6 +1,4 @@ -import logging from collections import defaultdict -from typing import Literal import numpy as np import xarray as xr diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 310f909d3..ff0f462da 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -1,13 +1,8 @@ -import io import logging from typing import Callable, Optional, Tuple -import matplotlib.pyplot as plt import numpy as np import torch -from PIL import Image -from sklearn.decomposition import PCA -from sklearn.manifold import TSNE from torchvision.utils import make_grid from viscy.representation.disentanglement_metrics import DisentanglementMetrics From 4bb7b929d0273c415d5823b46f5d8826250f8fd2 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:32:57 -0700 Subject: [PATCH 071/101] restore no-augmentation flag effect --- viscy/data/triplet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 9f495fbcc..43dacf3ab 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -590,6 +590,8 @@ def _find_transform(self, key: str): if self.trainer: if self.trainer.predicting: return self._no_augmentation_transform + if self.trainer.validating and not self.augment_validation: + return self._no_augmentation_transform # NOTE: for backwards compatibility if key == "anchor" and self.time_interval in ("any", 0): return self._no_augmentation_transform From b5d71fdf82479093a8f691c9629ab5a986b1169e Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:43:50 -0700 Subject: [PATCH 072/101] format tests --- .../evaluation/test_distance.py | 90 ++++++++++--------- 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/tests/representation/evaluation/test_distance.py b/tests/representation/evaluation/test_distance.py index bdd75ecd3..6f82d5e54 100644 --- a/tests/representation/evaluation/test_distance.py +++ b/tests/representation/evaluation/test_distance.py @@ -13,12 +13,12 @@ def sample_embedding_dataset(): """Create a sample embedding dataset for testing.""" n_samples = 10 n_features = 5 - + features = np.random.rand(n_samples, n_features) fov_names = ["fov1"] * 5 + ["fov2"] * 5 track_ids = [1, 1, 1, 2, 2, 3, 3, 3, 4, 4] time_points = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1] - + dataset = xr.Dataset( { "features": (["sample", "features"], features), @@ -35,20 +35,20 @@ def test_calculate_cosine_similarity_cell(sample_embedding_dataset): time_points, similarities = calculate_cosine_similarity_cell( sample_embedding_dataset, "fov1", 1 ) - + assert len(time_points) == len(similarities) assert len(time_points) == 3 assert np.isclose(similarities[0], 1.0, atol=1e-6) assert all(-1 <= sim <= 1 for sim in similarities) -@pytest.mark.parametrize("distance_metric", ["cosine", "euclidean","sqeuclidean"]) +@pytest.mark.parametrize("distance_metric", ["cosine", "euclidean", "sqeuclidean"]) def test_compute_track_displacement(sample_embedding_dataset, distance_metric): """Test track displacement computation with different metrics.""" result = compute_track_displacement( sample_embedding_dataset, distance_metric=distance_metric ) - + assert isinstance(result, dict) assert all(isinstance(tau, int) for tau in result.keys()) assert all(isinstance(displacements, list) for displacements in result.values()) @@ -57,51 +57,61 @@ def test_compute_track_displacement(sample_embedding_dataset, distance_metric): for displacements in result.values() ) + def test_compute_track_displacement_numerical(): """Test compute_track_displacement with known embeddings and expected results.""" - features = np.array([ - [1.0, 0.0], - [0.0, 1.0], - [1.0, 1.0], - ]) - - dataset = xr.Dataset({ - "features": (["sample", "features"], features), - "fov_name": (["sample"], ["fov1", "fov1", "fov1"]), - "track_id": (["sample"], [1, 1, 1]), - "t": (["sample"], [0, 1, 2]), - }) + features = np.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + + dataset = xr.Dataset( + { + "features": (["sample", "features"], features), + "fov_name": (["sample"], ["fov1", "fov1", "fov1"]), + "track_id": (["sample"], [1, 1, 1]), + "t": (["sample"], [0, 1, 2]), + } + ) result_euclidean = compute_track_displacement(dataset, distance_metric="euclidean") - - assert 1 in result_euclidean + + assert 1 in result_euclidean assert 2 in result_euclidean - assert len(result_euclidean[1]) == 2 - assert len(result_euclidean[2]) == 1 - - result_sqeuclidean = compute_track_displacement(dataset, distance_metric="sqeuclidean") - expected_tau1 = [2.0, 1.0] - expected_tau2 = [1.0] - + assert len(result_euclidean[1]) == 2 + assert len(result_euclidean[2]) == 1 + + result_sqeuclidean = compute_track_displacement( + dataset, distance_metric="sqeuclidean" + ) + expected_tau1 = [2.0, 1.0] + expected_tau2 = [1.0] + assert np.allclose(sorted(result_sqeuclidean[1]), sorted(expected_tau1), atol=1e-10) assert np.allclose(result_sqeuclidean[2], expected_tau2, atol=1e-10) - - + result_cosine = compute_track_displacement(dataset, distance_metric="cosine") - expected_cosine_tau1 = [1.0, 1 - 1/np.sqrt(2)] - expected_cosine_tau2 = [1 - 1/np.sqrt(2)] - - assert np.allclose(sorted(result_cosine[1]), sorted(expected_cosine_tau1), atol=1e-10) + expected_cosine_tau1 = [1.0, 1 - 1 / np.sqrt(2)] + expected_cosine_tau2 = [1 - 1 / np.sqrt(2)] + + assert np.allclose( + sorted(result_cosine[1]), sorted(expected_cosine_tau1), atol=1e-10 + ) assert np.allclose(result_cosine[2], expected_cosine_tau2, atol=1e-10) def test_compute_track_displacement_empty_dataset(): """Test behavior with empty dataset.""" - empty_dataset = xr.Dataset({ - "features": (["sample", "features"], np.empty((0, 5))), - "fov_name": (["sample"], []), - "track_id": (["sample"], []), - "t": (["sample"], []), - }) - + empty_dataset = xr.Dataset( + { + "features": (["sample", "features"], np.empty((0, 5))), + "fov_name": (["sample"], []), + "track_id": (["sample"], []), + "t": (["sample"], []), + } + ) + result = compute_track_displacement(empty_dataset) - assert result == {} \ No newline at end of file + assert result == {} From f3108f0b458c88f80aadf3129aac682dbec8bc99 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Sep 2025 17:21:32 -0700 Subject: [PATCH 073/101] rename the sam2 file --- .../SAM2/{test_sam2_visualization.py => sam2_visualizations.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename applications/benchmarking/DynaCLR/SAM2/{test_sam2_visualization.py => sam2_visualizations.py} (100%) diff --git a/applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py similarity index 100% rename from applications/benchmarking/DynaCLR/SAM2/test_sam2_visualization.py rename to applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py From 01ed038f34971fd7b799e92cd6a4e0807993d0e1 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 14:48:39 -0700 Subject: [PATCH 074/101] removing unused arguments for logging embeddings. --- viscy/representation/engine.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 963917ae5..3043e77ac 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -37,8 +37,6 @@ def __init__( schedule: Literal["WarmupCosine", "Constant"] = "Constant", log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, - log_embeddings: bool = True, - embedding_log_frequency: int = 20, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -51,8 +49,6 @@ def __init__( self.example_input_array = torch.rand(*example_input_array_shape) self.training_step_outputs = [] self.validation_step_outputs = [] - self.log_embeddings = log_embeddings - self.embedding_log_frequency = embedding_log_frequency self.save_hyperparameters() From c73ffffb4227ed6ef761f33471425bdf9560787b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 19 Sep 2025 09:03:30 -0700 Subject: [PATCH 075/101] removing duplication in the lca --- viscy/representation/evaluation/lca.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 93a0d75aa..eab6c41d6 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -75,12 +75,8 @@ def fit_logistic_regression( features_filtered = features annotations_filtered = annotations - # Determine train/test split - if train_fovs is not None: - fov_selection = features_filtered["fov_name"].isin(train_fovs) - train_selection = fov_selection - test_selection = ~fov_selection - else: + # Determine train FOVs + if train_fovs is None: unique_fovs = features_filtered["fov_name"].unique() fov_class_dist = [] @@ -99,9 +95,9 @@ def fit_logistic_regression( random_state=random_state, ) - # Create selection based on FOV assignment - train_selection = features_filtered["fov_name"].isin(train_fovs) - test_selection = ~train_selection + # Create train/test selections + train_selection = features_filtered["fov_name"].isin(train_fovs) + test_selection = ~train_selection train_features = features_filtered.values[train_selection] test_features = features_filtered.values[test_selection] train_annotations = annotations_filtered[train_selection] From 7aa6d8defb370f3db3f70906185b35af66d29b56 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 13:10:11 -0700 Subject: [PATCH 076/101] remove disentaglement metrics --- .../representation/disentanglement_metrics.py | 374 ------------------ viscy/representation/engine.py | 35 -- viscy/representation/vae_logging.py | 64 --- 3 files changed, 473 deletions(-) delete mode 100644 viscy/representation/disentanglement_metrics.py diff --git a/viscy/representation/disentanglement_metrics.py b/viscy/representation/disentanglement_metrics.py deleted file mode 100644 index 5b721b390..000000000 --- a/viscy/representation/disentanglement_metrics.py +++ /dev/null @@ -1,374 +0,0 @@ -import logging -from typing import Dict, Optional, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from scipy import stats -from sklearn.ensemble import RandomForestClassifier -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import accuracy_score - -_logger = logging.getLogger(__name__) - - -class DisentanglementMetrics: - """ - Disentanglement metrics for VAE evaluation on microscopy data. - - Implements MIG, SAP, DCI, and Beta-VAE score metrics for evaluating - how well the VAE learns disentangled representations. - """ - - def __init__(self, device: str = "cuda"): - self.device = device - - def compute_all_metrics( - self, - vae_model: nn.Module, - dataloader: torch.utils.data.DataLoader, - max_samples: int = 1000, - n_factors: Optional[int] = None, - ) -> Dict[str, float]: - """ - Compute all disentanglement metrics. - - Args: - vae_model: Trained VAE model - dataloader: DataLoader with labeled data - max_samples: Maximum number of samples to use - n_factors: Number of known generative factors (if available) - - Returns: - Dictionary of metric scores - """ - latents, factors = self._extract_latents_and_factors( - vae_model, dataloader, max_samples - ) - - metrics = {} - - # MIG Score - metrics["MIG"] = self.compute_mig(latents, factors) - - # SAP Score - metrics["SAP"] = self.compute_sap(latents, factors) - - # DCI Scores - dci_scores = self.compute_dci(latents, factors) - metrics.update(dci_scores) - - # Beta-VAE Score (unsupervised) - metrics["Beta_VAE_Score"] = self.compute_beta_vae_score( - vae_model, dataloader, max_samples - ) - - return metrics - - def _extract_latents_and_factors( - self, - vae_model: nn.Module, - dataloader: torch.utils.data.DataLoader, - max_samples: int, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - Extract latent representations and generative factors. - - For microscopy data, we'll extract simple visual factors like: - - Cell size (approximated from pixel intensity) - - Cell count (approximated from connected components) - - Brightness (mean intensity) - - Contrast (std of intensity) - """ - vae_model.eval() - latents = [] - factors = [] - - samples_collected = 0 - - with torch.no_grad(): - for batch in dataloader: - if samples_collected >= max_samples: - break - - x = batch["anchor"].to(self.device) - batch_size = x.shape[0] - - # Extract latent representations - model_output = vae_model(x) - # Handle both dict format and object format - if isinstance(model_output, dict): - z = model_output["z"] - else: - z = ( - model_output.z - if hasattr(model_output, "z") - else model_output.embedding - ) - latents.append(z.cpu().numpy()) - - # Extract visual factors from images - batch_factors = self._extract_visual_factors(x.cpu()) - factors.append(batch_factors) - - samples_collected += batch_size - - latents = np.vstack(latents)[:max_samples] - factors = np.vstack(factors)[:max_samples] - - return latents, factors - - def _extract_visual_factors(self, images: torch.Tensor) -> np.ndarray: - """ - Extract visual factors from microscopy images. - - Args: - images: Batch of images (B, C, D, H, W) - - Returns: - Array of shape (B, n_factors) with extracted factors - """ - batch_size = images.shape[0] - factors = [] - - for i in range(batch_size): - img = images[i].numpy() # (C, D, H, W) - - # Take middle z-slice for 2D analysis - mid_z = img.shape[1] // 2 - img_2d = img[:, mid_z, :, :] # (C, H, W) - - # Factor 1: Brightness (mean intensity) - brightness = np.mean(img_2d) - - # Factor 2: Contrast (std of intensity) - contrast = np.std(img_2d) - - # Factor 3: Cell size (approximated by high-intensity regions) - binary_mask = img_2d[0] > np.percentile(img_2d[0], 75) - cell_size = np.sum(binary_mask) / (img_2d.shape[1] * img_2d.shape[2]) - - # Factor 4: Texture complexity (gradient magnitude) - grad_x = np.gradient(img_2d[0], axis=1) - grad_y = np.gradient(img_2d[0], axis=0) - texture = np.mean(np.sqrt(grad_x**2 + grad_y**2)) - - factors.append([brightness, contrast, cell_size, texture]) - - return np.array(factors) - - def compute_mig(self, latents: np.ndarray, factors: np.ndarray) -> float: - """ - Compute Mutual Information Gap (MIG). - - MIG = (1/K) * Σ_k (I(z_j*; v_k) - I(z_j'; v_k)) - where j* = argmax_j I(z_j; v_k) and j' = argmax_{j≠j*} I(z_j; v_k) - """ - - def mutual_info_continuous(x, y): - """Estimate mutual information between continuous variables.""" - # Discretize continuous variables - x_discrete = self._discretize(x) - y_discrete = self._discretize(y) - - # Compute mutual information - return self._mutual_info_discrete(x_discrete, y_discrete) - - n_factors = factors.shape[1] - n_latents = latents.shape[1] - - # Compute mutual information matrix - mi_matrix = np.zeros((n_latents, n_factors)) - - for i in range(n_latents): - for j in range(n_factors): - mi_matrix[i, j] = mutual_info_continuous(latents[:, i], factors[:, j]) - - # Compute MIG - mig = 0.0 - for j in range(n_factors): - mi_values = mi_matrix[:, j] - sorted_indices = np.argsort(mi_values)[::-1] - - if len(sorted_indices) > 1: - gap = mi_values[sorted_indices[0]] - mi_values[sorted_indices[1]] - mig += gap / np.max(mi_values) if np.max(mi_values) > 0 else 0 - - return mig / n_factors - - def compute_sap(self, latents: np.ndarray, factors: np.ndarray) -> float: - """ - Compute Attribute Predictability Score (SAP). - - SAP measures how well a simple classifier can predict factors from latents. - """ - n_factors = factors.shape[1] - scores = [] - - for i in range(n_factors): - # Discretize factor for classification - factor_discrete = self._discretize(factors[:, i], n_bins=10) - - # Train classifier - clf = LogisticRegression(random_state=42, max_iter=1000) - clf.fit(latents, factor_discrete) - - # Evaluate - pred = clf.predict(latents) - score = accuracy_score(factor_discrete, pred) - scores.append(score) - - return np.mean(scores) - - def compute_dci(self, latents: np.ndarray, factors: np.ndarray) -> Dict[str, float]: - """ - Compute Disentanglement, Completeness, and Informativeness (DCI). - """ - n_factors = factors.shape[1] - n_latents = latents.shape[1] - - # Train predictors for each factor - importance_matrix = np.zeros((n_factors, n_latents)) - - for i in range(n_factors): - # Discretize factor - factor_discrete = self._discretize(factors[:, i], n_bins=10) - - # Train random forest to get feature importance - rf = RandomForestClassifier(n_estimators=100, random_state=42) - rf.fit(latents, factor_discrete) - - importance_matrix[i, :] = rf.feature_importances_ - - # Normalize importance matrix - importance_matrix = importance_matrix / ( - np.sum(importance_matrix, axis=1, keepdims=True) + 1e-8 - ) - - # Compute DCI metrics - disentanglement = self._compute_disentanglement(importance_matrix) - completeness = self._compute_completeness(importance_matrix) - informativeness = self._compute_informativeness(importance_matrix) - - return { - "DCI_Disentanglement": disentanglement, - "DCI_Completeness": completeness, - "DCI_Informativeness": informativeness, - } - - def compute_beta_vae_score( - self, - vae_model: nn.Module, - dataloader: torch.utils.data.DataLoader, - max_samples: int, - ) -> float: - """ - Compute Beta-VAE score (unsupervised disentanglement metric). - - Measures how well individual latent dimensions affect reconstruction - when perturbed independently. - """ - vae_model.eval() - scores = [] - - samples_collected = 0 - - with torch.no_grad(): - for batch in dataloader: - if samples_collected >= max_samples: - break - - x = batch["anchor"].to(self.device) - batch_size = x.shape[0] - - # Get latent representation - model_output = vae_model(x) - # Handle both dict format and object format - if isinstance(model_output, dict): - z = model_output["z"] - else: - z = ( - model_output.z - if hasattr(model_output, "z") - else model_output.embedding - ) - - # Compute baseline reconstruction - baseline_recon = vae_model.decoder(z) - if hasattr(baseline_recon, "reconstruction"): - baseline_recon = baseline_recon.reconstruction - - # Perturb each latent dimension - for dim in range(z.shape[1]): - z_perturbed = z.clone() - z_perturbed[:, dim] += torch.randn_like(z_perturbed[:, dim]) * 0.5 - - # Get perturbed reconstruction - perturbed_recon = vae_model.decoder(z_perturbed) - if hasattr(perturbed_recon, "reconstruction"): - perturbed_recon = perturbed_recon.reconstruction - - # Compute reconstruction difference - diff = F.mse_loss(baseline_recon, perturbed_recon, reduction="none") - diff = diff.mean( - dim=(1, 2, 3, 4) - ) # Average over spatial dimensions - - # Score is inverse of reconstruction change - score = 1.0 / (1.0 + diff.mean().item()) - scores.append(score) - - samples_collected += batch_size - - return np.mean(scores) - - def _discretize(self, x: np.ndarray, n_bins: int = 20) -> np.ndarray: - """Discretize continuous variable into bins.""" - return np.digitize(x, np.linspace(x.min(), x.max(), n_bins)) - - def _mutual_info_discrete(self, x: np.ndarray, y: np.ndarray) -> float: - """Compute mutual information between discrete variables.""" - # Joint histogram - xy = np.stack([x, y], axis=1) - unique_xy, counts_xy = np.unique(xy, axis=0, return_counts=True) - p_xy = counts_xy / counts_xy.sum() - - # Marginal histograms - unique_x, counts_x = np.unique(x, return_counts=True) - p_x = counts_x / counts_x.sum() - - unique_y, counts_y = np.unique(y, return_counts=True) - p_y = counts_y / counts_y.sum() - - # Compute MI - mi = 0.0 - for i, (x_val, y_val) in enumerate(unique_xy): - p_joint = p_xy[i] - p_x_marginal = p_x[unique_x == x_val][0] - p_y_marginal = p_y[unique_y == y_val][0] - - if p_joint > 0 and p_x_marginal > 0 and p_y_marginal > 0: - mi += p_joint * np.log(p_joint / (p_x_marginal * p_y_marginal)) - - return mi - - def _compute_disentanglement(self, importance_matrix: np.ndarray) -> float: - """Compute disentanglement score from importance matrix.""" - disentanglement = 0.0 - for i in range(importance_matrix.shape[0]): - if np.sum(importance_matrix[i]) > 0: - disentanglement += 1.0 - stats.entropy(importance_matrix[i]) - return disentanglement / importance_matrix.shape[0] - - def _compute_completeness(self, importance_matrix: np.ndarray) -> float: - """Compute completeness score from importance matrix.""" - completeness = 0.0 - for j in range(importance_matrix.shape[1]): - if np.sum(importance_matrix[:, j]) > 0: - completeness += 1.0 - stats.entropy(importance_matrix[:, j]) - return completeness / importance_matrix.shape[1] - - def _compute_informativeness(self, importance_matrix: np.ndarray) -> float: - """Compute informativeness score from importance matrix.""" - return np.mean(np.sum(importance_matrix, axis=1)) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 3043e77ac..070ef80d4 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -250,8 +250,6 @@ def __init__( log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_array_shape: Sequence[int] = (1, 2, 30, 256, 256), - compute_disentanglement: bool = True, - disentanglement_frequency: int = 10, log_enhanced_visualizations: bool = False, log_enhanced_visualizations_frequency: int = 30, ): @@ -272,8 +270,6 @@ def __init__( self.log_samples_per_batch = log_samples_per_batch self.example_input_array = torch.rand(*example_input_array_shape) - self.compute_disentanglement = compute_disentanglement - self.disentanglement_frequency = disentanglement_frequency self.log_enhanced_visualizations = log_enhanced_visualizations self.log_enhanced_visualizations_frequency = ( @@ -509,12 +505,6 @@ def on_validation_epoch_end(self) -> None: self._log_samples("val_reconstructions", self.validation_step_outputs) self.validation_step_outputs = [] - if ( - self.compute_disentanglement - and self.current_epoch % self.disentanglement_frequency == 0 - and self.current_epoch > 0 - ): - self._compute_and_log_disentanglement_metrics() if ( self.log_enhanced_visualizations @@ -523,31 +513,6 @@ def on_validation_epoch_end(self) -> None: ): self._log_enhanced_visualizations() - def _compute_and_log_disentanglement_metrics(self): - """Compute and log disentanglement metrics.""" - try: - val_dataloaders = self.trainer.val_dataloaders - if val_dataloaders is None: - val_dataloader = None - elif isinstance(val_dataloaders, list): - val_dataloader = val_dataloaders[0] if val_dataloaders else None - else: - val_dataloader = val_dataloaders - - if val_dataloader is None: - _logger.warning( - "No validation dataloader available for disentanglement metrics" - ) - return - - self.vae_logger.log_disentanglement_metrics( - lightning_module=self, - dataloader=val_dataloader, - max_samples=200, - ) - - except Exception as e: - _logger.error(f"Error computing disentanglement metrics: {e}") def _log_enhanced_visualizations(self): """Log enhanced β-VAE visualizations.""" diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index ff0f462da..21b593710 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -5,7 +5,6 @@ import torch from torchvision.utils import make_grid -from viscy.representation.disentanglement_metrics import DisentanglementMetrics _logger = logging.getLogger(__name__) @@ -21,13 +20,10 @@ class BetaVaeLogger: def __init__(self, latent_dim: int = 128): self.latent_dim = latent_dim self.device = None - self.disentanglement_metrics = None def setup(self, device: str): """Initialize device-dependent components.""" self.device = device - if self.disentanglement_metrics is None: - self.disentanglement_metrics = DisentanglementMetrics(device=device) def log_enhanced_metrics( self, lightning_module, model_output: dict, batch: dict, stage: str = "train" @@ -366,63 +362,3 @@ def log_beta_schedule( lightning_module.log("beta_schedule", beta) return beta - def log_disentanglement_metrics( - self, - lightning_module, - dataloader: torch.utils.data.DataLoader, - max_samples: int = 500, - sync_dist: bool = True, - ): - """ - Log disentanglement metrics to TensorBoard every 10 epochs. - - Args: - lightning_module: Lightning module instance - dataloader: DataLoader for evaluation - max_samples: Maximum samples to use for evaluation - """ - # Only compute every 10 epochs to save compute - if lightning_module.current_epoch % 10 != 0: - return - - _logger.info( - f"Computing disentanglement metrics at epoch {lightning_module.current_epoch}" - ) - - try: - # Use the lightning module directly (no separate model attribute after refactoring) - vae_model = lightning_module - - # Compute all disentanglement metrics - metrics = self.disentanglement_metrics.compute_all_metrics( - vae_model=vae_model, - dataloader=dataloader, - max_samples=max_samples, - ) - - # Log metrics with organized naming - tensorboard_metrics = {} - for metric_name, value in metrics.items(): - tensorboard_metrics[f"disentanglement_metrics/{metric_name}"] = value - - lightning_module.log_dict( - tensorboard_metrics, - on_step=False, - on_epoch=True, - logger=True, - sync_dist=sync_dist, - ) - - _logger.info(f"Logged disentanglement metrics: {metrics}") - - except Exception as e: - _logger.warning(f"Failed to compute disentanglement metrics: {e}") - # Log a placeholder to indicate the attempt - lightning_module.log( - "disentanglement_metrics/computation_failed", - 1.0, - on_step=False, - on_epoch=True, - logger=True, - sync_dist=sync_dist, - ) From d1b255bf22ad1082a993984ab13399d950ea96a4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 13:14:58 -0700 Subject: [PATCH 077/101] vectorized the anchor filtering for celldivisiontriplet dataset --- viscy/data/cell_division_triplet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index ed3b45929..e4dc3b1d8 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -84,6 +84,10 @@ def __init__( # Load and process all data files self.cell_tracks = self._load_data(data_paths) self.valid_anchors = self._filter_anchors() + + # Create arrays for vectorized operations + self.track_ids = np.array([t["track_id"] for t in self.cell_tracks]) + self.cell_tracks_array = np.array(self.cell_tracks) def _load_data(self, data_paths: list[Path]) -> list[dict]: """Load npy files.""" @@ -155,9 +159,9 @@ def _sample_negative(self, anchor_info: dict) -> Tensor: """Select a negative sample from a different track.""" anchor_track_id = anchor_info["track_id"] - negative_candidates = [ - t for t in self.cell_tracks if t["track_id"] != anchor_track_id - ] + # Vectorized filtering using boolean indexing + mask = self.track_ids != anchor_track_id + negative_candidates = self.cell_tracks_array[mask].tolist() if not negative_candidates: # Fallback: use different timepoint from same track From 9e939af0959f708ada30fad23da640714c37cc50 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 13:22:41 -0700 Subject: [PATCH 078/101] map the channels to the rpe dataset convention --- viscy/data/cell_division_triplet.py | 32 ++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index e4dc3b1d8..73e707d6c 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -88,7 +88,28 @@ def __init__( # Create arrays for vectorized operations self.track_ids = np.array([t["track_id"] for t in self.cell_tracks]) self.cell_tracks_array = np.array(self.cell_tracks) - + + # Map channel names to indices using CHANNEL_MAPPING + self.channel_indices = self._map_channel_indices(channel_names) + + def _map_channel_indices(self, channel_names: list[str]) -> list[int]: + """Map channel names to their corresponding indices in the data array.""" + channel_indices = [] + for name in channel_names: + if name in self.CHANNEL_MAPPING: + channel_indices.append(self.CHANNEL_MAPPING[name]) + else: + # Try to parse as integer if not in mapping + try: + channel_indices.append(int(name)) + except ValueError: + raise ValueError(f"Channel '{name}' not found in CHANNEL_MAPPING and is not a valid integer") + return channel_indices + + def _select_channels(self, patch: Tensor) -> Tensor: + """Select only the requested channels from the patch.""" + return patch[self.channel_indices] + def _load_data(self, data_paths: list[Path]) -> list[dict]: """Load npy files.""" all_tracks = [] @@ -151,6 +172,7 @@ def _sample_positive(self, anchor_info: dict) -> Tensor: positive_t = anchor_t + self.time_interval positive_patch = track["data"][positive_t] + positive_patch = self._select_channels(positive_patch) if not self.output_2d: positive_patch = positive_patch.unsqueeze(1) return positive_patch @@ -173,9 +195,11 @@ def _sample_negative(self, anchor_info: dict) -> Tensor: if available_times: neg_t = random.choice(available_times) negative_patch = track["data"][neg_t] + negative_patch = self._select_channels(negative_patch) else: # Ultimate fallback: use same patch (transforms will differentiate) negative_patch = track["data"][anchor_t] + negative_patch = self._select_channels(negative_patch) else: # Sample from different track neg_track = random.choice(negative_candidates) @@ -192,6 +216,7 @@ def _sample_negative(self, anchor_info: dict) -> Tensor: neg_t = random.randint(0, neg_track["num_timepoints"] - 1) negative_patch = neg_track["data"][neg_t] + negative_patch = self._select_channels(negative_patch) # Add depth dimension only if not output_2d: (C, Y, X) -> (C, D=1, Y, X) if not self.output_2d: @@ -203,10 +228,11 @@ def __getitem__(self, index: int) -> TripletSample: track = anchor_info["track"] anchor_t = anchor_info["timepoint"] - # Get anchor patch and add depth dimension only if not output_2d + # Get anchor patch and select requested channels anchor_patch = track["data"][anchor_t] # Shape: (C, Y, X) + anchor_patch = self._select_channels(anchor_patch) if not self.output_2d: - anchor_patch = anchor_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + anchor_patch = anchor_patch.unsqueeze(1) sample = {"anchor": anchor_patch} From 6307ef8a75e4a763ae7f99bf0e8dd914b063dcb9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 13:29:54 -0700 Subject: [PATCH 079/101] fix logistic regresion standardization --- viscy/representation/evaluation/lca.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index eab6c41d6..d82d324c0 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -104,9 +104,8 @@ def fit_logistic_regression( test_annotations = annotations_filtered[test_selection] if scale_features: - scaler = StandardScaler() - train_features = scaler.fit_transform(train_features) - test_features = scaler.transform(test_features) + train_features = StandardScaler().fit_transform(train_features) + test_features = StandardScaler().fit_transform(test_features) logistic_regression = LogisticRegression( class_weight=class_weight, random_state=random_state, From 0fe0ef862549ac7fc677ca27caaf85980bb1034c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 14:43:01 -0700 Subject: [PATCH 080/101] update rpe classifier to include mitosis --- .../evaluation/rpe1_fucci/linear_classifier.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py index d3b984d6e..077e57c87 100644 --- a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py @@ -10,9 +10,9 @@ from viscy.representation.embedding_writer import read_embedding_dataset test_data_features_path = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_data_ckpt264.zarr" + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/bf_only_timeaware_ntxent_lr2e-5_temp_7e-2_tau1_w_augmentations_2_ckpt306.zarr" ) -cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/pseudolabels/cell_cycle_labels.csv" +cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/dynaclr/pseudolabels/cell_cycle_labels_w_mitosis.csv" # %% # Load the data @@ -23,8 +23,6 @@ features = test_embeddings.features.values # %% -# Create a combined identifier for matching -# The sample coordinate contains (fov_name, id) tuples sample_coords = test_embeddings.coords["sample"].values fov_names = [coord[0] for coord in sample_coords] ids = [coord[1] for coord in sample_coords] @@ -76,18 +74,18 @@ # Enhanced evaluation and visualization import matplotlib.pyplot as plt import seaborn as sns -from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix # 1. Confusion Matrix - shows which classes are confused with each other cm = confusion_matrix(y_test, y_test_pred) plt.figure(figsize=(8, 6)) -ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S"]).plot(cmap="Blues") +ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S","M"]).plot(cmap="Blues") plt.title("Confusion Matrix") plt.show() # 2. Per-class errors breakdown print("\nDetailed per-class analysis:") -for class_name in ["G1", "G2", "S"]: +for class_name in ["G1", "G2", "S","M"]: mask = y_test == class_name correct = (y_test_pred[mask] == class_name).sum() total = mask.sum() @@ -105,9 +103,9 @@ plt.figure(figsize=(12, 4)) for i, class_name in enumerate(class_names): - plt.subplot(1, 3, i + 1) + plt.subplot(1, 4, i + 1) plt.hist( - y_test_proba[:, i], bins=20, alpha=0.7, color=["blue", "orange", "green"][i] + y_test_proba[:, i], bins=20, alpha=0.7, color=["blue", "orange", "green",'red'][i] ) plt.title(f"Confidence for {class_name}") plt.xlabel("Probability") From cbadef3218870d9366d52ba15b56bd6eff26af27 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 14:43:06 -0700 Subject: [PATCH 081/101] ruff --- viscy/data/cell_division_triplet.py | 12 +++++++----- viscy/representation/engine.py | 2 -- viscy/representation/vae_logging.py | 2 -- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index 73e707d6c..a1ccc3cef 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -84,11 +84,11 @@ def __init__( # Load and process all data files self.cell_tracks = self._load_data(data_paths) self.valid_anchors = self._filter_anchors() - + # Create arrays for vectorized operations self.track_ids = np.array([t["track_id"] for t in self.cell_tracks]) self.cell_tracks_array = np.array(self.cell_tracks) - + # Map channel names to indices using CHANNEL_MAPPING self.channel_indices = self._map_channel_indices(channel_names) @@ -103,13 +103,15 @@ def _map_channel_indices(self, channel_names: list[str]) -> list[int]: try: channel_indices.append(int(name)) except ValueError: - raise ValueError(f"Channel '{name}' not found in CHANNEL_MAPPING and is not a valid integer") + raise ValueError( + f"Channel '{name}' not found in CHANNEL_MAPPING and is not a valid integer" + ) return channel_indices - + def _select_channels(self, patch: Tensor) -> Tensor: """Select only the requested channels from the patch.""" return patch[self.channel_indices] - + def _load_data(self, data_paths: list[Path]) -> list[dict]: """Load npy files.""" all_tracks = [] diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 070ef80d4..8a48a2532 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -505,7 +505,6 @@ def on_validation_epoch_end(self) -> None: self._log_samples("val_reconstructions", self.validation_step_outputs) self.validation_step_outputs = [] - if ( self.log_enhanced_visualizations and self.current_epoch % self.log_enhanced_visualizations_frequency == 0 @@ -513,7 +512,6 @@ def on_validation_epoch_end(self) -> None: ): self._log_enhanced_visualizations() - def _log_enhanced_visualizations(self): """Log enhanced β-VAE visualizations.""" try: diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 21b593710..4974002ef 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -5,7 +5,6 @@ import torch from torchvision.utils import make_grid - _logger = logging.getLogger(__name__) @@ -361,4 +360,3 @@ def log_beta_schedule( lightning_module.log("beta_schedule", beta) return beta - From 46ff7a2cb0590a0ab58195b1065a1046b0f249cc Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 14:49:00 -0700 Subject: [PATCH 082/101] remove unused logging --- viscy/representation/vae_logging.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py index 4974002ef..b06894f9a 100644 --- a/viscy/representation/vae_logging.py +++ b/viscy/representation/vae_logging.py @@ -1,12 +1,9 @@ -import logging from typing import Callable, Optional, Tuple import numpy as np import torch from torchvision.utils import make_grid -_logger = logging.getLogger(__name__) - class BetaVaeLogger: """ From 4536d3cd68b15aad1f9680fc18ffd55a7638b563 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 17:14:30 -0700 Subject: [PATCH 083/101] datamodule agnostic --- .../DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml | 3 ++- .../benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py | 7 +++++-- .../DynaCLR/OpenPhenom/config_template.yml | 1 + .../DynaCLR/OpenPhenom/openphenom_embeddings.py | 7 ++++++- .../benchmarking/DynaCLR/SAM2/sam2_config.yml | 1 + .../benchmarking/DynaCLR/SAM2/sam2_embeddings.py | 7 ++++++- viscy/representation/vae.py | 11 +++++++++-- 7 files changed, 30 insertions(+), 7 deletions(-) diff --git a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml index ab6bb52cc..268ee1e6a 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml +++ b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml @@ -1,3 +1,4 @@ +datamodule_class: viscy.data.triplet.TripletDataModule datamodule: batch_size: 32 final_yx_patch_size: @@ -60,5 +61,5 @@ model: paths: data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr output_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_mean.zarr - tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py index 984894c94..15e50b78b 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -13,7 +13,6 @@ from skimage.exposure import rescale_intensity from transformers import AutoImageProcessor, AutoModel -from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import EmbeddingWriter from viscy.trainer import VisCyTrainer @@ -314,7 +313,11 @@ def main(config): dm_params[param] = value logger.info("Setting up data module") - dm = TripletDataModule(**dm_params) + class_path = cfg["datamodule_class"] + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + datamodule_class = getattr(module, class_name) + dm = datamodule_class(**dm_params) # Get model parameters model_name = cfg["model"].get("model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m") diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml index 4139e0d08..48b0e7c90 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml +++ b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml @@ -16,6 +16,7 @@ model: "raw GFP EX488 EM525-45": "max" # Data module configuration +datamodule_class: viscy.data.triplet.TripletDataModule datamodule: source_channel: - Phase3D diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py index 510d8e132..4e0aa4343 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py +++ b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py @@ -226,7 +226,12 @@ def main(config): # Set up the data module logger.info("Setting up data module") - dm = TripletDataModule(**dm_params) + + class_path = cfg["datamodule_class"] + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + datamodule_class = getattr(module, class_name) + dm = datamodule_class(**dm_params) # Get model parameters for handling 5D inputs channel_reduction_methods = {} diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml index 19f3ea51f..89bbb5202 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml @@ -1,3 +1,4 @@ +datamodule_class: viscy.data.triplet.TripletDataModule datamodule: batch_size: 32 final_yx_patch_size: diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py index 66f7dfb83..4a4d6d278 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -330,8 +330,13 @@ def main(config): # Set up the data module logger.info("Setting up data module") - dm = TripletDataModule(**dm_params) + class_path = cfg["datamodule_class"] + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + datamodule_class = getattr(module, class_name) + dm = datamodule_class(**dm_params) + # Get model parameters for handling 5D inputs channel_reduction_methods = {} middle_slice_index = None diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index a496ef569..e91f1e479 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -90,8 +90,15 @@ def __init__( def forward(self, inp: Tensor) -> Tensor: """ - :param Tensor inp: Low resolution features - :return Tensor: High resolution features + Parameters + ---------- + inp : Tensor + Low resolution features + + Returns + ------- + Tensor + High resolution features """ inp = self.upsample(inp) return self.conv(inp) From 9f00cfd0903478b63d6d4b7a364e5f3cb963f248 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 22 Sep 2025 17:32:51 -0700 Subject: [PATCH 084/101] cleaning up duplicated code in the benchmarking --- .../DINOV3/config_dinov3_convnext_tiny.yml | 4 +- .../DynaCLR/DINOV3/dinov3_embeddings.py | 339 ++------------- .../DynaCLR/OpenPhenom/config_template.yml | 4 +- .../OpenPhenom/openphenom_embeddings.py | 328 ++------------ .../benchmarking/DynaCLR/SAM2/sam2_config.yml | 4 +- .../DynaCLR/SAM2/sam2_embeddings.py | 408 ++---------------- 6 files changed, 105 insertions(+), 982 deletions(-) diff --git a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml index 268ee1e6a..195aa6db9 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml +++ b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml @@ -1,5 +1,7 @@ datamodule_class: viscy.data.triplet.TripletDataModule datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr batch_size: 32 final_yx_patch_size: - 256 @@ -60,6 +62,4 @@ model: - Phase3D paths: - data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr - tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr output_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_mean.zarr diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py index 15e50b78b..49cdb4064 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -1,23 +1,19 @@ -import importlib -import logging -import os +import sys from pathlib import Path from typing import Dict, List, Literal, Optional -import click import numpy as np import torch -import yaml -from lightning.pytorch import LightningModule from PIL import Image from skimage.exposure import rescale_intensity from transformers import AutoImageProcessor, AutoModel -from viscy.representation.embedding_writer import EmbeddingWriter -from viscy.trainer import VisCyTrainer +sys.path.append(str(Path(__file__).parent.parent)) +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli -class DINOv3Module(LightningModule): + +class DINOv3Module(BaseEmbeddingModule): def __init__( self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", @@ -25,37 +21,28 @@ def __init__( Dict[str, Literal["middle_slice", "mean", "max"]] ] = None, channel_names: Optional[List[str]] = None, - pooling_method: Literal["mean", "max", "cls_token"] = "mean", + pooling_method: Literal["mean", "max", "cls_token"] = "mean", middle_slice_index: Optional[int] = None, ): - """ - DINOv3 module for feature extraction. - - Parameters - ---------- - model_name : str, optional - DINOv3 model name from HuggingFace Model Hub (default: "facebook/dinov3-vitb16-pretrain-lvd1689m"). - channel_reduction_methods : dict[str, {"middle_slice", "mean", "max"}], optional - Dictionary mapping channel names to reduction methods for 5D inputs (default: None, uses "middle_slice"). - channel_names : list of str, optional - List of channel names corresponding to input channels (default: None). - pooling_method : Literal["mean", "max", "cls_token"], optional - Method to pool spatial tokens from the model output (default: "mean"). - middle_slice_index : int, optional - Specific z-slice index to use for "middle_slice" reduction (default: None, uses D//2). - - """ - super().__init__() + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) self.model_name = model_name - self.channel_reduction_methods = channel_reduction_methods or {} - self.channel_names = channel_names or [] self.pooling_method = pooling_method - self.middle_slice_index = middle_slice_index - - torch.set_float32_matmul_precision("high") + self.model = None self.processor = None + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) + return cls( + model_name=model_config.get("model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m"), + pooling_method=model_config.get("pooling_method", "mean"), + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + channel_names=model_config.get("channel_names", []), + middle_slice_index=model_config.get("middle_slice_index", None), + ) + def on_predict_start(self): if self.model is None: self.processor = AutoImageProcessor.from_pretrained(self.model_name) @@ -63,59 +50,21 @@ def on_predict_start(self): self.model.eval() self.model.to(self.device) - def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """ - Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - - Parameters - ---------- - x : torch.Tensor - 5D input tensor with shape (B, C, D, H, W). - - Returns - ------- - torch.Tensor - 4D tensor after applying reduction methods with shape (B, C, H, W). - """ - if x.dim() != 5: - return x - - B, C, D, H, W = x.shape - result = torch.zeros((B, C, H, W), device=x.device) + def _process_input(self, x: torch.Tensor): + """Convert tensor to PIL Images for DINOv3 processing.""" + return self._convert_to_pil_images(x) - # Group channels by reduction method - middle_slice_indices = [] - mean_indices = [] - max_indices = [] - - for c in range(C): - channel_name = ( - self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" - ) - method = self.channel_reduction_methods.get(channel_name, "middle_slice") - - if method == "mean": - mean_indices.append(c) - elif method == "max": - max_indices.append(c) - else: # Default to middle_slice - middle_slice_indices.append(c) - - # Apply reductions - if middle_slice_indices: - indices = torch.tensor(middle_slice_indices, device=x.device) - slice_idx = self.middle_slice_index if self.middle_slice_index is not None else D // 2 - result[:, indices] = x[:, indices, slice_idx] - - if mean_indices: - indices = torch.tensor(mean_indices, device=x.device) - result[:, indices] = x[:, indices].mean(dim=2) - - if max_indices: - indices = torch.tensor(max_indices, device=x.device) - result[:, indices] = x[:, indices].max(dim=2)[0] - - return result + def _extract_features(self, pil_images): + """Extract features using DINOv3 model.""" + inputs = self.processor(pil_images, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + token_features = outputs.last_hidden_state + features = self._pool_features(token_features) + + return features def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: """ @@ -145,17 +94,15 @@ def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: pil_img = Image.fromarray(img_normalized, mode='L') elif img_tensor.shape[0] == 2: - # Two channels - create RGB with blend in blue img_array = img_tensor.cpu().numpy() rgb_array = np.zeros((img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8) - # Normalize each channel to 0-255 ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype(np.uint8) ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype(np.uint8) rgb_array[:, :, 0] = ch0_norm # Red rgb_array[:, :, 1] = ch1_norm # Green - rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue as blend + rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue pil_img = Image.fromarray(rgb_array, mode='RGB') @@ -202,221 +149,7 @@ def _pool_features(self, features: torch.Tensor) -> torch.Tensor: else: # mean pooling return features.mean(dim=1) - def predict_step(self, batch, batch_idx, dataloader_idx=0): - x = batch["anchor"] - - # Handle 5D input (B, C, D, H, W) - if x.dim() == 5: - x = self._reduce_5d_input(x) - - # Convert to PIL Images for DINOv3 processing - pil_images = self._convert_to_pil_images(x) - - # Batch process all images at once for better GPU utilization - inputs = self.processor(pil_images, return_tensors="pt") - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - token_features = outputs.last_hidden_state # (B, num_tokens, hidden_dim) - features = self._pool_features(token_features) # (B, hidden_dim) - - return { - "features": features, - "projections": torch.zeros((features.shape[0], 0), device=features.device), - "index": batch["index"], - } - - -def load_config(config_file): - with open(config_file, "r") as f: - config = yaml.safe_load(f) - return config - - -def load_normalization_from_config(norm_config): - class_path = norm_config["class_path"] - init_args = norm_config.get("init_args", {}) - - module_path, class_name = class_path.rsplit(".", 1) - - module = importlib.import_module(module_path) - - transform_class = getattr(module, class_name) - - return transform_class(**init_args) - - -@click.command() -@click.option( - "--config", - "-c", - type=click.Path(exists=True), - required=True, - help="Path to YAML configuration file", -) -def main(config): - """ - Extract DINOv3 embeddings and save to zarr format using VisCy Trainer. - - Parameters - ---------- - config : str or Path - Path to the YAML configuration file containing all parameters for - data loading, model configuration, and output settings. - """ - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - cfg = load_config(config) - logger.info(f"Loaded configuration from {config}") - - dm_params = {} - - if "paths" not in cfg: - raise ValueError("Configuration must contain a 'paths' section") - - if "data_path" not in cfg["paths"]: - raise ValueError( - "Data path is required in the configuration file (paths.data_path)" - ) - dm_params["data_path"] = cfg["paths"]["data_path"] - - if "tracks_path" not in cfg["paths"]: - raise ValueError( - "Tracks path is required in the configuration file (paths.tracks_path)" - ) - dm_params["tracks_path"] = cfg["paths"]["tracks_path"] - - if "datamodule" not in cfg: - raise ValueError("Configuration must contain a 'datamodule' section") - - if ( - "normalizations" not in cfg["datamodule"] - or not cfg["datamodule"]["normalizations"] - ): - raise ValueError( - "Normalizations are required in the configuration file (datamodule.normalizations)" - ) - - norm_configs = cfg["datamodule"]["normalizations"] - normalizations = [load_normalization_from_config(norm) for norm in norm_configs] - dm_params["normalizations"] = normalizations - - for param, value in cfg["datamodule"].items(): - if param != "normalizations": - # Handle patch sizes - if param == "patch_size": - dm_params["initial_yx_patch_size"] = value - dm_params["final_yx_patch_size"] = value - else: - dm_params[param] = value - - logger.info("Setting up data module") - class_path = cfg["datamodule_class"] - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - datamodule_class = getattr(module, class_name) - dm = datamodule_class(**dm_params) - - # Get model parameters - model_name = cfg["model"].get("model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m") - pooling_method = cfg["model"].get("pooling_method", "mean") - channel_reduction_methods = cfg["model"].get("channel_reduction_methods", {}) - channel_names = cfg["model"].get("channel_names", []) - middle_slice_index = cfg["model"].get("middle_slice_index", None) - - # Initialize DINOv3 model - logger.info(f"Loading DINOv3 model: {model_name}") - model = DINOv3Module( - model_name=model_name, - pooling_method=pooling_method, - channel_reduction_methods=channel_reduction_methods, - channel_names=channel_names, - middle_slice_index=middle_slice_index, - ) - - phate_kwargs = None - pca_kwargs = None - - if "embedding" in cfg: - if "phate_kwargs" in cfg["embedding"]: - phate_kwargs = cfg["embedding"]["phate_kwargs"] - if "pca_kwargs" in cfg["embedding"]: - pca_kwargs = cfg["embedding"]["pca_kwargs"] - - if "output_path" not in cfg["paths"]: - raise ValueError( - "Output path is required in the configuration file (paths.output_path)" - ) - - output_path = Path(cfg["paths"]["output_path"]) - output_dir = output_path.parent - output_dir.mkdir(parents=True, exist_ok=True) - - overwrite = False - if "execution" in cfg and "overwrite" in cfg["execution"]: - overwrite = cfg["execution"]["overwrite"] - elif output_path.exists(): - logger.warning(f"Output path {output_path} already exists, will overwrite") - overwrite = True - - embedding_writer = EmbeddingWriter( - output_path=output_path, - phate_kwargs=phate_kwargs, - pca_kwargs=pca_kwargs, - overwrite=overwrite, - ) - - logger.info("Setting up VisCy trainer") - trainer = VisCyTrainer( - accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, - callbacks=[embedding_writer], - inference_mode=True, - ) - - logger.info(f"Running prediction and saving to {output_path}") - trainer.predict(model, datamodule=dm) - - # Save configuration if requested - save_config_flag = True - show_config_flag = True - - if "execution" in cfg: - if "save_config" in cfg["execution"]: - save_config_flag = cfg["execution"]["save_config"] - if "show_config" in cfg["execution"]: - show_config_flag = cfg["execution"]["show_config"] - - # Save configuration if requested - if save_config_flag: - config_path = os.path.join(output_dir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(cfg, f, default_flow_style=False) - logger.info(f"Configuration saved to {config_path}") - - # Display configuration if requested - if show_config_flag: - click.echo("\nConfiguration used:") - click.echo("-" * 40) - for key, value in cfg.items(): - click.echo(f"{key}:") - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if isinstance(subvalue, list) and subkey == "normalizations": - click.echo(f" {subkey}:") - for norm in subvalue: - click.echo(f" - class_path: {norm['class_path']}") - click.echo(f" init_args: {norm['init_args']}") - else: - click.echo(f" {subkey}: {subvalue}") - else: - click.echo(f" {value}") - click.echo("-" * 40) - - logger.info("Done!") - if __name__ == "__main__": + main = create_embedding_cli(DINOv3Module, "DINOv3") main() \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml index 48b0e7c90..236381f38 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml +++ b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml @@ -2,8 +2,6 @@ # Paths section paths: - data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr - tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr output_path: "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_sec61b_n_phase_3.zarr" # Model configuration @@ -18,6 +16,8 @@ model: # Data module configuration datamodule_class: viscy.data.triplet.TripletDataModule datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr source_channel: - Phase3D - "raw GFP EX488 EM525-45" diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py index 4e0aa4343..8cabfa931 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py +++ b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py @@ -1,56 +1,26 @@ -""" -Generate embeddings using the OpenPhenom model and save them to a zarr store -using VisCy Trainer and EmbeddingWriter callback. -""" - -import importlib -import logging -import os -from pathlib import Path -from typing import Dict, List, Literal, Optional - -import click import torch -import yaml -from lightning.pytorch import LightningModule from transformers import AutoModel +from typing import Dict, List, Literal, Optional -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import EmbeddingWriter -from viscy.trainer import VisCyTrainer +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) + +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli -class OpenPhenomModule(LightningModule): +class OpenPhenomModule(BaseEmbeddingModule): def __init__( self, channel_reduction_methods: Optional[ Dict[str, Literal["middle_slice", "mean", "max"]] ] = None, channel_names: Optional[List[str]] = None, + middle_slice_index: Optional[int] = None, ): - """Initialize the OpenPhenom module. - - Parameters - ---------- - channel_reduction_methods : dict, optional - Dictionary mapping channel names to reduction methods: - - "middle_slice": Take the middle slice along the depth dimension - - "mean": Average across the depth dimension - - "max": Take the maximum value across the depth dimension - channel_names : list of str, optional - List of channel names corresponding to the input channels - - Notes - ----- - The module uses the OpenPhenom model from HuggingFace for generating embeddings. - """ - super().__init__() - - self.channel_reduction_methods = channel_reduction_methods or {} - self.channel_names = channel_names or [] - + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) + try: - torch.set_float32_matmul_precision("high") self.model = AutoModel.from_pretrained( "recursionpharma/OpenPhenom", trust_remote_code=True ) @@ -61,275 +31,39 @@ def __init__( "pip install transformers" ) + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) + dm_config = cfg.get("datamodule", {}) + + return cls( + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + channel_names=dm_config.get("source_channel", []), + ) + def on_predict_start(self): - # Move model to GPU when prediction starts + """Move model to GPU when prediction starts.""" self.model.to(self.device) - def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - - Args: - x: 5D input tensor - - Returns: - 4D tensor after applying reduction methods - """ - if x.dim() != 5: - return x - - B, C, D, H, W = x.shape - result = torch.zeros((B, C, H, W), device=x.device) - - # Apply reduction method for each channel - for c in range(C): - channel_name = ( - self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" - ) - # Default to middle slice if not specified - method = self.channel_reduction_methods.get(channel_name, "middle_slice") - - if method == "middle_slice": - result[:, c] = x[:, c, D // 2] - elif method == "mean": - result[:, c] = x[:, c].mean(dim=1) - elif method == "max": - result[:, c] = x[:, c].max(dim=1)[0] - else: - # Fallback to middle slice for unknown methods - result[:, c] = x[:, c, D // 2] - - return result - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - """Extract features from the input images. - - Returns: - Dictionary with features, projections (None), and index information - """ - x = batch["anchor"] - - # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W] - # If 5D input, handle according to specified reduction methods - if x.dim() == 5: - x = self._reduce_5d_input(x) - - # Convert to uint8 as OpenPhenom expects uint8 inputs + def _process_input(self, x: torch.Tensor): + """Convert to uint8 as OpenPhenom expects uint8 inputs.""" if x.dtype != torch.uint8: x = ( ((x - x.min()) / (x.max() - x.min()) * 255) .clamp(0, 255) .to(torch.uint8) ) + return x + def _extract_features(self, processed_input): + """Extract features using OpenPhenom model.""" # Get embeddings self.model.return_channelwise_embeddings = False - features = self.model.predict(x) - # Create empty projections tensor with same batch size as features - # This ensures the EmbeddingWriter can process it - projections = torch.zeros((features.shape[0], 0), device=features.device) - - return { - "features": features, - "projections": projections, - "index": batch["index"], - } - - -def load_config(config_file): - """Load configuration from a YAML file.""" - with open(config_file, "r") as f: - config = yaml.safe_load(f) - return config - - -def load_normalization_from_config(norm_config): - """Load a normalization transform from a configuration dictionary.""" - class_path = norm_config["class_path"] - init_args = norm_config.get("init_args", {}) - - # Split module and class name - module_path, class_name = class_path.rsplit(".", 1) - - # Import the module - module = importlib.import_module(module_path) - - # Get the class - transform_class = getattr(module, class_name) - - # Instantiate the transform - return transform_class(**init_args) - - -@click.command() -@click.option( - "--config", - "-c", - type=click.Path(exists=True), - required=True, - help="Path to YAML configuration file", -) -def main(config): - """Extract OpenPhenom embeddings and save to zarr format using VisCy Trainer.""" - # Configure logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - # Load config file - cfg = load_config(config) - logger.info(f"Loaded configuration from {config}") - - # Prepare datamodule parameters - dm_params = {} - - # Add data and tracks paths from the paths section - if "paths" not in cfg: - raise ValueError("Configuration must contain a 'paths' section") - - if "data_path" not in cfg["paths"]: - raise ValueError( - "Data path is required in the configuration file (paths.data_path)" - ) - dm_params["data_path"] = cfg["paths"]["data_path"] - - if "tracks_path" not in cfg["paths"]: - raise ValueError( - "Tracks path is required in the configuration file (paths.tracks_path)" - ) - dm_params["tracks_path"] = cfg["paths"]["tracks_path"] - - # Add datamodule parameters - if "datamodule" not in cfg: - raise ValueError("Configuration must contain a 'datamodule' section") - - # Prepare normalizations - if ( - "normalizations" not in cfg["datamodule"] - or not cfg["datamodule"]["normalizations"] - ): - raise ValueError( - "Normalizations are required in the configuration file (datamodule.normalizations)" - ) - - norm_configs = cfg["datamodule"]["normalizations"] - normalizations = [load_normalization_from_config(norm) for norm in norm_configs] - dm_params["normalizations"] = normalizations - - # Copy all other datamodule parameters - for param, value in cfg["datamodule"].items(): - if param != "normalizations": - # Handle patch sizes - if param == "patch_size": - dm_params["initial_yx_patch_size"] = value - dm_params["final_yx_patch_size"] = value - else: - dm_params[param] = value - - # Set up the data module - logger.info("Setting up data module") - - class_path = cfg["datamodule_class"] - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - datamodule_class = getattr(module, class_name) - dm = datamodule_class(**dm_params) - - # Get model parameters for handling 5D inputs - channel_reduction_methods = {} - - if "model" in cfg and "channel_reduction_methods" in cfg["model"]: - channel_reduction_methods = cfg["model"]["channel_reduction_methods"] - - # Initialize OpenPhenom model with reduction settings - logger.info("Loading OpenPhenom model") - model = OpenPhenomModule( - channel_reduction_methods=channel_reduction_methods, - channel_names=dm_params.get("source_channel", []), - ) - - # Get dimensionality reduction parameters from config - phate_kwargs = None - pca_kwargs = None - - if "embedding" in cfg: - if "phate_kwargs" in cfg["embedding"]: - phate_kwargs = cfg["embedding"]["phate_kwargs"] - if "pca_kwargs" in cfg["embedding"]: - pca_kwargs = cfg["embedding"]["pca_kwargs"] - # Check if output path exists and should be overwritten - if "output_path" not in cfg["paths"]: - raise ValueError( - "Output path is required in the configuration file (paths.output_path)" - ) - - output_path = Path(cfg["paths"]["output_path"]) - output_dir = output_path.parent - output_dir.mkdir(parents=True, exist_ok=True) - - overwrite = False - if "execution" in cfg and "overwrite" in cfg["execution"]: - overwrite = cfg["execution"]["overwrite"] - elif output_path.exists(): - logger.warning(f"Output path {output_path} already exists, will overwrite") - overwrite = True - - # Set up EmbeddingWriter callback - embedding_writer = EmbeddingWriter( - output_path=output_path, - phate_kwargs=phate_kwargs, - pca_kwargs=pca_kwargs, - overwrite=overwrite, - ) - - # Set up and run VisCy trainer - logger.info("Setting up VisCy trainer") - trainer = VisCyTrainer( - accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, - callbacks=[embedding_writer], - inference_mode=True, - ) - - logger.info(f"Running prediction and saving to {output_path}") - trainer.predict(model, datamodule=dm) - - # Save configuration if requested - save_config_flag = True - show_config_flag = True - - if "execution" in cfg: - if "save_config" in cfg["execution"]: - save_config_flag = cfg["execution"]["save_config"] - if "show_config" in cfg["execution"]: - show_config_flag = cfg["execution"]["show_config"] - - # Save configuration if requested - if save_config_flag: - config_path = os.path.join(output_dir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(cfg, f, default_flow_style=False) - logger.info(f"Configuration saved to {config_path}") - - # Display configuration if requested - if show_config_flag: - click.echo("\nConfiguration used:") - click.echo("-" * 40) - for key, value in cfg.items(): - click.echo(f"{key}:") - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if isinstance(subvalue, list) and subkey == "normalizations": - click.echo(f" {subkey}:") - for norm in subvalue: - click.echo(f" - class_path: {norm['class_path']}") - click.echo(f" init_args: {norm['init_args']}") - else: - click.echo(f" {subkey}: {subvalue}") - else: - click.echo(f" {value}") - click.echo("-" * 40) - - logger.info("Done!") + features = self.model.predict(processed_input) + return features if __name__ == "__main__": - main() + main = create_embedding_cli(OpenPhenomModule, "OpenPhenom") + main() \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml index 89bbb5202..5e3771d25 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml @@ -1,5 +1,7 @@ datamodule_class: viscy.data.triplet.TripletDataModule datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr batch_size: 32 final_yx_patch_size: - 192 @@ -55,6 +57,4 @@ model: Phase3D: middle_slice raw GFP EX488 EM525-45: max paths: - data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr output_path: /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sec61b_n_phase_all_highresfeats0.zarr - tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py index 4a4d6d278..07f2dda70 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -1,22 +1,16 @@ -import importlib -import logging -import os -from pathlib import Path -from typing import Dict, List, Literal, Optional - -import click import torch -import yaml -from lightning.pytorch import LightningModule from sam2.sam2_image_predictor import SAM2ImagePredictor from skimage.exposure import rescale_intensity +from typing import Dict, List, Literal, Optional + +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import EmbeddingWriter -from viscy.trainer import VisCyTrainer +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli -class SAM2Module(LightningModule): +class SAM2Module(BaseEmbeddingModule): def __init__( self, model_name: str = "facebook/sam2-hiera-base-plus", @@ -26,100 +20,38 @@ def __init__( channel_names: Optional[List[str]] = None, middle_slice_index: Optional[int] = None, ): - """ - SAM2 module for feature extraction. - - Parameters - ---------- - model_name : str, optional - SAM2 model name from HuggingFace Model Hub (default: "facebook/sam2-hiera-base-plus"). - channel_reduction_methods : dict[str, {"middle_slice", "mean", "max"}], optional - Dictionary mapping channel names to reduction methods for 5D inputs (default: None, uses "middle_slice"). - channel_names : list of str, optional - List of channel names corresponding to input channels (default: None). - middle_slice_index : int, optional - Specific z-slice index to use for "middle_slice" reduction (default: None, uses D//2). - - """ - super().__init__() + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) self.model_name = model_name - self.channel_reduction_methods = channel_reduction_methods or {} - self.channel_names = channel_names or [] - self.middle_slice_index = middle_slice_index - - torch.set_float32_matmul_precision("high") self.model = None # Initialize in on_predict_start when device is set - def on_predict_start(self): - """ - Initialize model with proper device when prediction starts. + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) - Notes - ----- - This method is called automatically by Lightning when prediction begins. - It ensures the SAM2 model is properly initialized on the correct device. - """ + return cls( + model_name=model_config.get("model_name", "facebook/sam2-hiera-base-plus"), + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + middle_slice_index=model_config.get("middle_slice_index", None), + ) + + def on_predict_start(self): + """Initialize model with proper device when prediction starts.""" if self.model is None: self.model = SAM2ImagePredictor.from_pretrained( self.model_name, device=self.device ) - def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """ - Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - - Parameters - ---------- - x : torch.Tensor - 5D input tensor with shape (B, C, D, H, W). - - Returns - ------- - torch.Tensor - 4D tensor after applying reduction methods with shape (B, C, H, W). - """ - if x.dim() != 5: - return x - - B, C, D, H, W = x.shape - result = torch.zeros((B, C, H, W), device=x.device) - - # Process all channels at once for each reduction method to minimize loops - middle_slice_indices = [] - mean_indices = [] - max_indices = [] - - # Group channels by reduction method - for c in range(C): - channel_name = ( - self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" - ) - method = self.channel_reduction_methods.get(channel_name, "middle_slice") - - if method == "mean": - mean_indices.append(c) - elif method == "max": - max_indices.append(c) - else: # Default to middle_slice for any unknown method - middle_slice_indices.append(c) - - # Apply middle_slice reduction to all relevant channels at once - if middle_slice_indices: - indices = torch.tensor(middle_slice_indices, device=x.device) - slice_idx = self.middle_slice_index if self.middle_slice_index is not None else D // 2 - result[:, indices] = x[:, indices, slice_idx] - - # Apply mean reduction to all relevant channels at once - if mean_indices: - indices = torch.tensor(mean_indices, device=x.device) - result[:, indices] = x[:, indices].mean(dim=2) - - # Apply max reduction to all relevant channels at once - if max_indices: - indices = torch.tensor(max_indices, device=x.device) - result[:, indices] = x[:, indices].max(dim=2)[0] + def _process_input(self, x: torch.Tensor): + """Convert input tensor to 3-channel RGB format as needed for SAM2.""" + return self._convert_to_rgb(x) - return result + def _extract_features(self, image_list): + """Extract features using SAM2 model.""" + self.model.set_image_batch(image_list) + # Extract high-resolution features and apply global average pooling + features = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) + return features def _convert_to_rgb(self, x: torch.Tensor) -> list: """ @@ -148,6 +80,7 @@ def _convert_to_rgb(self, x: torch.Tensor) -> list: x_3ch[:, 0] = x[:, 0] x_3ch[:, 1] = x[:, 1] x_3ch[:, 2] = 0.5 * (x[:, 0] + x[:, 1]) # B channel as blend + x_rgb = x_3ch elif x.shape[1] == 3: x_rgb = rescale_intensity(x, out_range="uint8") @@ -161,284 +94,7 @@ def _convert_to_rgb(self, x: torch.Tensor) -> list: x_rgb[i].cpu().numpy().transpose(1, 2, 0) for i in range(x_rgb.shape[0]) ] - def predict_step(self, batch, batch_idx, dataloader_idx=0): - """ - Extract features from the input images. - - Parameters - ---------- - batch : dict - Batch dictionary containing "anchor" key with input tensors. - batch_idx : int - Index of the current batch. - dataloader_idx : int, optional - Index of the dataloader (default: 0). - - Returns - ------- - dict - Dictionary containing: - - "features": Extracted features tensor - - "projections": Empty tensor for compatibility (B, 0) - - "index": Batch index information - """ - x = batch["anchor"] - - # Handle 5D input (B, C, D, H, W) using configured reduction methods - if x.dim() == 5: - x = self._reduce_5d_input(x) - - # Convert input to RGB format and get list of numpy arrays in HWC format for SAM2 - image_list = self._convert_to_rgb(x) - self.model.set_image_batch(image_list) - - # Extract features - # features_0 = self.model._features["image_embed"].mean(dim=(2, 3)) - # features_1 = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) - # features_2 = self.model._features["high_res_feats"][1].mean(dim=(2, 3)) - # features = torch.concat([features_0, features_1, features_2], dim=1) - features = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) - - # Return features and empty projections with correct batch dimension - return { - "features": features, - "projections": torch.zeros((features.shape[0], 0), device=features.device), - "index": batch["index"], - } - - -def load_config(config_file): - """ - Load configuration from a YAML file. - - Parameters - ---------- - config_file : str or Path - Path to the YAML configuration file. - - Returns - ------- - dict - Configuration dictionary loaded from the YAML file. - """ - with open(config_file, "r") as f: - config = yaml.safe_load(f) - return config - - -def load_normalization_from_config(norm_config): - """ - Load a normalization transform from a configuration dictionary. - - Parameters - ---------- - norm_config : dict - Configuration dictionary containing "class_path" and optional "init_args". - - Returns - ------- - object - Instantiated normalization transform object. - """ - class_path = norm_config["class_path"] - init_args = norm_config.get("init_args", {}) - - # Split module and class name - module_path, class_name = class_path.rsplit(".", 1) - - # Import the module - module = importlib.import_module(module_path) - - # Get the class - transform_class = getattr(module, class_name) - - # Instantiate the transform - return transform_class(**init_args) - - -@click.command() -@click.option( - "--config", - "-c", - type=click.Path(exists=True), - required=True, - help="Path to YAML configuration file", -) -def main(config): - """ - Extract SAM2 embeddings and save to zarr format using VisCy Trainer. - - Parameters - ---------- - config : str or Path - Path to the YAML configuration file containing all parameters for - data loading, model configuration, and output settings. - """ - # Configure logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - # Load config file - cfg = load_config(config) - logger.info(f"Loaded configuration from {config}") - - # Prepare datamodule parameters - dm_params = {} - - # Add data and tracks paths from the paths section - if "paths" not in cfg: - raise ValueError("Configuration must contain a 'paths' section") - - if "data_path" not in cfg["paths"]: - raise ValueError( - "Data path is required in the configuration file (paths.data_path)" - ) - dm_params["data_path"] = cfg["paths"]["data_path"] - - if "tracks_path" not in cfg["paths"]: - raise ValueError( - "Tracks path is required in the configuration file (paths.tracks_path)" - ) - dm_params["tracks_path"] = cfg["paths"]["tracks_path"] - - # Add datamodule parameters - if "datamodule" not in cfg: - raise ValueError("Configuration must contain a 'datamodule' section") - - # Prepare normalizations - if ( - "normalizations" not in cfg["datamodule"] - or not cfg["datamodule"]["normalizations"] - ): - raise ValueError( - "Normalizations are required in the configuration file (datamodule.normalizations)" - ) - - norm_configs = cfg["datamodule"]["normalizations"] - normalizations = [load_normalization_from_config(norm) for norm in norm_configs] - dm_params["normalizations"] = normalizations - - # Copy all other datamodule parameters - for param, value in cfg["datamodule"].items(): - if param != "normalizations": - # Handle patch sizes - if param == "patch_size": - dm_params["initial_yx_patch_size"] = value - dm_params["final_yx_patch_size"] = value - else: - dm_params[param] = value - - # Set up the data module - logger.info("Setting up data module") - - class_path = cfg["datamodule_class"] - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - datamodule_class = getattr(module, class_name) - dm = datamodule_class(**dm_params) - - # Get model parameters for handling 5D inputs - channel_reduction_methods = {} - middle_slice_index = None - - if "model" in cfg: - if "channel_reduction_methods" in cfg["model"]: - channel_reduction_methods = cfg["model"]["channel_reduction_methods"] - if "middle_slice_index" in cfg["model"]: - middle_slice_index = cfg["model"]["middle_slice_index"] - - # Initialize SAM2 model with reduction settings - logger.info("Loading SAM2 model") - model = SAM2Module( - model_name=cfg["model"]["model_name"], - channel_reduction_methods=channel_reduction_methods, - middle_slice_index=middle_slice_index, - ) - - # Get dimensionality reduction parameters from config - phate_kwargs = None - pca_kwargs = None - - if "embedding" in cfg: - if "phate_kwargs" in cfg["embedding"]: - phate_kwargs = cfg["embedding"]["phate_kwargs"] - if "pca_kwargs" in cfg["embedding"]: - pca_kwargs = cfg["embedding"]["pca_kwargs"] - # Check if output path exists and should be overwritten - if "output_path" not in cfg["paths"]: - raise ValueError( - "Output path is required in the configuration file (paths.output_path)" - ) - - output_path = Path(cfg["paths"]["output_path"]) - output_dir = output_path.parent - output_dir.mkdir(parents=True, exist_ok=True) - - overwrite = False - if "execution" in cfg and "overwrite" in cfg["execution"]: - overwrite = cfg["execution"]["overwrite"] - elif output_path.exists(): - logger.warning(f"Output path {output_path} already exists, will overwrite") - overwrite = True - - # Set up EmbeddingWriter callback - embedding_writer = EmbeddingWriter( - output_path=output_path, - phate_kwargs=phate_kwargs, - pca_kwargs=pca_kwargs, - overwrite=overwrite, - ) - - # Set up and run VisCy trainer - logger.info("Setting up VisCy trainer") - trainer = VisCyTrainer( - accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, - callbacks=[embedding_writer], - inference_mode=True, - ) - - logger.info(f"Running prediction and saving to {output_path}") - trainer.predict(model, datamodule=dm) - - # Save configuration if requested - save_config_flag = True - show_config_flag = True - - if "execution" in cfg: - if "save_config" in cfg["execution"]: - save_config_flag = cfg["execution"]["save_config"] - if "show_config" in cfg["execution"]: - show_config_flag = cfg["execution"]["show_config"] - - # Save configuration if requested - if save_config_flag: - config_path = os.path.join(output_dir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(cfg, f, default_flow_style=False) - logger.info(f"Configuration saved to {config_path}") - - # Display configuration if requested - if show_config_flag: - click.echo("\nConfiguration used:") - click.echo("-" * 40) - for key, value in cfg.items(): - click.echo(f"{key}:") - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if isinstance(subvalue, list) and subkey == "normalizations": - click.echo(f" {subkey}:") - for norm in subvalue: - click.echo(f" - class_path: {norm['class_path']}") - click.echo(f" init_args: {norm['init_args']}") - else: - click.echo(f" {subkey}: {subvalue}") - else: - click.echo(f" {value}") - click.echo("-" * 40) - - logger.info("Done!") - if __name__ == "__main__": - main() + main = create_embedding_cli(SAM2Module, "SAM2") + main() \ No newline at end of file From 7807d9ab556e8fd98641efa9497846aae332c7f5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 22 Oct 2025 16:49:09 -0700 Subject: [PATCH 085/101] cleanup vae --- viscy/representation/vae.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index e91f1e479..4161aed51 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -129,15 +129,14 @@ def __init__( features_only=True, drop_path_rate=drop_path_rate, ) + num_channels = encoder.feature_info.channels() + in_channels_encoder = num_channels[0] + out_channels_encoder = num_channels[-1] if "convnext" in backbone: num_channels = encoder.feature_info.channels() - in_channels_encoder = num_channels[0] encoder.stem_0 = nn.Identity() - out_channels_encoder = num_channels[-1] elif "resnet" in backbone: - num_channels = encoder.feature_info.channels() - in_channels_encoder = num_channels[0] encoder.conv1 = nn.Identity() out_channels_encoder = num_channels[-1] else: @@ -230,7 +229,6 @@ def __init__( self.latent_dim = latent_dim self.out_channels = out_channels self.out_stack_depth = out_stack_depth - self.decoder_channels = decoder_channels self.spatial_size = encoder_spatial_size self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) @@ -325,13 +323,15 @@ def __init__( drop_path_rate=drop_path_rate, ) - decoder_channels = self.encoder.num_channels.copy() - decoder_channels.reverse() - decoder_channels[-1] = ( + base_channels = self.encoder.num_channels[-1] + decoder_channels = [base_channels] + for i in range(decoder_stages - 1): + decoder_channels.append(base_channels // (2 ** (i + 1))) + decoder_channels.append( (out_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio ) - strides = [2] * (len(decoder_channels) - 1) + [1] + strides = [2] * decoder_stages + [1] self.decoder = VaeDecoder( decoder_channels=decoder_channels, @@ -376,7 +376,7 @@ def __init__( up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, use_sigmoid: bool = False, - norm: Literal[Norm.BATCH, Norm.INSTANCE] = Norm.INSTANCE, + norm: Literal["batch", "instance"] = "instance", **kwargs, ): super().__init__() @@ -392,6 +392,12 @@ def __init__( self.num_res_units = num_res_units self.use_sigmoid = use_sigmoid self.norm = norm + if self.norm not in ["batch", "instance"]: + raise ValueError("norm must be 'batch' or 'instance'") + if self.norm == "batch": + self.norm = Norm.BATCH + else: + self.norm = Norm.INSTANCE self.model = VarAutoEncoder( spatial_dims=self.spatial_dims, From f06e0a3becdc730914164c954dd0db9cdc3c709b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 22 Oct 2025 16:52:58 -0700 Subject: [PATCH 086/101] keeping it consistent and using residual units --- viscy/representation/vae.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 4161aed51..0712c55ee 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -68,21 +68,13 @@ def __init__( for i in range(conv_blocks): block_out_channels = out_channels - conv_layers.extend( - [ - nn.Conv2d( - current_channels, - block_out_channels, - kernel_size=3, - padding=1, - ), - ( - nn.BatchNorm2d(block_out_channels) - if norm_name == "batch" - else nn.InstanceNorm2d(block_out_channels) - ), - nn.ReLU(inplace=True), - ] + conv_layers.append( + ResidualUnit( + spatial_dims=spatial_dims, + in_channels=current_channels, + out_channels=block_out_channels, + norm=norm_name, + ) ) current_channels = block_out_channels From 35c9f756448867b576c4c36fc5b493d74f834c37 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 22 Oct 2025 17:15:14 -0700 Subject: [PATCH 087/101] fix typings betavaemonai --- viscy/representation/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index 0712c55ee..a7cd1a51e 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -363,7 +363,7 @@ def __init__( out_channels: int, latent_size: int, channels: Sequence[int], - strides: Sequence[int], + strides: Sequence[int] | Sequence[Sequence[int]], kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, From eed55f3e30cad55943d3138e7aef087081d26d79 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 27 Oct 2025 17:36:18 -0700 Subject: [PATCH 088/101] update smoothness to handle adata --- viscy/representation/evaluation/smoothness.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 869cd4baa..bef94fddc 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Literal +import anndata as ad import numpy as np import pandas as pd from numpy.typing import NDArray @@ -8,7 +9,6 @@ from scipy.stats import gaussian_kde from sklearn.preprocessing import StandardScaler -from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.clustering import ( compare_time_offset, pairwise_distance_matrix, @@ -108,7 +108,7 @@ def find_distribution_peak( def compute_embeddings_smoothness( - prediction_path: Path, + features_ad: ad.AnnData, distance_metric: Literal["cosine", "euclidean"] = "cosine", verbose: bool = False, ) -> tuple[dict, dict, list[list[float]]]: @@ -116,8 +116,8 @@ def compute_embeddings_smoothness( Compute the smoothness statistics of embeddings Parameters - ---------- - prediction_path: Path to the embedding dataset + -------- + features_ad: adAnnData distance_metric: Distance metric to use, by default "cosine" Returns: @@ -139,18 +139,15 @@ def compute_embeddings_smoothness( piecewise_distance_per_track: list[list[float]] Piece-wise distance per track """ - - # Read the dataset - embeddings = read_embedding_dataset(prediction_path) - features = embeddings["features"] - scaled_features = StandardScaler().fit_transform(features.values) + features = features_ad.X + scaled_features = StandardScaler().fit_transform(features) # Compute the distance matrix cross_dist = pairwise_distance_matrix(scaled_features, metric=distance_metric) rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) # Compute piece-wise distance and rank difference - features_df = features["sample"].to_dataframe().reset_index(drop=True) + features_df = features_ad.obs.reset_index(drop=True) piecewise_distance_per_track, _ = compute_piece_wise_distance( features_df, cross_dist, rank_fractions ) From 4674e91e9c37082239eacecb52492321102fafbd Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 27 Oct 2025 21:27:32 -0700 Subject: [PATCH 089/101] update clustering method and add test --- .../evaluation/test_clustering.py | 181 ++++++++++++++++++ viscy/representation/evaluation/clustering.py | 54 +++++- 2 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 tests/representation/evaluation/test_clustering.py diff --git a/tests/representation/evaluation/test_clustering.py b/tests/representation/evaluation/test_clustering.py new file mode 100644 index 000000000..9b91a065b --- /dev/null +++ b/tests/representation/evaluation/test_clustering.py @@ -0,0 +1,181 @@ +import numpy as np +import pytest +from numpy.typing import NDArray + +from viscy.representation.evaluation.clustering import pairwise_distance_matrix + + +@pytest.fixture +def sample_features(): + """Create sample features for testing.""" + np.random.seed(42) + return np.random.randn(50, 128).astype(np.float64) + + +@pytest.fixture +def small_features(): + """Create small sample with known values for numerical testing.""" + return np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]) + + +class TestPairwiseDistanceMatrix: + """Tests for pairwise_distance_matrix function.""" + + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + def test_scipy_baseline(self, sample_features: NDArray, metric: str): + """Test that scipy backend produces valid distance matrices.""" + dist_matrix = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + + # Check shape + n = len(sample_features) + assert dist_matrix.shape == (n, n) + + # Check symmetry + assert np.allclose(dist_matrix, dist_matrix.T) + + # Check diagonal is zero (or near zero for numerical precision) + assert np.allclose(np.diag(dist_matrix), 0, atol=1e-10) + + # Check all distances are non-negative + assert np.all(dist_matrix >= 0) + + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + @pytest.mark.parametrize("device", ["cpu", "auto"]) + def test_torch_vs_scipy(self, sample_features: NDArray, metric: str, device: str): + """Test that PyTorch implementation matches scipy results.""" + pytest.importorskip("torch") + + dist_scipy = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + dist_torch = pairwise_distance_matrix( + sample_features, metric=metric, device=device + ) + + # Check numerical agreement + assert np.allclose(dist_scipy, dist_torch, rtol=1e-5, atol=1e-6) + + @pytest.mark.skipif( + not pytest.importorskip("torch").cuda.is_available(), + reason="CUDA not available", + ) + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + def test_gpu_vs_scipy(self, sample_features: NDArray, metric: str): + """Test that GPU implementation matches scipy results.""" + dist_scipy = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + dist_gpu = pairwise_distance_matrix( + sample_features, metric=metric, device="cuda" + ) + + # Check numerical agreement + assert np.allclose(dist_scipy, dist_gpu, rtol=1e-5, atol=1e-6) + + def test_cosine_distance_known_values(self, small_features: NDArray): + """Test cosine distance with known values.""" + dist_matrix = pairwise_distance_matrix( + small_features, metric="cosine", device="scipy" + ) + + # [1,0] and [0,1] are orthogonal: cosine distance = 1 + assert np.isclose(dist_matrix[0, 1], 1.0, atol=1e-10) + + # [1,1] and [0.5, 0.5] are parallel: cosine distance = 0 + assert np.isclose(dist_matrix[2, 3], 0.0, atol=1e-10) + + # [1,0] and [1,1]: cosine similarity = 1/sqrt(2), distance = 1 - 1/sqrt(2) + expected = 1 - 1 / np.sqrt(2) + assert np.isclose(dist_matrix[0, 2], expected, atol=1e-10) + + def test_euclidean_distance_known_values(self, small_features: NDArray): + """Test euclidean distance with known values.""" + dist_matrix = pairwise_distance_matrix( + small_features, metric="euclidean", device="scipy" + ) + + # Distance between [1,0] and [0,1] is sqrt(2) + assert np.isclose(dist_matrix[0, 1], np.sqrt(2), atol=1e-10) + + # Distance between [1,1] and [0.5, 0.5] is sqrt(0.5) + assert np.isclose(dist_matrix[2, 3], np.sqrt(0.5), atol=1e-10) + + def test_unsupported_metric_falls_back_to_scipy(self, sample_features: NDArray): + """Test that unsupported metrics fall back to scipy.""" + # These metrics are only supported by scipy, not PyTorch + dist_matrix = pairwise_distance_matrix( + sample_features, metric="cityblock", device="auto" + ) + + # Should still produce valid results via scipy fallback + n = len(sample_features) + assert dist_matrix.shape == (n, n) + assert np.allclose(dist_matrix, dist_matrix.T) + + def test_device_options(self, sample_features: NDArray): + """Test various device options.""" + # Test scipy explicitly + dist_scipy = pairwise_distance_matrix( + sample_features, metric="cosine", device="scipy" + ) + assert dist_scipy is not None + + # Test None as scipy + dist_none = pairwise_distance_matrix( + sample_features, metric="cosine", device=None + ) + assert np.allclose(dist_scipy, dist_none) + + @pytest.mark.skipif( + not pytest.importorskip("torch").cuda.is_available(), + reason="CUDA not available", + ) + def test_cuda_aliases(self, sample_features: NDArray): + """Test that cuda and gpu device names work.""" + dist_cuda = pairwise_distance_matrix( + sample_features, metric="cosine", device="cuda" + ) + dist_gpu = pairwise_distance_matrix( + sample_features, metric="cosine", device="gpu" + ) + + assert np.allclose(dist_cuda, dist_gpu) + + def test_invalid_device_raises_error(self, sample_features: NDArray): + """Test that invalid device names raise appropriate errors.""" + pytest.importorskip("torch") + + with pytest.raises(ValueError, match="Invalid device"): + pairwise_distance_matrix( + sample_features, metric="cosine", device="invalid_device" + ) + + def test_float32_input_preserves_precision(self): + """Test that float32 input is converted to float64 for precision.""" + torch = pytest.importorskip("torch") + + features_f32 = np.random.randn(10, 32).astype(np.float32) + + dist_scipy = pairwise_distance_matrix( + features_f32, metric="cosine", device="scipy" + ) + dist_torch = pairwise_distance_matrix( + features_f32, metric="cosine", device="cpu" + ) + + # Should still have good agreement despite float32 input + assert np.allclose(dist_scipy, dist_torch, rtol=1e-5, atol=1e-6) + + def test_large_matrix_shape(self): + """Test with larger matrix to ensure it works at scale.""" + large_features = np.random.randn(500, 64).astype(np.float64) + + dist_matrix = pairwise_distance_matrix( + large_features, metric="cosine", device="auto" + ) + + assert dist_matrix.shape == (500, 500) + assert np.allclose(dist_matrix, dist_matrix.T) + assert np.allclose(np.diag(dist_matrix), 0, atol=1e-6) diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..8f58ef0b5 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -33,22 +33,72 @@ def knn_accuracy(embeddings, annotations, k=5): return accuracy -def pairwise_distance_matrix(features: ArrayLike, metric: str = "cosine") -> NDArray: +def pairwise_distance_matrix( + features: ArrayLike, metric: str = "cosine", device: str = "auto" +) -> NDArray: """Compute pairwise distances between all samples in the feature matrix. + Uses PyTorch with GPU acceleration when available for significant speedup. + Falls back to scipy for unsupported metrics or when PyTorch is unavailable. + Parameters ---------- features : ArrayLike Feature matrix (n_samples, n_features) metric : str, optional Distance metric to use, by default "cosine" + Supports "cosine" and "euclidean" with PyTorch acceleration. + Other scipy metrics will use scipy fallback. + device : str, optional + Device to use for computation, by default "auto" + - "auto": automatically use GPU if available, otherwise CPU + - "cuda" or "gpu": force GPU usage + - "cpu": force CPU usage + - None or "scipy": force scipy fallback Returns ------- NDArray Distance matrix of shape (n_samples, n_samples) """ - return cdist(features, features, metric=metric) + if device in (None, "scipy") or metric not in ("cosine", "euclidean"): + return cdist(features, features, metric=metric) + + try: + import torch + + if device == "auto": + device_torch = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif device in ("cuda", "gpu"): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + device_torch = torch.device("cuda") + elif device == "cpu": + device_torch = torch.device("cpu") + else: + raise ValueError( + f"Invalid device: {device}. Use 'auto', 'cuda', 'cpu', or 'scipy'" + ) + features_array = np.asarray(features) + if features_array.dtype == np.float32: + features_tensor = torch.from_numpy(features_array).double().to(device_torch) + else: + features_tensor = torch.from_numpy(features_array).to(device_torch) + if features_tensor.dtype not in (torch.float32, torch.float64): + features_tensor = features_tensor.double() + + if metric == "cosine": + features_norm = torch.nn.functional.normalize(features_tensor, p=2, dim=1) + similarity = features_norm @ features_norm.T + distances = 1 - similarity + elif metric == "euclidean": + distances = torch.cdist(features_tensor, features_tensor, p=2) + return distances.cpu().numpy() + + except ImportError: + return cdist(features, features, metric=metric) + except (RuntimeError, torch.cuda.OutOfMemoryError): + return cdist(features, features, metric=metric) def rank_nearest_neighbors( From b4a639892018a1b91ff9f33101b65383cb382bd3 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 28 Oct 2025 14:51:06 -0700 Subject: [PATCH 090/101] pre-commit --- .../DynaCLR/DINOV3/dinov3_embeddings.py | 75 +++++++++++-------- .../benchmarking/DynaCLR/SAM2/run_sam2.sh | 2 +- .../DynaCLR/SAM2/sam2_embeddings.py | 11 +-- .../DynaCLR/SAM2/sam2_visualizations.py | 31 +++++--- .../rpe1_fucci/linear_classifier.py | 12 +-- .../evaluation/rpe1_fucci/phate_plot.py | 9 +-- .../smoothness/compute_smoothness.py | 46 ++++++------ .../figures/grad_attr.py | 30 +++----- .../evaluation/compare_dtw_embeddings_sam2.py | 4 - .../evaluation/test_clustering.py | 2 +- viscy/representation/evaluation/smoothness.py | 1 - 11 files changed, 120 insertions(+), 103 deletions(-) diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py index 49cdb4064..38164e523 100644 --- a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -27,7 +27,7 @@ def __init__( super().__init__(channel_reduction_methods, channel_names, middle_slice_index) self.model_name = model_name self.pooling_method = pooling_method - + self.model = None self.processor = None @@ -36,7 +36,9 @@ def from_config(cls, cfg): """Create model instance from configuration.""" model_config = cfg.get("model", {}) return cls( - model_name=model_config.get("model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m"), + model_name=model_config.get( + "model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m" + ), pooling_method=model_config.get("pooling_method", "mean"), channel_reduction_methods=model_config.get("channel_reduction_methods", {}), channel_names=model_config.get("channel_names", []), @@ -58,12 +60,12 @@ def _extract_features(self, pil_images): """Extract features using DINOv3 model.""" inputs = self.processor(pil_images, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} - + with torch.no_grad(): outputs = self.model(**inputs) token_features = outputs.last_hidden_state features = self._pool_features(token_features) - + return features def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: @@ -81,56 +83,69 @@ def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: List of PIL Images ready for DINOv3 processing. """ images = [] - + for b in range(x.shape[0]): img_tensor = x[b] # (C, H, W) - + if img_tensor.shape[0] == 1: # Single channel - convert to grayscale PIL img_array = img_tensor[0].cpu().numpy() # Normalize to 0-255 - img_normalized = ((img_array - img_array.min()) / - (img_array.max() - img_array.min()) * 255).astype(np.uint8) - pil_img = Image.fromarray(img_normalized, mode='L') - + img_normalized = ( + (img_array - img_array.min()) + / (img_array.max() - img_array.min()) + * 255 + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="L") + elif img_tensor.shape[0] == 2: img_array = img_tensor.cpu().numpy() - rgb_array = np.zeros((img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8) - - ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype(np.uint8) - ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype(np.uint8) - + rgb_array = np.zeros( + (img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8 + ) + + ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype( + np.uint8 + ) + ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype( + np.uint8 + ) + rgb_array[:, :, 0] = ch0_norm # Red - rgb_array[:, :, 1] = ch1_norm # Green + rgb_array[:, :, 1] = ch1_norm # Green rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue - - pil_img = Image.fromarray(rgb_array, mode='RGB') - + + pil_img = Image.fromarray(rgb_array, mode="RGB") + elif img_tensor.shape[0] == 3: # Three channels - direct RGB img_array = img_tensor.cpu().numpy().transpose(1, 2, 0) # HWC - img_normalized = rescale_intensity(img_array, out_range=(0, 255)).astype(np.uint8) - pil_img = Image.fromarray(img_normalized, mode='RGB') - + img_normalized = rescale_intensity( + img_array, out_range=(0, 255) + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="RGB") + else: # More than 3 channels - use first 3 img_array = img_tensor[:3].cpu().numpy().transpose(1, 2, 0) # HWC - img_normalized = rescale_intensity(img_array, out_range=(0, 255)).astype(np.uint8) - pil_img = Image.fromarray(img_normalized, mode='RGB') - + img_normalized = rescale_intensity( + img_array, out_range=(0, 255) + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="RGB") + images.append(pil_img) - + return images def _pool_features(self, features: torch.Tensor) -> torch.Tensor: """ Pool spatial features from DINOv3 tokens. - + Parameters ---------- features : torch.Tensor Token features with shape (B, num_tokens, hidden_dim). - + Returns ------- torch.Tensor @@ -143,7 +158,7 @@ def _pool_features(self, features: torch.Tensor) -> torch.Tensor: else: # For ConvNeXt, no CLS token, fall back to mean return features.mean(dim=1) - + elif self.pooling_method == "max": return features.max(dim=1)[0] else: # mean pooling @@ -152,4 +167,4 @@ def _pool_features(self, features: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": main = create_embedding_cli(DINOv3Module, "DINOv3") - main() \ No newline at end of file + main() diff --git a/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh index 405499ec8..9bf781bc0 100644 --- a/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh +++ b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh @@ -15,4 +15,4 @@ module load anaconda/latest conda activate viscy CONFIG_PATH=/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sensor_only.yml -python /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py -c $CONFIG_PATH \ No newline at end of file +python /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py -c $CONFIG_PATH diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py index 07f2dda70..c664ca5cd 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -1,10 +1,11 @@ +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional + import torch from sam2.sam2_image_predictor import SAM2ImagePredictor from skimage.exposure import rescale_intensity -from typing import Dict, List, Literal, Optional -import sys -from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) from base_embedding_module import BaseEmbeddingModule, create_embedding_cli @@ -28,7 +29,7 @@ def __init__( def from_config(cls, cfg): """Create model instance from configuration.""" model_config = cfg.get("model", {}) - + return cls( model_name=model_config.get("model_name", "facebook/sam2-hiera-base-plus"), channel_reduction_methods=model_config.get("channel_reduction_methods", {}), @@ -97,4 +98,4 @@ def _convert_to_rgb(self, x: torch.Tensor) -> list: if __name__ == "__main__": main = create_embedding_cli(SAM2Module, "SAM2") - main() \ No newline at end of file + main() diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py index 1ab3b1082..ef389e044 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py @@ -4,15 +4,12 @@ This script helps debug what images are being passed to SAM2 and how they're processed. """ -import matplotlib.pyplot as plt -import numpy as np -import torch -import yaml -from pathlib import Path -import sys import os +from pathlib import Path +import matplotlib.pyplot as plt from sam2_embeddings import SAM2Module, load_config, load_normalization_from_config + from viscy.data.triplet import TripletDataModule @@ -52,16 +49,28 @@ def visualize_rgb_conversion(x_original, x_rgb_list, save_dir="./debug_images"): ax = axes[2, 0] # Normalize to 0-1 for display rgb_display = rgb_img.copy() - rgb_display = (rgb_display - rgb_display.min()) / (rgb_display.max() - rgb_display.min()) + rgb_display = (rgb_display - rgb_display.min()) / ( + rgb_display.max() - rgb_display.min() + ) im = ax.imshow(rgb_display) ax.set_title("Merged RGB Image") ax.axis("off") # Check if RGB is properly scaled to 0-255 ax = axes[2, 1] - ax.text(0.1, 0.8, f"RGB Range: [{rgb_img.min():.1f}, {rgb_img.max():.1f}]", transform=ax.transAxes) - ax.text(0.1, 0.6, f"Expected: [0, 255]", transform=ax.transAxes) - ax.text(0.1, 0.4, f"Properly scaled: {rgb_img.min() >= 0 and rgb_img.max() <= 255}", transform=ax.transAxes) + ax.text( + 0.1, + 0.8, + f"RGB Range: [{rgb_img.min():.1f}, {rgb_img.max():.1f}]", + transform=ax.transAxes, + ) + ax.text(0.1, 0.6, "Expected: [0, 255]", transform=ax.transAxes) + ax.text( + 0.1, + 0.4, + f"Properly scaled: {rgb_img.min() >= 0 and rgb_img.max() <= 255}", + transform=ax.transAxes, + ) ax.text(0.1, 0.2, f"Mean: {rgb_img.mean():.1f}", transform=ax.transAxes) ax.set_title("RGB Scaling Check") ax.axis("off") @@ -123,7 +132,7 @@ def test_sam2_processing(config_path, num_samples=3): if i >= num_samples: break - print(f"\n--- Sample {i+1} ---") + print(f"\n--- Sample {i + 1} ---") x = batch["anchor"] print(f"Input tensor shape: {x.shape}") print(f"Input tensor range: [{x.min():.3f}, {x.max():.3f}]") diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py index 077e57c87..89d8bb6fc 100644 --- a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py @@ -73,23 +73,22 @@ # %% # Enhanced evaluation and visualization import matplotlib.pyplot as plt -import seaborn as sns from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix # 1. Confusion Matrix - shows which classes are confused with each other cm = confusion_matrix(y_test, y_test_pred) plt.figure(figsize=(8, 6)) -ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S","M"]).plot(cmap="Blues") +ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S", "M"]).plot(cmap="Blues") plt.title("Confusion Matrix") plt.show() # 2. Per-class errors breakdown print("\nDetailed per-class analysis:") -for class_name in ["G1", "G2", "S","M"]: +for class_name in ["G1", "G2", "S", "M"]: mask = y_test == class_name correct = (y_test_pred[mask] == class_name).sum() total = mask.sum() - print(f"{class_name}: {correct}/{total} correct ({correct/total:.3f})") + print(f"{class_name}: {correct}/{total} correct ({correct / total:.3f})") # Show what this class was misclassified as if total > correct: @@ -105,7 +104,10 @@ for i, class_name in enumerate(class_names): plt.subplot(1, 4, i + 1) plt.hist( - y_test_proba[:, i], bins=20, alpha=0.7, color=["blue", "orange", "green",'red'][i] + y_test_proba[:, i], + bins=20, + alpha=0.7, + color=["blue", "orange", "green", "red"][i], ) plt.title(f"Confidence for {class_name}") plt.xlabel("Probability") diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py index e7c3b54cf..e4a56ab88 100644 --- a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py @@ -2,7 +2,6 @@ from pathlib import Path import matplotlib.pyplot as plt -import numpy as np import pandas as pd import seaborn as sns @@ -99,21 +98,21 @@ # plot the 3D PHATE embedding (Note: seaborn scatterplot doesn't support 3D, using matplotlib) fig = plt.figure(figsize=(10, 10)) -ax = fig.add_subplot(111, projection='3d') +ax = fig.add_subplot(111, projection="3d") for state in ["G1", "G2", "S"]: mask = cell_cycle_states == state ax.scatter( test_features["PHATE1"][merged_indices][mask], - test_features["PHATE2"][merged_indices][mask], + test_features["PHATE2"][merged_indices][mask], test_features["PHATE3"][merged_indices][mask], c=cycle_colors[state], alpha=0.6, - label=state + label=state, ) ax.set_xlabel("PHATE1") -ax.set_ylabel("PHATE2") +ax.set_ylabel("PHATE2") ax.set_zlabel("PHATE3") ax.set_title("3D PHATE Embedding Colored by Cell Cycle State") ax.legend() diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py index 9c0bf5111..8a9a83cb4 100644 --- a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -8,23 +8,27 @@ from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.smoothness import compute_embeddings_smoothness -#%% +# %% # FEATURES # openphenom_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/open_phenom/features/open_phenom_features.csv") # imagenet_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/imagenet/features/imagenet_features.csv") -dynaclr_features_path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/dtw_evaluation/SAM2/sam2_sensor_only.zarr") -dinov3_features_path = Path("/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_phase_only_2.zarr") +dynaclr_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/dtw_evaluation/SAM2/sam2_sensor_only.zarr" +) +dinov3_features_path = Path( + "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_phase_only_2.zarr" +) # LOADING DATASETS # openphenom_features = read_embedding_dataset(openphenom_features_path) # imagenet_features = read_embedding_dataset(imagenet_features_path) dynaclr_embedding_dataset = read_embedding_dataset(dynaclr_features_path) dinov3_embedding_dataset = read_embedding_dataset(dinov3_features_path) -#%% +# %% # Compute the smoothness of the features DISTANCE_METRIC = "cosine" -feature_paths ={ +feature_paths = { # "dynaclr": dynaclr_features_path, "dinov3": dinov3_features_path, } @@ -73,29 +77,27 @@ plt.xlabel(f"{DISTANCE_METRIC} Distance") plt.ylabel("Density") # Add vertical lines for the peaks - plt.axvline( - x=stats["adjacent_frame_peak"], color="cyan", linestyle="--", alpha=0.8 - ) + plt.axvline(x=stats["adjacent_frame_peak"], color="cyan", linestyle="--", alpha=0.8) plt.axvline(x=stats["random_frame_peak"], color="red", linestyle="--", alpha=0.8) plt.tight_layout() plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - plt.savefig(output_dir/f"{label}_smoothness.pdf", dpi=300) - plt.savefig(output_dir/f"{label}_smoothness.png", dpi=300) + plt.savefig(output_dir / f"{label}_smoothness.pdf", dpi=300) + plt.savefig(output_dir / f"{label}_smoothness.png", dpi=300) plt.close() - #metrics to csv + # metrics to csv scalar_metrics = { - "adjacent_frame_mean": stats["adjacent_frame_mean"], - "adjacent_frame_std": stats["adjacent_frame_std"], - "adjacent_frame_median": stats["adjacent_frame_median"], - "adjacent_frame_peak": stats["adjacent_frame_peak"], - "random_frame_mean": stats["random_frame_mean"], - "random_frame_std": stats["random_frame_std"], - "random_frame_median": stats["random_frame_median"], - "random_frame_peak": stats["random_frame_peak"], - "smoothness_score": stats["smoothness_score"], - "dynamic_range": stats["dynamic_range"] + "adjacent_frame_mean": stats["adjacent_frame_mean"], + "adjacent_frame_std": stats["adjacent_frame_std"], + "adjacent_frame_median": stats["adjacent_frame_median"], + "adjacent_frame_peak": stats["adjacent_frame_peak"], + "random_frame_mean": stats["random_frame_mean"], + "random_frame_std": stats["random_frame_std"], + "random_frame_median": stats["random_frame_median"], + "random_frame_peak": stats["random_frame_peak"], + "smoothness_score": stats["smoothness_score"], + "dynamic_range": stats["dynamic_range"], } # Create DataFrame with single row stats_df = pd.DataFrame(scalar_metrics, index=[0]) - stats_df.to_csv(output_dir/f"{label}_smoothness_stats.csv", index=False) + stats_df.to_csv(output_dir / f"{label}_smoothness_stats.csv", index=False) diff --git a/applications/contrastive_phenotyping/figures/grad_attr.py b/applications/contrastive_phenotyping/figures/grad_attr.py index 9d9ce7246..038cc5c96 100644 --- a/applications/contrastive_phenotyping/figures/grad_attr.py +++ b/applications/contrastive_phenotyping/figures/grad_attr.py @@ -9,16 +9,10 @@ import numpy as np import pandas as pd import torch +import xarray as xr from cmap import Colormap from lightning.pytorch import seed_everything from skimage.exposure import rescale_intensity -from sklearn.metrics import ( - accuracy_score, - auc, - f1_score, - precision_recall_curve, - roc_auc_score, -) from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset @@ -68,8 +62,8 @@ keys=["RFP"], lower=50, upper=99, b_min=0.0, b_max=1.0 ), Decollated( - keys=["Phase3D","RFP"], - ) + keys=["Phase3D", "RFP"], + ), ], predict_cells=True, include_fov_names=[fov] * len(track), @@ -93,6 +87,7 @@ ), ).eval() + # %% def load_and_combine_datasets( datasets, @@ -145,8 +140,6 @@ def load_and_combine_datasets( 1.0: 1, 0.0: 0, 2: 2, - 1: 1, - 0: 0, } elif target_type == "division": standardization_mapping = { @@ -159,8 +152,6 @@ def load_and_combine_datasets( 1.0: 1, 0.0: 0, 2: 2, - 1: 1, - 0: 0, } for emb_path, ann_path, train_fovs in datasets: @@ -505,7 +496,9 @@ def clim_percentile(heatmap, low=1, high=99): icefire = Colormap("icefire").to_mpl() -f, ax = plt.subplots(3, len(selected_time_points), figsize=(5.5, 3), layout="compressed") +f, ax = plt.subplots( + 3, len(selected_time_points), figsize=(5.5, 3), layout="compressed" +) for i, time in enumerate(selected_time_points): hpi = 3 + 0.5 * time prob = infection_probs[time].item() @@ -515,11 +508,11 @@ def clim_percentile(heatmap, low=1, high=99): ax[0, i].set_title(f"{hpi} HPI") ax[1, i].imshow(inf_render[time], cmap=icefire, vmin=0, vmax=1) ax[1, i].set_title( - f"infected: {prob:.3f}\n" f"label: {inf_binary}", + f"infected: {prob:.3f}\nlabel: {inf_binary}", ) ax[2, i].imshow(div_render[time], cmap=icefire, vmin=0, vmax=1) ax[2, i].set_title( - f"dividing: {division_probs[time].item():.3f}\n" f"label: {div_binary}", + f"dividing: {division_probs[time].item():.3f}\nlabel: {div_binary}", ) for a in ax.ravel(): a.axis("off") @@ -627,7 +620,8 @@ def animate(frame): return [im1, im2, im3] -#%% + +# %% # Create animation anim = animation.FuncAnimation( @@ -647,4 +641,4 @@ def animate(frame): writer = Writer(fps=5, metadata=dict(artist="VisCy"), bitrate=1800) anim.save(str(video_path), writer=writer) -print(f"Video saved to: {video_path}") \ No newline at end of file +print(f"Video saved to: {video_path}") diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py index c18528f6f..17c15d1df 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -11,15 +11,11 @@ find_pattern_matches, identify_lineages, plot_pc_trajectories, - plot_reference_vs_full_lineages, ) -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler from tqdm import tqdm from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.dimensionality_reduction import compute_pca logger = logging.getLogger("viscy") logger.setLevel(logging.INFO) diff --git a/tests/representation/evaluation/test_clustering.py b/tests/representation/evaluation/test_clustering.py index 9b91a065b..cfff09694 100644 --- a/tests/representation/evaluation/test_clustering.py +++ b/tests/representation/evaluation/test_clustering.py @@ -154,7 +154,7 @@ def test_invalid_device_raises_error(self, sample_features: NDArray): def test_float32_input_preserves_precision(self): """Test that float32 input is converted to float64 for precision.""" - torch = pytest.importorskip("torch") + pytest.importorskip("torch") features_f32 = np.random.randn(10, 32).astype(np.float32) diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index bef94fddc..62c74d424 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Literal import anndata as ad From 69d928fef1e578e4b8faa589605e22a1b2038ec9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:19:59 -0700 Subject: [PATCH 091/101] Update viscy/data/cell_division_triplet.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- viscy/data/cell_division_triplet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py index a1ccc3cef..4547adaa1 100644 --- a/viscy/data/cell_division_triplet.py +++ b/viscy/data/cell_division_triplet.py @@ -345,7 +345,7 @@ def __init__( split_ratio=split_ratio, batch_size=batch_size, num_workers=num_workers, - target_2d=False, # Set to False since we're adding depth dimension + target_2d=False, # Set to False since we're adding depth dimension yx_patch_size=final_yx_patch_size, normalizations=normalizations, augmentations=augmentations, From d42d6da4529359b809637bdd5c10102619a3d7f8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:20:26 -0700 Subject: [PATCH 092/101] Update applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py index ef389e044..46a9b2eac 100644 --- a/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py @@ -52,7 +52,7 @@ def visualize_rgb_conversion(x_original, x_rgb_list, save_dir="./debug_images"): rgb_display = (rgb_display - rgb_display.min()) / ( rgb_display.max() - rgb_display.min() ) - im = ax.imshow(rgb_display) + ax.imshow(rgb_display) ax.set_title("Merged RGB Image") ax.axis("off") From 84634e8f7f497a7daa422fd4cfa718b9a8ef58bd Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:20:53 -0700 Subject: [PATCH 093/101] Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../evaluation/compare_dtw_embeddings_sam2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py index 17c15d1df..841f1c365 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -114,7 +114,6 @@ # OPTION 1: Use the infection annotations to find the reference lineage reference_lineage_fov = "/C/2/001000" reference_lineage_track_id = [129] -reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling # Option 2: from the filtered lineages find one from FOV C/2/000001 reference_lineage_fov = "/C/2/000001" From 817be47feeb31cf054d2feeba5949b9a861b1ecc Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:21:08 -0700 Subject: [PATCH 094/101] Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../evaluation/compare_dtw_embeddings_sam2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py index 841f1c365..d5ba59d46 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -113,7 +113,6 @@ # Aligning condition embeddings to infection # OPTION 1: Use the infection annotations to find the reference lineage reference_lineage_fov = "/C/2/001000" -reference_lineage_track_id = [129] # Option 2: from the filtered lineages find one from FOV C/2/000001 reference_lineage_fov = "/C/2/000001" From 2cc16e9e7a32800c6d7c5da168e7690dc5bb08a6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:21:20 -0700 Subject: [PATCH 095/101] Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../evaluation/smoothness/compute_smoothness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py index 8a9a83cb4..7999ac0d4 100644 --- a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -57,7 +57,7 @@ ) # Plot the piecewise distances - fig = plt.figure() + plt.figure() sns.histplot( distributions["adjacent_frame_distribution"], bins=30, From 3acad6f7884ed695fa212aab9900a287afa80ed0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:21:34 -0700 Subject: [PATCH 096/101] Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../evaluation/compare_dtw_embeddings_sam2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py index d5ba59d46..0381c12c4 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -112,7 +112,6 @@ # %% # Aligning condition embeddings to infection # OPTION 1: Use the infection annotations to find the reference lineage -reference_lineage_fov = "/C/2/001000" # Option 2: from the filtered lineages find one from FOV C/2/000001 reference_lineage_fov = "/C/2/000001" From 2a65bf054d3c55b88ff0fc80e9c5eeed03df29c5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:21:47 -0700 Subject: [PATCH 097/101] Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../evaluation/smoothness/compute_smoothness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py index 7999ac0d4..8047c3d63 100644 --- a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -51,7 +51,7 @@ # Compute displacements stats, distributions, _ = compute_embeddings_smoothness( - prediction_path=Path(path), + embedding_dataset=embedding_dataset, distance_metric=DISTANCE_METRIC, verbose=True, ) From 9421b03efa4a98e34f31a209dd33bda2203d961b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Oct 2025 17:22:16 -0700 Subject: [PATCH 098/101] Update applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py index 98014a4f3..ed86e0b4c 100644 --- a/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py @@ -183,7 +183,7 @@ def extract_step_sizes(embedding_dataset: xr.Dataset): # %% # Plot histograms -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) +ax1, ax2 = plt.subplots(1, 2, figsize=(15, 6))[1] for model_type, steps in all_step_data.items(): color = interval_colors.get(model_type, "#1f77b4") From a3f015efbff628848f16a65b456e5a6fb4f4c6a2 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 3 Nov 2025 08:59:18 -0800 Subject: [PATCH 099/101] valuerror on the fidn peaks function --- viscy/representation/evaluation/smoothness.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py index 62c74d424..117aa3d15 100644 --- a/viscy/representation/evaluation/smoothness.py +++ b/viscy/representation/evaluation/smoothness.py @@ -105,6 +105,9 @@ def find_distribution_peak( peak_heights = properties["peak_heights"] return x_range[peaks[np.argmax(peak_heights)]] + else: + raise ValueError(f"Unknown method: {method}. Use 'histogram' or 'kde_robust'.") + def compute_embeddings_smoothness( features_ad: ad.AnnData, From cf7450acd0bd5fca649ad0b2e165ffb7f5b8c14f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 3 Nov 2025 09:00:21 -0800 Subject: [PATCH 100/101] add literal to the betavae25d normalization --- viscy/representation/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py index a7cd1a51e..f34c2828d 100644 --- a/viscy/representation/vae.py +++ b/viscy/representation/vae.py @@ -299,7 +299,7 @@ def __init__( head_pool: bool = False, upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", conv_blocks: int = 2, - norm_name: str = "batch", + norm_name: Literal["batch", "instance"] = "batch", upsample_pre_conv: Literal["default"] | Callable | None = None, ): super().__init__() From 76f8ddbde7dd7e57434fdbabe1d195b51d640c7d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 3 Nov 2025 10:19:01 -0800 Subject: [PATCH 101/101] clipping similarity that was breaking the tests --- viscy/representation/evaluation/distance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 9ea940384..f5fea2f2b 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -23,6 +23,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): cosine_similarities = cosine_similarity( first_time_point_embedding, features ).flatten() + cosine_similarities = np.clip(cosine_similarities, -1.0, 1.0) return time_points, cosine_similarities.tolist()