From 25322e20aff70fdefa1a0c78c30311254d986fce Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 7 Jan 2025 18:10:15 -0800 Subject: [PATCH 01/38] updating the MSD calculation --- .../evaluation/ALFI_MSD_v2.py | 110 +++++++++ viscy/representation/evaluation/distance.py | 219 ++++++++++++++---- 2 files changed, 282 insertions(+), 47 deletions(-) create mode 100644 applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py new file mode 100644 index 000000000..a73d6a5e8 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -0,0 +1,110 @@ +# %% +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.distance import ( + compute_displacement, + compute_displacement_statistics, +) + +# Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", +} + +# %% Compute MSD for each dataset +results = {} +raw_displacements = {} +max_tau = 200 + +# Different normalization strategies to test +norm_strategies = [None, "per_feature", "per_embedding", "per_dataset"] +colors = { + None: "blue", + "per_feature": "red", + "per_embedding": "green", + "per_dataset": "purple", +} +labels = { + None: "Raw", + "per_feature": "Per-feature z-score", + "per_embedding": "Unit norm", + "per_dataset": "Dataset z-score", +} + +for label, path in feature_paths.items(): + print(f"\nProcessing {label}...") + embedding_dataset = read_embedding_dataset(Path(path)) + + for norm in norm_strategies: + # Compute displacements with different normalization strategies + displacements = compute_displacement( + embedding_dataset=embedding_dataset, + max_tau=max_tau, + distance_metric="euclidean_squared", + normalize=norm, + ) + means, stds = compute_displacement_statistics(displacements) + results[f"{label} ({labels[norm]})"] = (means, stds) + raw_displacements[f"{label} ({labels[norm]})"] = displacements + + print(f"{labels[norm]} MSD at tau=1: {means[1]:.4f} ± {stds[1]:.4f}") + +# %% Plot results with sample sizes +plt.figure(figsize=(12, 8)) + +for label, (means, stds) in results.items(): + taus = list(means.keys()) + mean_values = list(means.values()) + std_values = list(stds.values()) + + # Get the normalization strategy from the label + norm = next(n for n in norm_strategies if labels[n] in label) + color = colors[norm] + + # Plot MSD with confidence band + plt.plot(taus, mean_values, "o-", color=color, label=f"{label} (mean)") + plt.fill_between( + taus, + np.array(mean_values) - np.array(std_values), + np.array(mean_values) + np.array(std_values), + alpha=0.3, + color=color, + label=f"{label} (±1σ)", + ) + +plt.xlabel("Time Shift (τ)") +plt.ylabel("Mean Square Displacement") +plt.title( + "Mean Square Displacement vs Time Shift\n(Comparing Normalization Strategies)" +) +plt.grid(True) +plt.legend() +plt.tight_layout() +plt.show() + +# %% Plot displacement distributions for different taus +fig, axes = plt.subplots(2, 2, figsize=(15, 12)) +axes = axes.ravel() + +for i, norm in enumerate(norm_strategies): + label = f"7 min interval ({labels[norm]})" + displacements = raw_displacements[label] + + # Plot distributions for a few selected taus + selected_taus = [1, 5, max_tau] + for tau in selected_taus: + values = displacements[tau] + axes[i].hist(values, bins=50, alpha=0.3, density=True, label=f"τ = {tau}") + + axes[i].set_xlabel("Square Displacement") + axes[i].set_ylabel("Density") + axes[i].set_title(f"Distribution of Square Displacements\n({labels[norm]})") + axes[i].legend() + axes[i].grid(True) + +plt.tight_layout() +plt.show() + +# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 9a1c72ef3..7b532ac3d 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Dict, List, Literal, Tuple, Union import numpy as np from sklearn.metrics.pairwise import cosine_similarity @@ -20,69 +21,193 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities.tolist() +def normalize_embeddings( + embeddings: np.ndarray, + strategy: Literal["per_feature", "per_embedding", "per_dataset"] = "per_feature", +) -> np.ndarray: + """Normalize embeddings using different strategies. + + Parameters + ---------- + embeddings : np.ndarray + Array of shape (n_samples, n_features) containing embeddings + strategy : str + Normalization strategy: + - "per_feature": z-score each feature across all samples + - "per_embedding": normalize each embedding vector to unit norm + - "per_dataset": z-score entire dataset (across all features and samples) + + Returns + ------- + np.ndarray + Normalized embeddings with same shape as input + """ + if strategy == "per_feature": + # Normalize each feature independently + return (embeddings - np.mean(embeddings, axis=0)) / np.std(embeddings, axis=0) + elif strategy == "per_embedding": + # Normalize each embedding to unit norm + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + return embeddings / norms + elif strategy == "per_dataset": + # Normalize entire dataset + return (embeddings - np.mean(embeddings)) / np.std(embeddings) + else: + raise ValueError(f"Unknown normalization strategy: {strategy}") + + def compute_displacement( embedding_dataset, - max_tau=10, - use_cosine=False, - use_dissimilarity=False, - use_umap=False, - return_mean_std=False, -): - """Compute the norm of differences between embeddings at t and t + tau""" + max_tau: int = 10, + distance_metric: Literal[ + "euclidean", "euclidean_squared", "cosine", "cosine_dissimilarity" + ] = "euclidean_squared", + embedding_coords: Literal["UMAP", "PHATE", None] = None, + normalize: Literal["per_feature", "per_embedding", "per_dataset", None] = None, +) -> Dict[int, List[float]]: + """Compute the displacement or mean square displacement (MSD) of embeddings. + + For each time difference τ, computes either: + - |r(t + τ) - r(t)|² for squared Euclidean (MSD) + - |r(t + τ) - r(t)| for Euclidean + - cos_sim(r(t + τ), r(t)) for cosine + for all particles and initial times t. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing embeddings and metadata + max_tau : int + Maximum time difference to compute displacement for + distance_metric : str + The metric to use for computing distances between embeddings. + Valid options are: + - "euclidean": Euclidean distance (L2 norm) + - "euclidean_squared": Squared Euclidean distance (for MSD, default) + - "cosine": Cosine similarity + - "cosine_dissimilarity": 1 - cosine similarity + embedding_coords : str or None + Which embedding coordinates to use for distance computation: + - None: Use original features from dataset (default) + - "UMAP": Use UMAP coordinates (UMAP1, UMAP2) + - "PHATE": Use PHATE coordinates (PHATE1, PHATE2) + normalize : str or None + Normalization strategy to apply to embeddings before computing distances: + - None: No normalization (default) + - "per_feature": z-score each feature across all samples + - "per_embedding": normalize each embedding vector to unit norm + - "per_dataset": z-score entire dataset (across all features and samples) + + Returns + ------- + Dict[int, List[float]] + Dictionary mapping τ to list of displacements for all particles and initial times + """ + # Get data from dataset fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values - if use_umap: + # Get embeddings based on specified coordinates + if embedding_coords == "UMAP": embeddings = np.vstack( (embedding_dataset["UMAP1"].values, embedding_dataset["UMAP2"].values) ).T + elif embedding_coords == "PHATE": + embeddings = np.vstack( + (embedding_dataset["PHATE1"].values, embedding_dataset["PHATE2"].values) + ).T else: embeddings = embedding_dataset["features"].values - displacement_per_tau = defaultdict(list) - - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i].reshape(1, -1) - - for tau in range(1, max_tau + 1): - future_time = current_time + tau - matching_indices = np.where( - (fov_names == fov_name) - & (track_ids == track_id) - & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - future_embedding = embeddings[matching_indices[0]].reshape(1, -1) - - if use_cosine: - similarity = cosine_similarity(current_embedding, future_embedding)[ - 0 - ][0] - displacement = 1 - similarity if use_dissimilarity else similarity - else: - displacement = np.sum((current_embedding - future_embedding) ** 2) - - displacement_per_tau[tau].append(displacement) - - if return_mean_std: - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - return mean_displacement_per_tau, std_displacement_per_tau + # Normalize embeddings if requested + if normalize is not None: + embeddings = normalize_embeddings(embeddings, strategy=normalize) + + # Initialize results dictionary + displacement_per_tau = {tau: [] for tau in range(1, max_tau + 1)} + + # Get unique tracks using a set of tuples + unique_tracks = set(zip(fov_names, track_ids)) + + # Process each track + for fov_name, track_id in unique_tracks: + # Get sorted track data + mask = (fov_names == fov_name) & (track_ids == track_id) + times = timepoints[mask] + track_embeddings = embeddings[mask] + + # Sort by time + time_order = np.argsort(times) + times = times[time_order] + track_embeddings = track_embeddings[time_order] + + # Process each time point + for t_idx, t in enumerate(times[:-1]): + current_embedding = track_embeddings[t_idx] + + # Check each tau + for tau in range(1, max_tau + 1): + future_time = t + tau + # Since times are sorted, we can use searchsorted + future_idx = np.searchsorted(times, future_time) + + if future_idx < len(times) and times[future_idx] == future_time: + future_embedding = track_embeddings[future_idx] + + if distance_metric in ["cosine", "cosine_dissimilarity"]: + dot_product = np.dot(current_embedding, future_embedding) + norms = np.linalg.norm(current_embedding) * np.linalg.norm( + future_embedding + ) + similarity = dot_product / norms + displacement = ( + 1 - similarity + if distance_metric == "cosine_dissimilarity" + else similarity + ) + else: # Euclidean metrics + diff_squared = np.sum( + (current_embedding - future_embedding) ** 2 + ) + displacement = ( + diff_squared + if distance_metric == "euclidean_squared" + else np.sqrt(diff_squared) + ) + + displacement_per_tau[tau].append(displacement) return displacement_per_tau +def compute_displacement_statistics( + displacement_per_tau: Dict[int, List[float]] +) -> Tuple[Dict[int, float], Dict[int, float]]: + """Compute mean and standard deviation of displacements for each tau. + + Parameters + ---------- + displacement_per_tau : Dict[int, List[float]] + Dictionary mapping τ to list of displacements + + Returns + ------- + Tuple[Dict[int, float], Dict[int, float]] + Tuple of (mean_displacements, std_displacements) where each is a + dictionary mapping τ to the statistic + """ + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + return mean_displacement_per_tau, std_displacement_per_tau + + def compute_dynamic_range(mean_displacement_per_tau): """ Compute the dynamic range as the difference between the maximum From 0fcb0e61c61d97843678eaab6a1fc1161b7ae9f8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 7 Jan 2025 18:24:11 -0800 Subject: [PATCH 02/38] MSD with different normalizations and removing unecessary arguments --- .../evaluation/ALFI_MSD_v2.py | 111 ++++++++---------- viscy/representation/evaluation/distance.py | 75 ++++++------ 2 files changed, 84 insertions(+), 102 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index a73d6a5e8..bc9324cb7 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -11,21 +11,11 @@ # Paths to datasets feature_paths = { "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", + "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", } -# %% Compute MSD for each dataset -results = {} -raw_displacements = {} -max_tau = 200 - # Different normalization strategies to test norm_strategies = [None, "per_feature", "per_embedding", "per_dataset"] -colors = { - None: "blue", - "per_feature": "red", - "per_embedding": "green", - "per_dataset": "purple", -} labels = { None: "Raw", "per_feature": "Per-feature z-score", @@ -33,6 +23,16 @@ "per_dataset": "Dataset z-score", } +# Colors for different time intervals +interval_colors = { + "7 min interval": "blue", + "21 min interval": "red", +} + +# %% Compute MSD for each dataset +results = {} +raw_displacements = {} + for label, path in feature_paths.items(): print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) @@ -41,7 +41,6 @@ # Compute displacements with different normalization strategies displacements = compute_displacement( embedding_dataset=embedding_dataset, - max_tau=max_tau, distance_metric="euclidean_squared", normalize=norm, ) @@ -49,62 +48,52 @@ results[f"{label} ({labels[norm]})"] = (means, stds) raw_displacements[f"{label} ({labels[norm]})"] = displacements - print(f"{labels[norm]} MSD at tau=1: {means[1]:.4f} ± {stds[1]:.4f}") - -# %% Plot results with sample sizes -plt.figure(figsize=(12, 8)) + # Print some statistics + taus = sorted(means.keys()) + print(f"\n{labels[norm]}:") + print(f" Number of different τ values: {len(taus)}") + print(f" τ range: {min(taus)} to {max(taus)}") + print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") -for label, (means, stds) in results.items(): - taus = list(means.keys()) - mean_values = list(means.values()) - std_values = list(stds.values()) - - # Get the normalization strategy from the label - norm = next(n for n in norm_strategies if labels[n] in label) - color = colors[norm] - - # Plot MSD with confidence band - plt.plot(taus, mean_values, "o-", color=color, label=f"{label} (mean)") - plt.fill_between( - taus, - np.array(mean_values) - np.array(std_values), - np.array(mean_values) + np.array(std_values), - alpha=0.3, - color=color, - label=f"{label} (±1σ)", - ) - -plt.xlabel("Time Shift (τ)") -plt.ylabel("Mean Square Displacement") -plt.title( - "Mean Square Displacement vs Time Shift\n(Comparing Normalization Strategies)" -) -plt.grid(True) -plt.legend() -plt.tight_layout() -plt.show() - -# %% Plot displacement distributions for different taus +# %% Plot MSD vs time - one plot per normalization strategy fig, axes = plt.subplots(2, 2, figsize=(15, 12)) axes = axes.ravel() for i, norm in enumerate(norm_strategies): - label = f"7 min interval ({labels[norm]})" - displacements = raw_displacements[label] - - # Plot distributions for a few selected taus - selected_taus = [1, 5, max_tau] - for tau in selected_taus: - values = displacements[tau] - axes[i].hist(values, bins=50, alpha=0.3, density=True, label=f"τ = {tau}") + ax = axes[i] + + # Plot each time interval for this normalization strategy + for interval_label, path in feature_paths.items(): + result_label = f"{interval_label} ({labels[norm]})" + means, stds = results[result_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + # Plot MSD with confidence band + ax.plot( + taus, + mean_values, + "-", + color=interval_colors[interval_label], + label=f"{interval_label}", + ) + ax.fill_between( + taus, + np.array(mean_values) - np.array(std_values), + np.array(mean_values) + np.array(std_values), + alpha=0.3, + color=interval_colors[interval_label], + ) - axes[i].set_xlabel("Square Displacement") - axes[i].set_ylabel("Density") - axes[i].set_title(f"Distribution of Square Displacements\n({labels[norm]})") - axes[i].legend() - axes[i].grid(True) + ax.set_xlabel("Time Shift (τ)") + ax.set_ylabel("Mean Square Displacement") + ax.set_title(f"MSD vs Time Shift\n({labels[norm]})") + ax.grid(True) + ax.legend() plt.tight_layout() plt.show() - # %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 7b532ac3d..51caecc87 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Literal, Tuple, Union +from typing import Dict, List, Literal, Tuple, Union, Optional import numpy as np from sklearn.metrics.pairwise import cosine_similarity @@ -58,7 +58,6 @@ def normalize_embeddings( def compute_displacement( embedding_dataset, - max_tau: int = 10, distance_metric: Literal[ "euclidean", "euclidean_squared", "cosine", "cosine_dissimilarity" ] = "euclidean_squared", @@ -77,8 +76,6 @@ def compute_displacement( ---------- embedding_dataset : xarray.Dataset Dataset containing embeddings and metadata - max_tau : int - Maximum time difference to compute displacement for distance_metric : str The metric to use for computing distances between embeddings. Valid options are: @@ -124,12 +121,12 @@ def compute_displacement( if normalize is not None: embeddings = normalize_embeddings(embeddings, strategy=normalize) - # Initialize results dictionary - displacement_per_tau = {tau: [] for tau in range(1, max_tau + 1)} - - # Get unique tracks using a set of tuples + # Get unique tracks unique_tracks = set(zip(fov_names, track_ids)) + # Initialize results dictionary with empty lists + displacement_per_tau = defaultdict(list) + # Process each track for fov_name, track_id in unique_tracks: # Get sorted track data @@ -146,39 +143,35 @@ def compute_displacement( for t_idx, t in enumerate(times[:-1]): current_embedding = track_embeddings[t_idx] - # Check each tau - for tau in range(1, max_tau + 1): - future_time = t + tau - # Since times are sorted, we can use searchsorted - future_idx = np.searchsorted(times, future_time) - - if future_idx < len(times) and times[future_idx] == future_time: - future_embedding = track_embeddings[future_idx] - - if distance_metric in ["cosine", "cosine_dissimilarity"]: - dot_product = np.dot(current_embedding, future_embedding) - norms = np.linalg.norm(current_embedding) * np.linalg.norm( - future_embedding - ) - similarity = dot_product / norms - displacement = ( - 1 - similarity - if distance_metric == "cosine_dissimilarity" - else similarity - ) - else: # Euclidean metrics - diff_squared = np.sum( - (current_embedding - future_embedding) ** 2 - ) - displacement = ( - diff_squared - if distance_metric == "euclidean_squared" - else np.sqrt(diff_squared) - ) - - displacement_per_tau[tau].append(displacement) - - return displacement_per_tau + # Check all possible future time points + for future_idx, future_time in enumerate( + times[t_idx + 1 :], start=t_idx + 1 + ): + tau = future_time - t + future_embedding = track_embeddings[future_idx] + + if distance_metric in ["cosine", "cosine_dissimilarity"]: + dot_product = np.dot(current_embedding, future_embedding) + norms = np.linalg.norm(current_embedding) * np.linalg.norm( + future_embedding + ) + similarity = dot_product / norms + displacement = ( + 1 - similarity + if distance_metric == "cosine_dissimilarity" + else similarity + ) + else: # Euclidean metrics + diff_squared = np.sum((current_embedding - future_embedding) ** 2) + displacement = ( + diff_squared + if distance_metric == "euclidean_squared" + else np.sqrt(diff_squared) + ) + + displacement_per_tau[int(tau)].append(displacement) + + return dict(displacement_per_tau) def compute_displacement_statistics( From 091fabbbca5deb08461f60edb808572daeeb996e Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 7 Jan 2025 18:31:53 -0800 Subject: [PATCH 03/38] adding more paths from old code --- .../evaluation/ALFI_MSD_v2.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index bc9324cb7..a2a293ed6 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -12,6 +12,10 @@ feature_paths = { "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", + "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", + "Classical": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr", } # Different normalization strategies to test @@ -27,6 +31,10 @@ interval_colors = { "7 min interval": "blue", "21 min interval": "red", + "28 min interval": "green", + "56 min interval": "purple", + "Cell Aware": "orange", + "Classical": "gray", } # %% Compute MSD for each dataset @@ -72,26 +80,35 @@ mean_values = [means[tau] for tau in taus] std_values = [stds[tau] for tau in taus] - # Plot MSD with confidence band + # Plot MSD with confidence band and scatter points + # ax.fill_between( + # taus, + # np.array(mean_values) - np.array(std_values), + # np.array(mean_values) + np.array(std_values), + # alpha=0.2, + # color=interval_colors[interval_label], + # ) ax.plot( taus, mean_values, "-", color=interval_colors[interval_label], - label=f"{interval_label}", + alpha=0.5, + zorder=1, ) - ax.fill_between( + ax.scatter( taus, - np.array(mean_values) - np.array(std_values), - np.array(mean_values) + np.array(std_values), - alpha=0.3, + mean_values, color=interval_colors[interval_label], + s=20, + label=f"{interval_label}", + zorder=2, ) ax.set_xlabel("Time Shift (τ)") ax.set_ylabel("Mean Square Displacement") ax.set_title(f"MSD vs Time Shift\n({labels[norm]})") - ax.grid(True) + ax.grid(True, alpha=0.3) ax.legend() plt.tight_layout() From 823c971c840422e5d469c390930b60cccd5c8d64 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 7 Jan 2025 21:14:18 -0800 Subject: [PATCH 04/38] add log log plot and the slope --- .../evaluation/ALFI_MSD_v2.py | 171 +++++++++++++++++- 1 file changed, 162 insertions(+), 9 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index a2a293ed6..6053b7245 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -63,7 +63,7 @@ print(f" τ range: {min(taus)} to {max(taus)}") print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") -# %% Plot MSD vs time - one plot per normalization strategy +# %% Plot MSD vs time - one plot per normalization strategy (linear scale) fig, axes = plt.subplots(2, 2, figsize=(15, 12)) axes = axes.ravel() @@ -80,14 +80,6 @@ mean_values = [means[tau] for tau in taus] std_values = [stds[tau] for tau in taus] - # Plot MSD with confidence band and scatter points - # ax.fill_between( - # taus, - # np.array(mean_values) - np.array(std_values), - # np.array(mean_values) + np.array(std_values), - # alpha=0.2, - # color=interval_colors[interval_label], - # ) ax.plot( taus, mean_values, @@ -113,4 +105,165 @@ plt.tight_layout() plt.show() + +# %% Plot MSD vs time - one plot per normalization strategy (log-log scale with slopes) +fig, axes = plt.subplots(2, 2, figsize=(15, 12)) +axes = axes.ravel() + +for i, norm in enumerate(norm_strategies): + ax = axes[i] + + # Plot each time interval for this normalization strategy + for interval_label, path in feature_paths.items(): + result_label = f"{interval_label} ({labels[norm]})" + means, stds = results[result_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + # Filter out non-positive values for log scale + valid_mask = np.array(mean_values) > 0 + valid_taus = np.array(taus)[valid_mask] + valid_means = np.array(mean_values)[valid_mask] + + # Calculate slope using linear regression on log-log values + log_taus = np.log(valid_taus) + log_means = np.log(valid_means) + slope, intercept = np.polyfit(log_taus, log_means, 1) + + ax.plot( + valid_taus, + valid_means, + "-", + color=interval_colors[interval_label], + alpha=0.5, + zorder=1, + ) + ax.scatter( + valid_taus, + valid_means, + color=interval_colors[interval_label], + s=20, + label=f"{interval_label} (α={slope:.2f})", + zorder=2, + ) + + # Plot fitted line + fit_line = np.exp(intercept + slope * log_taus) + ax.plot( + valid_taus, + fit_line, + "--", + color=interval_colors[interval_label], + alpha=0.3, + zorder=1, + ) + + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Time Shift (τ)") + ax.set_ylabel("Mean Square Displacement") + ax.set_title(f"MSD vs Time Shift (log-log)\n({labels[norm]})") + ax.grid(True, alpha=0.3, which="both") + ax.legend(title="α = slope in log-log space") + +plt.tight_layout() +plt.show() + +# %% Print detailed slope analysis +print("\nSlope Analysis (α):") +print("α = 1: Normal diffusion") +print("α < 1: Sub-diffusion (confined/hindered)") +print("α > 1: Super-diffusion (directed/active)\n") + +for norm in norm_strategies: + print(f"\n{labels[norm]}:") + for interval_label in feature_paths.keys(): + result_label = f"{interval_label} ({labels[norm]})" + means, _ = results[result_label] + + # Calculate slope + taus = np.array(sorted(means.keys())) + mean_values = np.array([means[tau] for tau in taus]) + valid_mask = mean_values > 0 + + if np.sum(valid_mask) > 1: + log_taus = np.log(taus[valid_mask]) + log_means = np.log(mean_values[valid_mask]) + slope, _ = np.polyfit(log_taus, log_means, 1) + + motion_type = ( + "normal diffusion" + if abs(slope - 1) < 0.1 + else "sub-diffusion" if slope < 1 else "super-diffusion" + ) + + print(f" {interval_label}: α = {slope:.2f} ({motion_type})") + +# %% Plot slopes analysis +slopes_data = [] +intervals = [] +norm_types = [] + +for norm in norm_strategies: + for interval_label in feature_paths.keys(): + result_label = f"{interval_label} ({labels[norm]})" + means, _ = results[result_label] + + # Calculate slope + taus = np.array(sorted(means.keys())) + mean_values = np.array([means[tau] for tau in taus]) + valid_mask = mean_values > 0 + + if np.sum(valid_mask) > 1: # Need at least 2 points for slope + log_taus = np.log(taus[valid_mask]) + log_means = np.log(mean_values[valid_mask]) + slope, _ = np.polyfit(log_taus, log_means, 1) + + slopes_data.append(slope) + intervals.append(interval_label) + norm_types.append(labels[norm]) + +# Create bar plot +plt.figure(figsize=(12, 6)) + +# Set up positions for grouped bars +unique_intervals = list(feature_paths.keys()) +unique_norms = [labels[n] for n in norm_strategies] +x = np.arange(len(unique_intervals)) +width = 0.8 / len(norm_strategies) # Width of bars + +for i, norm_label in enumerate(unique_norms): + mask = np.array(norm_types) == norm_label + norm_slopes = np.array(slopes_data)[mask] + norm_intervals = np.array(intervals)[mask] + + positions = x + (i - len(norm_strategies) / 2 + 0.5) * width + + plt.bar(positions, norm_slopes, width, label=norm_label, alpha=0.7) + +# Add reference lines +plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") +plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) + +plt.xlabel("Time Interval") +plt.ylabel("Slope (α)") +plt.title("MSD Slopes by Time Interval and Normalization Strategy") +plt.xticks(x, unique_intervals, rotation=45) +plt.legend(title="Normalization", bbox_to_anchor=(1.05, 1), loc="upper left") + +# Add annotations for diffusion regimes +plt.text( + plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center" +) +plt.text( + plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center" +) + +plt.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + # %% From 1ab24d2b780b46978642cfa6a20a100c42b4b23f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 7 Jan 2025 21:50:37 -0800 Subject: [PATCH 05/38] making a dictionary to plot easier --- .../evaluation/ALFI_MSD_v2.py | 63 +++++-------------- 1 file changed, 16 insertions(+), 47 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index 6053b7245..a13ce2c8d 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -18,9 +18,8 @@ "Classical": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr", } -# Different normalization strategies to test -norm_strategies = [None, "per_feature", "per_embedding", "per_dataset"] -labels = { +# Different normalization strategies and their labels +norm_strategies = { None: "Raw", "per_feature": "Per-feature z-score", "per_embedding": "Unit norm", @@ -45,7 +44,7 @@ print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) - for norm in norm_strategies: + for norm, norm_label in norm_strategies.items(): # Compute displacements with different normalization strategies displacements = compute_displacement( embedding_dataset=embedding_dataset, @@ -53,12 +52,12 @@ normalize=norm, ) means, stds = compute_displacement_statistics(displacements) - results[f"{label} ({labels[norm]})"] = (means, stds) - raw_displacements[f"{label} ({labels[norm]})"] = displacements + results[f"{label} ({norm_label})"] = (means, stds) + raw_displacements[f"{label} ({norm_label})"] = displacements # Print some statistics taus = sorted(means.keys()) - print(f"\n{labels[norm]}:") + print(f"\n{norm_label}:") print(f" Number of different τ values: {len(taus)}") print(f" τ range: {min(taus)} to {max(taus)}") print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") @@ -67,12 +66,12 @@ fig, axes = plt.subplots(2, 2, figsize=(15, 12)) axes = axes.ravel() -for i, norm in enumerate(norm_strategies): +for i, (norm, norm_label) in enumerate(norm_strategies.items()): ax = axes[i] # Plot each time interval for this normalization strategy for interval_label, path in feature_paths.items(): - result_label = f"{interval_label} ({labels[norm]})" + result_label = f"{interval_label} ({norm_label})" means, stds = results[result_label] # Sort by tau for plotting @@ -99,7 +98,7 @@ ax.set_xlabel("Time Shift (τ)") ax.set_ylabel("Mean Square Displacement") - ax.set_title(f"MSD vs Time Shift\n({labels[norm]})") + ax.set_title(f"MSD vs Time Shift\n({norm_label})") ax.grid(True, alpha=0.3) ax.legend() @@ -110,12 +109,12 @@ fig, axes = plt.subplots(2, 2, figsize=(15, 12)) axes = axes.ravel() -for i, norm in enumerate(norm_strategies): +for i, (norm, norm_label) in enumerate(norm_strategies.items()): ax = axes[i] # Plot each time interval for this normalization strategy for interval_label, path in feature_paths.items(): - result_label = f"{interval_label} ({labels[norm]})" + result_label = f"{interval_label} ({norm_label})" means, stds = results[result_label] # Sort by tau for plotting @@ -165,51 +164,21 @@ ax.set_yscale("log") ax.set_xlabel("Time Shift (τ)") ax.set_ylabel("Mean Square Displacement") - ax.set_title(f"MSD vs Time Shift (log-log)\n({labels[norm]})") + ax.set_title(f"MSD vs Time Shift (log-log)\n({norm_label})") ax.grid(True, alpha=0.3, which="both") ax.legend(title="α = slope in log-log space") plt.tight_layout() plt.show() -# %% Print detailed slope analysis -print("\nSlope Analysis (α):") -print("α = 1: Normal diffusion") -print("α < 1: Sub-diffusion (confined/hindered)") -print("α > 1: Super-diffusion (directed/active)\n") - -for norm in norm_strategies: - print(f"\n{labels[norm]}:") - for interval_label in feature_paths.keys(): - result_label = f"{interval_label} ({labels[norm]})" - means, _ = results[result_label] - - # Calculate slope - taus = np.array(sorted(means.keys())) - mean_values = np.array([means[tau] for tau in taus]) - valid_mask = mean_values > 0 - - if np.sum(valid_mask) > 1: - log_taus = np.log(taus[valid_mask]) - log_means = np.log(mean_values[valid_mask]) - slope, _ = np.polyfit(log_taus, log_means, 1) - - motion_type = ( - "normal diffusion" - if abs(slope - 1) < 0.1 - else "sub-diffusion" if slope < 1 else "super-diffusion" - ) - - print(f" {interval_label}: α = {slope:.2f} ({motion_type})") - # %% Plot slopes analysis slopes_data = [] intervals = [] norm_types = [] -for norm in norm_strategies: +for norm, norm_label in norm_strategies.items(): for interval_label in feature_paths.keys(): - result_label = f"{interval_label} ({labels[norm]})" + result_label = f"{interval_label} ({norm_label})" means, _ = results[result_label] # Calculate slope @@ -224,14 +193,14 @@ slopes_data.append(slope) intervals.append(interval_label) - norm_types.append(labels[norm]) + norm_types.append(norm_label) # Create bar plot plt.figure(figsize=(12, 6)) # Set up positions for grouped bars unique_intervals = list(feature_paths.keys()) -unique_norms = [labels[n] for n in norm_strategies] +unique_norms = list(norm_strategies.values()) x = np.arange(len(unique_intervals)) width = 0.8 / len(norm_strategies) # Width of bars From 310688e61d110661cd1fdc41149a4fcc1aeceec7 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 8 Jan 2025 20:19:45 -0800 Subject: [PATCH 06/38] simplify MSD. --- .../evaluation/ALFI_MSD_v2.py | 321 +++++++++--------- viscy/representation/evaluation/distance.py | 86 +---- 2 files changed, 164 insertions(+), 243 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index a13ce2c8d..00e9effce 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -18,14 +18,6 @@ "Classical": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr", } -# Different normalization strategies and their labels -norm_strategies = { - None: "Raw", - "per_feature": "Per-feature z-score", - "per_embedding": "Unit norm", - "per_dataset": "Dataset z-score", -} - # Colors for different time intervals interval_colors = { "7 min interval": "blue", @@ -44,174 +36,179 @@ print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) - for norm, norm_label in norm_strategies.items(): - # Compute displacements with different normalization strategies - displacements = compute_displacement( - embedding_dataset=embedding_dataset, - distance_metric="euclidean_squared", - normalize=norm, - ) - means, stds = compute_displacement_statistics(displacements) - results[f"{label} ({norm_label})"] = (means, stds) - raw_displacements[f"{label} ({norm_label})"] = displacements - - # Print some statistics - taus = sorted(means.keys()) - print(f"\n{norm_label}:") - print(f" Number of different τ values: {len(taus)}") - print(f" τ range: {min(taus)} to {max(taus)}") - print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") - -# %% Plot MSD vs time - one plot per normalization strategy (linear scale) -fig, axes = plt.subplots(2, 2, figsize=(15, 12)) -axes = axes.ravel() - -for i, (norm, norm_label) in enumerate(norm_strategies.items()): - ax = axes[i] - - # Plot each time interval for this normalization strategy - for interval_label, path in feature_paths.items(): - result_label = f"{interval_label} ({norm_label})" - means, stds = results[result_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - ax.plot( - taus, - mean_values, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - ax.scatter( - taus, - mean_values, - color=interval_colors[interval_label], - s=20, - label=f"{interval_label}", - zorder=2, - ) - - ax.set_xlabel("Time Shift (τ)") - ax.set_ylabel("Mean Square Displacement") - ax.set_title(f"MSD vs Time Shift\n({norm_label})") - ax.grid(True, alpha=0.3) - ax.legend() - + # Compute displacements + displacements = compute_displacement( + embedding_dataset=embedding_dataset, + distance_metric="euclidean_squared", + ) + means, stds = compute_displacement_statistics(displacements) + results[label] = (means, stds) + raw_displacements[label] = displacements + + # Print some statistics + taus = sorted(means.keys()) + print(f" Number of different τ values: {len(taus)}") + print(f" τ range: {min(taus)} to {max(taus)}") + print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") + +# %% Plot MSD vs time (linear scale) +plt.figure(figsize=(10, 6)) + +# Plot each time interval +for interval_label, path in feature_paths.items(): + means, stds = results[interval_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + plt.plot( + taus, + mean_values, + "-", + color=interval_colors[interval_label], + alpha=0.5, + zorder=1, + ) + plt.scatter( + taus, + mean_values, + color=interval_colors[interval_label], + s=20, + label=interval_label, + zorder=2, + ) + +plt.xlabel("Time Shift (τ)") +plt.ylabel("Mean Square Displacement") +plt.title("MSD vs Time Shift") +plt.grid(True, alpha=0.3) +plt.legend() plt.tight_layout() plt.show() -# %% Plot MSD vs time - one plot per normalization strategy (log-log scale with slopes) -fig, axes = plt.subplots(2, 2, figsize=(15, 12)) -axes = axes.ravel() - -for i, (norm, norm_label) in enumerate(norm_strategies.items()): - ax = axes[i] - - # Plot each time interval for this normalization strategy - for interval_label, path in feature_paths.items(): - result_label = f"{interval_label} ({norm_label})" - means, stds = results[result_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - # Filter out non-positive values for log scale - valid_mask = np.array(mean_values) > 0 - valid_taus = np.array(taus)[valid_mask] - valid_means = np.array(mean_values)[valid_mask] - - # Calculate slope using linear regression on log-log values - log_taus = np.log(valid_taus) - log_means = np.log(valid_means) - slope, intercept = np.polyfit(log_taus, log_means, 1) - - ax.plot( - valid_taus, - valid_means, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - ax.scatter( - valid_taus, - valid_means, - color=interval_colors[interval_label], - s=20, - label=f"{interval_label} (α={slope:.2f})", - zorder=2, - ) - - # Plot fitted line - fit_line = np.exp(intercept + slope * log_taus) - ax.plot( - valid_taus, - fit_line, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) - - ax.set_xscale("log") - ax.set_yscale("log") - ax.set_xlabel("Time Shift (τ)") - ax.set_ylabel("Mean Square Displacement") - ax.set_title(f"MSD vs Time Shift (log-log)\n({norm_label})") - ax.grid(True, alpha=0.3, which="both") - ax.legend(title="α = slope in log-log space") - +# %% Plot MSD vs time (log-log scale with slopes) +plt.figure(figsize=(10, 6)) + +# Plot each time interval +for interval_label, path in feature_paths.items(): + means, stds = results[interval_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + # Filter out non-positive values for log scale + valid_mask = np.array(mean_values) > 0 + valid_taus = np.array(taus)[valid_mask] + valid_means = np.array(mean_values)[valid_mask] + + # Calculate slopes for different regions + log_taus = np.log(valid_taus) + log_means = np.log(valid_means) + + # Early slope (first third of points) + n_points = len(log_taus) + early_end = n_points // 3 + early_slope, early_intercept = np.polyfit( + log_taus[:early_end], log_means[:early_end], 1 + ) + + # Late slope (last third of points) + late_start = 2 * (n_points // 3) + late_slope, late_intercept = np.polyfit( + log_taus[late_start:], log_means[late_start:], 1 + ) + + plt.plot( + valid_taus, + valid_means, + "-", + color=interval_colors[interval_label], + alpha=0.5, + zorder=1, + ) + plt.scatter( + valid_taus, + valid_means, + color=interval_colors[interval_label], + s=20, + label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})", + zorder=2, + ) + + # Plot fitted lines for early and late regions + early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) + late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:]) + + plt.plot( + valid_taus[:early_end], + early_fit, + "--", + color=interval_colors[interval_label], + alpha=0.3, + zorder=1, + ) + plt.plot( + valid_taus[late_start:], + late_fit, + "--", + color=interval_colors[interval_label], + alpha=0.3, + zorder=1, + ) + +plt.xscale("log") +plt.yscale("log") +plt.xlabel("Time Shift (τ)") +plt.ylabel("Mean Square Displacement") +plt.title("MSD vs Time Shift (log-log)") +plt.grid(True, alpha=0.3, which="both") +plt.legend( + title="α = slope in log-log space", bbox_to_anchor=(1.05, 1), loc="upper left" +) plt.tight_layout() plt.show() # %% Plot slopes analysis -slopes_data = [] +early_slopes = [] +late_slopes = [] intervals = [] -norm_types = [] -for norm, norm_label in norm_strategies.items(): - for interval_label in feature_paths.keys(): - result_label = f"{interval_label} ({norm_label})" - means, _ = results[result_label] +for interval_label in feature_paths.keys(): + means, _ = results[interval_label] - # Calculate slope - taus = np.array(sorted(means.keys())) - mean_values = np.array([means[tau] for tau in taus]) - valid_mask = mean_values > 0 + # Calculate slopes + taus = np.array(sorted(means.keys())) + mean_values = np.array([means[tau] for tau in taus]) + valid_mask = mean_values > 0 - if np.sum(valid_mask) > 1: # Need at least 2 points for slope - log_taus = np.log(taus[valid_mask]) - log_means = np.log(mean_values[valid_mask]) - slope, _ = np.polyfit(log_taus, log_means, 1) + if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes + log_taus = np.log(taus[valid_mask]) + log_means = np.log(mean_values[valid_mask]) - slopes_data.append(slope) - intervals.append(interval_label) - norm_types.append(norm_label) + # Calculate early and late slopes + n_points = len(log_taus) + early_end = n_points // 3 + late_start = 2 * (n_points // 3) -# Create bar plot -plt.figure(figsize=(12, 6)) + early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) + late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1) -# Set up positions for grouped bars -unique_intervals = list(feature_paths.keys()) -unique_norms = list(norm_strategies.values()) -x = np.arange(len(unique_intervals)) -width = 0.8 / len(norm_strategies) # Width of bars + early_slopes.append(early_slope) + late_slopes.append(late_slope) + intervals.append(interval_label) -for i, norm_label in enumerate(unique_norms): - mask = np.array(norm_types) == norm_label - norm_slopes = np.array(slopes_data)[mask] - norm_intervals = np.array(intervals)[mask] +# Create bar plot +plt.figure(figsize=(12, 6)) - positions = x + (i - len(norm_strategies) / 2 + 0.5) * width +x = np.arange(len(intervals)) +width = 0.35 - plt.bar(positions, norm_slopes, width, label=norm_label, alpha=0.7) +plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) +plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7) # Add reference lines plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") @@ -219,9 +216,9 @@ plt.xlabel("Time Interval") plt.ylabel("Slope (α)") -plt.title("MSD Slopes by Time Interval and Normalization Strategy") -plt.xticks(x, unique_intervals, rotation=45) -plt.legend(title="Normalization", bbox_to_anchor=(1.05, 1), loc="upper left") +plt.title("MSD Slopes by Time Interval") +plt.xticks(x, intervals, rotation=45) +plt.legend() # Add annotations for diffusion regimes plt.text( diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 51caecc87..35977c9ce 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -21,54 +21,14 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities.tolist() -def normalize_embeddings( - embeddings: np.ndarray, - strategy: Literal["per_feature", "per_embedding", "per_dataset"] = "per_feature", -) -> np.ndarray: - """Normalize embeddings using different strategies. - - Parameters - ---------- - embeddings : np.ndarray - Array of shape (n_samples, n_features) containing embeddings - strategy : str - Normalization strategy: - - "per_feature": z-score each feature across all samples - - "per_embedding": normalize each embedding vector to unit norm - - "per_dataset": z-score entire dataset (across all features and samples) - - Returns - ------- - np.ndarray - Normalized embeddings with same shape as input - """ - if strategy == "per_feature": - # Normalize each feature independently - return (embeddings - np.mean(embeddings, axis=0)) / np.std(embeddings, axis=0) - elif strategy == "per_embedding": - # Normalize each embedding to unit norm - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - return embeddings / norms - elif strategy == "per_dataset": - # Normalize entire dataset - return (embeddings - np.mean(embeddings)) / np.std(embeddings) - else: - raise ValueError(f"Unknown normalization strategy: {strategy}") - - def compute_displacement( embedding_dataset, - distance_metric: Literal[ - "euclidean", "euclidean_squared", "cosine", "cosine_dissimilarity" - ] = "euclidean_squared", - embedding_coords: Literal["UMAP", "PHATE", None] = None, - normalize: Literal["per_feature", "per_embedding", "per_dataset", None] = None, + distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", ) -> Dict[int, List[float]]: """Compute the displacement or mean square displacement (MSD) of embeddings. For each time difference τ, computes either: - |r(t + τ) - r(t)|² for squared Euclidean (MSD) - - |r(t + τ) - r(t)| for Euclidean - cos_sim(r(t + τ), r(t)) for cosine for all particles and initial times t. @@ -83,17 +43,6 @@ def compute_displacement( - "euclidean_squared": Squared Euclidean distance (for MSD, default) - "cosine": Cosine similarity - "cosine_dissimilarity": 1 - cosine similarity - embedding_coords : str or None - Which embedding coordinates to use for distance computation: - - None: Use original features from dataset (default) - - "UMAP": Use UMAP coordinates (UMAP1, UMAP2) - - "PHATE": Use PHATE coordinates (PHATE1, PHATE2) - normalize : str or None - Normalization strategy to apply to embeddings before computing distances: - - None: No normalization (default) - - "per_feature": z-score each feature across all samples - - "per_embedding": normalize each embedding vector to unit norm - - "per_dataset": z-score entire dataset (across all features and samples) Returns ------- @@ -104,23 +53,7 @@ def compute_displacement( fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values - - # Get embeddings based on specified coordinates - if embedding_coords == "UMAP": - embeddings = np.vstack( - (embedding_dataset["UMAP1"].values, embedding_dataset["UMAP2"].values) - ).T - elif embedding_coords == "PHATE": - embeddings = np.vstack( - (embedding_dataset["PHATE1"].values, embedding_dataset["PHATE2"].values) - ).T - else: - embeddings = embedding_dataset["features"].values - - # Normalize embeddings if requested - if normalize is not None: - embeddings = normalize_embeddings(embeddings, strategy=normalize) - + embeddings = embedding_dataset["features"].values # Get unique tracks unique_tracks = set(zip(fov_names, track_ids)) @@ -150,25 +83,16 @@ def compute_displacement( tau = future_time - t future_embedding = track_embeddings[future_idx] - if distance_metric in ["cosine", "cosine_dissimilarity"]: + if distance_metric in ["cosine"]: dot_product = np.dot(current_embedding, future_embedding) norms = np.linalg.norm(current_embedding) * np.linalg.norm( future_embedding ) similarity = dot_product / norms - displacement = ( - 1 - similarity - if distance_metric == "cosine_dissimilarity" - else similarity - ) + displacement = similarity else: # Euclidean metrics diff_squared = np.sum((current_embedding - future_embedding) ** 2) - displacement = ( - diff_squared - if distance_metric == "euclidean_squared" - else np.sqrt(diff_squared) - ) - + displacement = diff_squared displacement_per_tau[int(tau)].append(displacement) return dict(displacement_per_tau) From 2e8d2431a4eb0b8b66c0c9768f9458ba22ca2c67 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 12:25:39 -0800 Subject: [PATCH 07/38] add scaling and change to log plot --- .../evaluation/ALFI_MSD_v2.py | 104 ++++++++---------- 1 file changed, 48 insertions(+), 56 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index 00e9effce..18f537759 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -10,21 +10,21 @@ # Paths to datasets feature_paths = { - "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", - "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", - "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", - "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", - "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", - "Classical": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr", + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr", + "14 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr", + "91 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr", + "Classical": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr", } # Colors for different time intervals interval_colors = { "7 min interval": "blue", - "21 min interval": "red", + "14 min interval": "red", "28 min interval": "green", "56 min interval": "purple", - "Cell Aware": "orange", + "91 min interval": "orange", "Classical": "gray", } @@ -89,6 +89,7 @@ plt.show() # %% Plot MSD vs time (log-log scale with slopes) +n_dimensions = 768 plt.figure(figsize=(10, 6)) # Plot each time interval @@ -115,50 +116,52 @@ early_slope, early_intercept = np.polyfit( log_taus[:early_end], log_means[:early_end], 1 ) + early_slope /= 2 * n_dimensions - # Late slope (last third of points) + # middle slope (mid portions of points) late_start = 2 * (n_points // 3) - late_slope, late_intercept = np.polyfit( - log_taus[late_start:], log_means[late_start:], 1 + mid_slope, mid_intercept = np.polyfit( + log_taus[early_end:late_start], log_means[early_end:late_start], 1 ) + mid_slope /= 2 * n_dimensions plt.plot( - valid_taus, - valid_means, + log_taus, + log_means, "-", color=interval_colors[interval_label], alpha=0.5, zorder=1, ) plt.scatter( - valid_taus, - valid_means, + log_taus, + log_means, color=interval_colors[interval_label], s=20, - label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})", + label=f"{interval_label} (α_early={early_slope:.2f}, α_mid={mid_slope:.2f})", zorder=2, ) - # Plot fitted lines for early and late regions - early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) - late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:]) - - plt.plot( - valid_taus[:early_end], - early_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) - plt.plot( - valid_taus[late_start:], - late_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) + # # Plot fitted lines for early and late regions + # early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) + # mid_fit = np.exp(mid_intercept + mid_slope * log_taus[early_end:late_start]) + + # plt.plot( + # early_fit, + # log_taus[:early_end], + # "--", + # color=interval_colors[interval_label], + # alpha=0.3, + # zorder=1, + # ) + # plt.plot( + # mid_fit, + # log_taus[early_end:late_start], + # "--", + # color=interval_colors[interval_label], + # alpha=0.3, + # zorder=1, + # ) plt.xscale("log") plt.yscale("log") @@ -174,7 +177,7 @@ # %% Plot slopes analysis early_slopes = [] -late_slopes = [] +mid_slopes = [] intervals = [] for interval_label in feature_paths.keys(): @@ -189,16 +192,16 @@ log_taus = np.log(taus[valid_mask]) log_means = np.log(mean_values[valid_mask]) - # Calculate early and late slopes + # Calculate early and mid slopes n_points = len(log_taus) early_end = n_points // 3 late_start = 2 * (n_points // 3) early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) - late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1) + mid_slope, _ = np.polyfit(log_taus[early_end:late_start], log_means[early_end:late_start], 1) - early_slopes.append(early_slope) - late_slopes.append(late_slope) + early_slopes.append(early_slope/(2*n_dimensions)) + mid_slopes.append(mid_slope/(2*n_dimensions)) intervals.append(interval_label) # Create bar plot @@ -208,28 +211,17 @@ width = 0.35 plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) -plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7) +plt.bar(x + width / 2, mid_slopes, width, label="Mid slope", alpha=0.7) -# Add reference lines -plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") -plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) +# # Add reference lines +# plt.axhline(y=0.001, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") +# plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) plt.xlabel("Time Interval") plt.ylabel("Slope (α)") plt.title("MSD Slopes by Time Interval") plt.xticks(x, intervals, rotation=45) plt.legend() - -# Add annotations for diffusion regimes -plt.text( - plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center" -) -plt.text( - plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center" -) - -plt.grid(True, alpha=0.3) -plt.tight_layout() plt.show() # %% From a63fe23fe528ce63a524293e5cfb0dbd38493a4a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 9 Jan 2025 13:37:30 -0800 Subject: [PATCH 08/38] reverting and restricting to only euclidean --- viscy/representation/evaluation/distance.py | 127 +++++++++----------- 1 file changed, 60 insertions(+), 67 deletions(-) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 35977c9ce..b7b563854 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,8 +1,9 @@ from collections import defaultdict -from typing import Dict, List, Literal, Tuple, Union, Optional +from typing import Dict, List, Literal, Tuple import numpy as np from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): @@ -24,88 +25,80 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_displacement( embedding_dataset, distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", + max_delta_t: int = None, ) -> Dict[int, List[float]]: - """Compute the displacement or mean square displacement (MSD) of embeddings. + """Compute displacements between embeddings at different time differences. - For each time difference τ, computes either: - - |r(t + τ) - r(t)|² for squared Euclidean (MSD) - - cos_sim(r(t + τ), r(t)) for cosine - for all particles and initial times t. + For each time difference τ, computes distances between embeddings of the same cell + separated by τ timepoints. Supports multiple distance metrics. Parameters ---------- embedding_dataset : xarray.Dataset - Dataset containing embeddings and metadata - distance_metric : str + Dataset containing embeddings and metadata with the following variables: + - features: (N, D) array of embeddings + - fov_name: (N,) array of field of view names + - track_id: (N,) array of cell track IDs + - t: (N,) array of timepoints + distance_metric : str, optional The metric to use for computing distances between embeddings. Valid options are: - - "euclidean": Euclidean distance (L2 norm) - - "euclidean_squared": Squared Euclidean distance (for MSD, default) + - "euclidean_squared": Squared Euclidean distance (default) - "cosine": Cosine similarity - - "cosine_dissimilarity": 1 - cosine similarity + max_delta_t : int, optional + Maximum time difference τ to compute displacements for. + If None, uses the maximum possible time difference in the dataset. Returns ------- Dict[int, List[float]] - Dictionary mapping τ to list of displacements for all particles and initial times + Dictionary mapping time difference τ to list of displacements. + Each displacement value represents the distance between a pair of + embeddings from the same cell separated by τ timepoints. """ + # Get data from dataset fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values embeddings = embedding_dataset["features"].values - # Get unique tracks - unique_tracks = set(zip(fov_names, track_ids)) - - # Initialize results dictionary with empty lists - displacement_per_tau = defaultdict(list) - - # Process each track - for fov_name, track_id in unique_tracks: - # Get sorted track data - mask = (fov_names == fov_name) & (track_ids == track_id) - times = timepoints[mask] - track_embeddings = embeddings[mask] - - # Sort by time - time_order = np.argsort(times) - times = times[time_order] - track_embeddings = track_embeddings[time_order] - - # Process each time point - for t_idx, t in enumerate(times[:-1]): - current_embedding = track_embeddings[t_idx] - - # Check all possible future time points - for future_idx, future_time in enumerate( - times[t_idx + 1 :], start=t_idx + 1 - ): - tau = future_time - t - future_embedding = track_embeddings[future_idx] - - if distance_metric in ["cosine"]: - dot_product = np.dot(current_embedding, future_embedding) - norms = np.linalg.norm(current_embedding) * np.linalg.norm( - future_embedding - ) - similarity = dot_product / norms - displacement = similarity - else: # Euclidean metrics - diff_squared = np.sum((current_embedding - future_embedding) ** 2) - displacement = diff_squared - displacement_per_tau[int(tau)].append(displacement) - - return dict(displacement_per_tau) + + # Check if max_delta_t is provided, otherwise use the maximum timepoint + if max_delta_t is None: + max_delta_t = timepoints.max() + + displacement_per_delta_t = defaultdict(list) + # Process each sample + for i in tqdm(range(len(fov_names)), desc="Processing FOVs"): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i].reshape(1, -1) + + # Compute displacements for each delta t + for delta_t in range(1, max_delta_t + 1): + future_time = current_time + delta_t + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = np.sum((current_embedding - future_embedding) ** 2) + displacement_per_delta_t[delta_t].append(displacement) + return dict(displacement_per_delta_t) def compute_displacement_statistics( - displacement_per_tau: Dict[int, List[float]] + displacement_per_delta_t: Dict[int, List[float]] ) -> Tuple[Dict[int, float], Dict[int, float]]: - """Compute mean and standard deviation of displacements for each tau. + """Compute mean and standard deviation of displacements for each delta_t. Parameters ---------- - displacement_per_tau : Dict[int, List[float]] + displacement_per_delta_t : Dict[int, List[float]] Dictionary mapping τ to list of displacements Returns @@ -114,29 +107,29 @@ def compute_displacement_statistics( Tuple of (mean_displacements, std_displacements) where each is a dictionary mapping τ to the statistic """ - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() + mean_displacement_per_delta_t = { + delta_t: np.mean(displacements) + for delta_t, displacements in displacement_per_delta_t.items() } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() + std_displacement_per_delta_t = { + delta_t: np.std(displacements) + for delta_t, displacements in displacement_per_delta_t.items() } - return mean_displacement_per_tau, std_displacement_per_tau + return mean_displacement_per_delta_t, std_displacement_per_delta_t -def compute_dynamic_range(mean_displacement_per_tau): +def compute_dynamic_range(mean_displacement_per_delta_t): """ Compute the dynamic range as the difference between the maximum and minimum mean displacement per τ. Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value + mean_displacement_per_delta_t: dict with τ as key and mean displacement as value Returns: float: dynamic range (max displacement - min displacement) """ - displacements = list(mean_displacement_per_tau.values()) + displacements = list(mean_displacement_per_delta_t.values()) return max(displacements) - min(displacements) From b57b2cf72140f351ea88629c49b72906e59cdbbc Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 9 Jan 2025 14:08:07 -0800 Subject: [PATCH 09/38] changing tau to delta_t and frames to seconds --- .../evaluation/ALFI_MSD_v2.py | 100 +++++++----------- 1 file changed, 40 insertions(+), 60 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py index 18f537759..fd7e0ad84 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -35,7 +35,7 @@ for label, path in feature_paths.items(): print(f"\nProcessing {label}...") embedding_dataset = read_embedding_dataset(Path(path)) - + embedding_dimension = embedding_dataset["features"].shape[1] # Compute displacements displacements = compute_displacement( embedding_dataset=embedding_dataset, @@ -46,25 +46,27 @@ raw_displacements[label] = displacements # Print some statistics - taus = sorted(means.keys()) - print(f" Number of different τ values: {len(taus)}") - print(f" τ range: {min(taus)} to {max(taus)}") + delta_t = sorted(means.keys()) + print(f" Number of different τ values: {len(delta_t)}") + print(f" τ range: {min(delta_t)} to {max(delta_t)}") print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") # %% Plot MSD vs time (linear scale) plt.figure(figsize=(10, 6)) +SECONDS_PER_FRAME = 7 * 60 # seconds # Plot each time interval for interval_label, path in feature_paths.items(): means, stds = results[interval_label] - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] + # Sort by delta_t for plotting + delta_t = sorted(means.keys()) + mean_values = [means[delta_t] for delta_t in delta_t] + std_values = [stds[delta_t] for delta_t in delta_t] + delta_t_seconds = [i * SECONDS_PER_FRAME for i in delta_t] plt.plot( - taus, + delta_t_seconds, mean_values, "-", color=interval_colors[interval_label], @@ -72,7 +74,7 @@ zorder=1, ) plt.scatter( - taus, + delta_t_seconds, mean_values, color=interval_colors[interval_label], s=20, @@ -80,7 +82,7 @@ zorder=2, ) -plt.xlabel("Time Shift (τ)") +plt.xlabel("Time Shift (seconds)") plt.ylabel("Mean Square Displacement") plt.title("MSD vs Time Shift") plt.grid(True, alpha=0.3) @@ -89,44 +91,44 @@ plt.show() # %% Plot MSD vs time (log-log scale with slopes) -n_dimensions = 768 plt.figure(figsize=(10, 6)) # Plot each time interval for interval_label, path in feature_paths.items(): means, stds = results[interval_label] - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] + # Sort by delta_t for plotting + delta_t = sorted(means.keys()) + mean_values = [means[delta_t] for delta_t in delta_t] + std_values = [stds[delta_t] for delta_t in delta_t] + delta_t_seconds = [i * SECONDS_PER_FRAME for i in delta_t] # Filter out non-positive values for log scale valid_mask = np.array(mean_values) > 0 - valid_taus = np.array(taus)[valid_mask] + valid_delta_t = np.array(delta_t_seconds)[valid_mask] valid_means = np.array(mean_values)[valid_mask] # Calculate slopes for different regions - log_taus = np.log(valid_taus) + log_delta_t = np.log(valid_delta_t) log_means = np.log(valid_means) # Early slope (first third of points) - n_points = len(log_taus) + n_points = len(log_delta_t) early_end = n_points // 3 early_slope, early_intercept = np.polyfit( - log_taus[:early_end], log_means[:early_end], 1 + log_delta_t[:early_end], log_means[:early_end], 1 ) - early_slope /= 2 * n_dimensions + early_slope /= 2 * embedding_dimension # middle slope (mid portions of points) late_start = 2 * (n_points // 3) mid_slope, mid_intercept = np.polyfit( - log_taus[early_end:late_start], log_means[early_end:late_start], 1 + log_delta_t[early_end:late_start], log_means[early_end:late_start], 1 ) - mid_slope /= 2 * n_dimensions + mid_slope /= 2 * embedding_dimension plt.plot( - log_taus, + log_delta_t, log_means, "-", color=interval_colors[interval_label], @@ -134,38 +136,17 @@ zorder=1, ) plt.scatter( - log_taus, + log_delta_t, log_means, color=interval_colors[interval_label], s=20, - label=f"{interval_label} (α_early={early_slope:.2f}, α_mid={mid_slope:.2f})", + label=f"{interval_label} (α_early={early_slope:.2e}, α_mid={mid_slope:.2e})", zorder=2, ) - # # Plot fitted lines for early and late regions - # early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) - # mid_fit = np.exp(mid_intercept + mid_slope * log_taus[early_end:late_start]) - - # plt.plot( - # early_fit, - # log_taus[:early_end], - # "--", - # color=interval_colors[interval_label], - # alpha=0.3, - # zorder=1, - # ) - # plt.plot( - # mid_fit, - # log_taus[early_end:late_start], - # "--", - # color=interval_colors[interval_label], - # alpha=0.3, - # zorder=1, - # ) - plt.xscale("log") plt.yscale("log") -plt.xlabel("Time Shift (τ)") +plt.xlabel("Time Shift (seconds)") plt.ylabel("Mean Square Displacement") plt.title("MSD vs Time Shift (log-log)") plt.grid(True, alpha=0.3, which="both") @@ -184,24 +165,27 @@ means, _ = results[interval_label] # Calculate slopes - taus = np.array(sorted(means.keys())) - mean_values = np.array([means[tau] for tau in taus]) + delta_t = np.array(sorted(means.keys())) + mean_values = np.array([means[delta_t] for delta_t in delta_t]) valid_mask = mean_values > 0 + delta_t_seconds = [i * SECONDS_PER_FRAME for i in delta_t] if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes - log_taus = np.log(taus[valid_mask]) + log_delta_t = np.log(delta_t[valid_mask]) log_means = np.log(mean_values[valid_mask]) # Calculate early and mid slopes - n_points = len(log_taus) + n_points = len(log_delta_t) early_end = n_points // 3 late_start = 2 * (n_points // 3) - early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) - mid_slope, _ = np.polyfit(log_taus[early_end:late_start], log_means[early_end:late_start], 1) + early_slope, _ = np.polyfit(log_delta_t[:early_end], log_means[:early_end], 1) + mid_slope, _ = np.polyfit( + log_delta_t[early_end:late_start], log_means[early_end:late_start], 1 + ) - early_slopes.append(early_slope/(2*n_dimensions)) - mid_slopes.append(mid_slope/(2*n_dimensions)) + early_slopes.append(early_slope / (2 * embedding_dimension)) + mid_slopes.append(mid_slope / (2 * embedding_dimension)) intervals.append(interval_label) # Create bar plot @@ -213,10 +197,6 @@ plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) plt.bar(x + width / 2, mid_slopes, width, label="Mid slope", alpha=0.7) -# # Add reference lines -# plt.axhline(y=0.001, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") -# plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) - plt.xlabel("Time Interval") plt.ylabel("Slope (α)") plt.title("MSD Slopes by Time Interval") From 915cc24dbea2220b004da98150529e4aed825609 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 14:55:24 -0800 Subject: [PATCH 10/38] delete scratch scripts --- .../predict_infection_score_supervised.py | 166 ----------------- .../figures/classify_feb.py | 100 ----------- .../figures/classify_feb_embeddings.py | 94 ---------- .../figures/classify_june.py | 121 ------------- .../figures/figure_4a_1.py | 167 ------------------ .../figures/figure_4e_2_feb.py | 87 --------- .../figures/figure_4e_2_june.py | 85 --------- .../figures/save_patches.py | 67 ------- 8 files changed, 887 deletions(-) delete mode 100644 applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py delete mode 100644 applications/contrastive_phenotyping/figures/classify_feb.py delete mode 100644 applications/contrastive_phenotyping/figures/classify_feb_embeddings.py delete mode 100644 applications/contrastive_phenotyping/figures/classify_june.py delete mode 100644 applications/contrastive_phenotyping/figures/figure_4a_1.py delete mode 100644 applications/contrastive_phenotyping/figures/figure_4e_2_feb.py delete mode 100644 applications/contrastive_phenotyping/figures/figure_4e_2_june.py delete mode 100644 applications/contrastive_phenotyping/figures/save_patches.py diff --git a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py b/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py deleted file mode 100644 index fb64b9f07..000000000 --- a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import warnings -from argparse import ArgumentParser - -import numpy as np -import pandas as pd -from torch.utils.data import DataLoader -from tqdm import tqdm - -from viscy.data.triplet import TripletDataModule - -warnings.filterwarnings( - "ignore", - category=UserWarning, - message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", -) - -# %% Paths and constants -save_dir = ( - "/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4" -) - -# rechunked data -data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr" - -# updated tracking data -tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -source_channel = ["background_mask", "uninfected_mask", "infected_mask"] -z_range = (0, 1) -batch_size = 1 # match the number of fovs being processed such that no data is left -# set to 15 for full, 12 for infected, and 8 for uninfected - -# non-rechunked data -data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -# updated tracking data -tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -source_channel_1 = ["Nuclei_prediction_labels"] - - -# %% Define the main function for training -def main(hparams): - # Initialize the data module for prediction, re-do embeddings but with size 224 by 224 - data_module = TripletDataModule( - data_path=data_path, - tracks_path=tracks_path, - source_channel=source_channel, - z_range=z_range, - initial_yx_patch_size=(224, 224), - final_yx_patch_size=(224, 224), - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - data_module.setup(stage="predict") - - print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") - - dataloader = DataLoader( - data_module.predict_dataset, - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - # Initialize the second data module for segmentation masks - seg_data_module = TripletDataModule( - data_path=data_path_1, - tracks_path=tracks_path_1, - source_channel=source_channel_1, - z_range=z_range, - initial_yx_patch_size=(224, 224), - final_yx_patch_size=(224, 224), - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - seg_data_module.setup(stage="predict") - - seg_dataloader = DataLoader( - seg_data_module.predict_dataset, - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - # Initialize lists to store average values - background_avg = [] - uninfected_avg = [] - infected_avg = [] - - for batch, seg_batch in tqdm( - zip(dataloader, seg_dataloader), - desc="Processing batches", - total=len(data_module.predict_dataset), - ): - anchor = batch["anchor"] - seg_anchor = seg_batch["anchor"].int() - - # Extract the fov_name and id from the batch - fov_name = batch["index"]["fov_name"][0] - cell_id = batch["index"]["id"].item() - - fov_dirs = fov_name.split("/") - # Construct the path to the CSV file - csv_path = os.path.join( - tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv" - ) - - # Read the CSV file - df = pd.read_csv(csv_path) - - # Find the row with the specified id and extract the track_id - track_id = df.loc[df["id"] == cell_id, "track_id"].values[0] - - # Create a boolean mask where segmentation values are equal to the track_id - mask = seg_anchor == track_id - # mask = (seg_anchor > 0) - - # Find the most frequent non-zero value in seg_anchor - # unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True) - # most_frequent_value = unique[np.argmax(counts)] - - # # Create a boolean mask where segmentation values are equal to the most frequent value - # mask = (seg_anchor == most_frequent_value) - - # Expand the mask to match the anchor tensor shape - mask = mask.expand(1, 3, 1, 224, 224) - - # Calculate average values for each channel (background, uninfected, infected) using the mask - background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item()) - uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item()) - infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item()) - - # Convert lists to numpy arrays - background_avg = np.array(background_avg) - uninfected_avg = np.array(uninfected_avg) - infected_avg = np.array(infected_avg) - - print("Average values per cell for each mask calculated.") - print("Background average shape:", background_avg.shape) - print("Uninfected average shape:", uninfected_avg.shape) - print("Infected average shape:", infected_avg.shape) - - # Save the averages as .npy files - np.save(os.path.join(save_dir, "background_avg.npy"), background_avg) - np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg) - np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--backbone", type=str, default="resnet50") - parser.add_argument("--margin", type=float, default=0.5) - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--schedule", type=str, default="Constant") - parser.add_argument("--log_steps_per_epoch", type=int, default=10) - parser.add_argument("--embedding_len", type=int, default=256) - parser.add_argument("--max_epochs", type=int, default=100) - parser.add_argument("--accelerator", type=str, default="gpu") - parser.add_argument("--devices", type=int, default=1) - parser.add_argument("--num_nodes", type=int, default=1) - parser.add_argument("--log_every_n_steps", type=int, default=1) - parser.add_argument("--num_workers", type=int, default=8) - args = parser.parse_args() - main(args) diff --git a/applications/contrastive_phenotyping/figures/classify_feb.py b/applications/contrastive_phenotyping/figures/classify_feb.py deleted file mode 100644 index b9dd81b8e..000000000 --- a/applications/contrastive_phenotyping/figures/classify_feb.py +++ /dev/null @@ -1,100 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from imblearn.over_sampling import SMOTE -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix -from tqdm import tqdm - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation -from viscy.representation.evaluation.dimensionality_reduction import compute_pca - -# %% Defining Paths for February Dataset -feb_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr" -) - - -# %% Load and Process February Dataset -feb_embedding_dataset = read_embedding_dataset(feb_features_path) -print(feb_embedding_dataset) -pca_df = compute_pca(feb_embedding_dataset, n_components=6) - -# Print shape before merge -print("Shape of pca_df before merge:", pca_df.shape) - -# Load the ground truth infection labels -feb_ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track" -) -feb_infection = load_annotation( - feb_embedding_dataset, - feb_ann_root / "tracking_v1_infection.csv", - "infection class", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - -# Print shape of feb_infection -print("Shape of feb_infection:", feb_infection.shape) - -# Merge PCA results with ground truth labels on both 'fov_name' and 'id' -pca_df = pd.merge(pca_df, feb_infection.reset_index(), on=["fov_name", "id"]) - -# Print shape after merge -print("Shape of pca_df after merge:", pca_df.shape) - -# Prepare the full dataset -X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]] -y = pca_df["infection class"] - -# Apply SMOTE to balance the classes in the full dataset -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print shape after SMOTE -print( - f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}" -) - -# %% Train Logistic Regression Classifier with Progress Bar -model = LogisticRegression(max_iter=1000, random_state=42) - -# Wrap the training with tqdm to show a progress bar -for _ in tqdm(range(1)): - model.fit(X_resampled, y_resampled) - -# %% Predict Labels for the Entire Dataset -pca_df["Predicted_Label"] = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"])) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"])) - -# %% Plotting the Results -plt.figure(figsize=(10, 8)) -sns.scatterplot( - x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8 -) -plt.title("PCA with Ground Truth Labels") -plt.savefig("up_pca_ground_truth_labels.png", format="png", dpi=300) -plt.show() - -plt.figure(figsize=(10, 8)) -sns.scatterplot( - x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8 -) -plt.title("PCA with Logistic Regression Predicted Labels") -plt.savefig("up_pca_predicted_labels.png", format="png", dpi=300) -plt.show() - -# %% Save Predicted Labels to CSV -save_path_csv = "up_logistic_regression_predicted_labels_feb_pca.csv" -pca_df[["id", "fov_name", "Predicted_Label"]].to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") diff --git a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py b/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py deleted file mode 100644 index da63c52a8..000000000 --- a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py +++ /dev/null @@ -1,94 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import pandas as pd -from imblearn.over_sampling import SMOTE -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation - -# %% Defining Paths for February Dataset -feb_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/" -) - - -# %% Load and Process February Dataset (Embedding Features) -feb_embedding_dataset = read_embedding_dataset( - feb_features_path / "febtest_predict.zarr" -) -print(feb_embedding_dataset) - -# Extract the embedding feature values as the input matrix (X) -X = feb_embedding_dataset["features"].values - -# Prepare a DataFrame for the embeddings with id and fov_name -embedding_df = pd.DataFrame(X, columns=[f"feature_{i+1}" for i in range(X.shape[1])]) -embedding_df["id"] = feb_embedding_dataset["id"].values -embedding_df["fov_name"] = feb_embedding_dataset["fov_name"].values -print(embedding_df.head()) - -# %% Load the ground truth infection labels -feb_ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" -) -feb_infection = load_annotation( - feb_embedding_dataset, - feb_ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - -# %% Merge embedding features with infection labels on 'fov_name' and 'id' -merged_df = pd.merge(embedding_df, feb_infection.reset_index(), on=["fov_name", "id"]) -print(merged_df.head()) -# %% Prepare the full dataset for training -X = merged_df.drop( - columns=["id", "fov_name", "infection_state"] -).values # Use embeddings as features -y = merged_df["infection_state"] # Use infection state as labels -print(X.shape) -print(y.shape) -# %% Print class distribution before applying SMOTE -print("Class distribution before SMOTE:") -print(y.value_counts()) - -# Apply SMOTE to balance the classes -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print class distribution after applying SMOTE -print("Class distribution after SMOTE:") -print(pd.Series(y_resampled).value_counts()) - -# Train Logistic Regression Classifier -model = LogisticRegression(max_iter=1000, random_state=42) -model.fit(X_resampled, y_resampled) - -# Predict Labels for the Entire Dataset -y_pred = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(y, y_pred)) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(y, y_pred)) - -# %% -# Save the predicted labels to a CSV -save_path_csv = feb_features_path / "feb_test_regression_predicted_labels_embedding.csv" -predicted_labels_df = pd.DataFrame( - { - "id": merged_df["id"].values, - "fov_name": merged_df["fov_name"].values, - "Predicted_Label": y_pred, - } -) - -predicted_labels_df.to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") - -# %% diff --git a/applications/contrastive_phenotyping/figures/classify_june.py b/applications/contrastive_phenotyping/figures/classify_june.py deleted file mode 100644 index ca51f2b17..000000000 --- a/applications/contrastive_phenotyping/figures/classify_june.py +++ /dev/null @@ -1,121 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from imblearn.over_sampling import SMOTE -from sklearn.decomposition import PCA -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix -from sklearn.preprocessing import StandardScaler -from tqdm import tqdm - -from viscy.representation.embedding_writer import read_embedding_dataset - -# %% Defining Paths for June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") - -# %% Function to Load Annotations -def load_annotation(da, path, name, categories: dict | None = None): - annotation = pd.read_csv(path) - annotation["fov_name"] = "/" + annotation["fov ID"] - annotation = annotation.set_index(["fov_name", "id"]) - mi = pd.MultiIndex.from_arrays( - [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] - ) - selected = annotation.loc[mi][name] - if categories: - selected = selected.astype("category").cat.rename_categories(categories) - return selected - -# %% Function to Compute PCA -def compute_pca(embedding_dataset, n_components=6): - features = embedding_dataset["features"] - scaled_features = StandardScaler().fit_transform(features.values) - - # Compute PCA with specified number of components - pca = PCA(n_components=n_components, random_state=42) - pca_embedding = pca.fit_transform(scaled_features) - - # Prepare DataFrame with id and PCA coordinates - pca_df = pd.DataFrame({ - "id": embedding_dataset["id"].values, - "fov_name": embedding_dataset["fov_name"].values, - "PCA1": pca_embedding[:, 0], - "PCA2": pca_embedding[:, 1], - "PCA3": pca_embedding[:, 2], - "PCA4": pca_embedding[:, 3], - "PCA5": pca_embedding[:, 4], - "PCA6": pca_embedding[:, 5] - }) - - return pca_df - -# %% Load and Process June Dataset -june_embedding_dataset = read_embedding_dataset(june_features_path) -print(june_embedding_dataset) -pca_df = compute_pca(june_embedding_dataset, n_components=6) - -# Print shape before merge -print("Shape of pca_df before merge:", pca_df.shape) - -# Load the ground truth infection labels -june_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking") -june_infection = load_annotation(june_embedding_dataset, june_ann_root / "tracking_v1_infection.csv", "infection class", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) - -# Print shape of june_infection -print("Shape of june_infection:", june_infection.shape) - -# Merge PCA results with ground truth labels on both 'fov_name' and 'id' -pca_df = pd.merge(pca_df, june_infection.reset_index(), on=['fov_name', 'id']) - -# Print shape after merge -print("Shape of pca_df after merge:", pca_df.shape) - -# Prepare the full dataset -X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]] -y = pca_df["infection class"] - -# Apply SMOTE to balance the classes in the full dataset -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print shape after SMOTE -print(f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}") - -# %% Train Logistic Regression Classifier with Progress Bar -model = LogisticRegression(max_iter=1000, random_state=42) - -# Wrap the training with tqdm to show a progress bar -for _ in tqdm(range(1)): - model.fit(X_resampled, y_resampled) - -# %% Predict Labels for the Entire Dataset -pca_df["Predicted_Label"] = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"])) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"])) - -# %% Plotting the Results -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8) -plt.title("PCA with Ground Truth Labels") -plt.savefig("june_pca_ground_truth_labels.png", format='png', dpi=300) -plt.show() - -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8) -plt.title("PCA with Logistic Regression Predicted Labels") -plt.savefig("june_pca_predicted_labels.png", format='png', dpi=300) -plt.show() - -# %% Save Predicted Labels to CSV -save_path_csv = "june_logistic_regression_predicted_labels_feb_pca.csv" -pca_df[['id', 'fov_name', 'Predicted_Label']].to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") diff --git a/applications/contrastive_phenotyping/figures/figure_4a_1.py b/applications/contrastive_phenotyping/figures/figure_4a_1.py deleted file mode 100644 index a670db0d0..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4a_1.py +++ /dev/null @@ -1,167 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from sklearn.preprocessing import StandardScaler -from umap import UMAP - -from viscy.representation.embedding_writer import read_embedding_dataset - -# %% Defining Paths for February and June Datasets -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_data_path = Path("/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr") -feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr") - -# %% Function to Load and Process the Embedding Dataset -def compute_umap(embedding_dataset): - features = embedding_dataset["features"] - scaled_features = StandardScaler().fit_transform(features.values) - umap = UMAP() - embedding = umap.fit_transform(scaled_features) - - features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) - ) - return features - -# %% Function to Load Annotations -def load_annotation(da, path, name, categories: dict | None = None): - annotation = pd.read_csv(path) - annotation["fov_name"] = "/" + annotation["fov ID"] - annotation = annotation.set_index(["fov_name", "id"]) - mi = pd.MultiIndex.from_arrays( - [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] - ) - selected = annotation.loc[mi][name] - if categories: - selected = selected.astype("category").cat.rename_categories(categories) - return selected - -# %% Function to Plot UMAP with Infection Annotations -def plot_umap_infection(features, infection, title): - plt.figure(figsize=(10, 8)) - sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) - plt.title(f"UMAP Plot - {title}") - plt.show() - -# %% Load and Process February Dataset -feb_embedding_dataset = read_embedding_dataset(feb_features_path) -feb_features = compute_umap(feb_embedding_dataset) - -feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track") -feb_infection = load_annotation(feb_features, feb_ann_root / "tracking_v1_infection.csv", "infection class", {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) - -# %% Plot UMAP with Infection Status for February Dataset -plot_umap_infection(feb_features, feb_infection, "February Dataset") - -# %% -print(feb_embedding_dataset) -print(feb_infection) -print(feb_features) -# %% - - -# %% Identify cells by infection type using fov_name -mock_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/3') | feb_features['fov_name'].str.contains('/B/3')) -zika_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/4')) -dengue_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/B/4')) - -# %% Plot UMAP with Infection Status -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8) - -# Overlay with circled cells -plt.scatter(mock_cells["UMAP1"], mock_cells["UMAP2"], facecolors='none', edgecolors='blue', s=20, label='Mock Cells') -plt.scatter(zika_cells["UMAP1"], zika_cells["UMAP2"], facecolors='none', edgecolors='green', s=20, label='Zika MOI 5') -plt.scatter(dengue_cells["UMAP1"], dengue_cells["UMAP2"], facecolors='none', edgecolors='red', s=20, label='Dengue MOI 5') - -# Add legend and show plot -plt.legend(loc='best') -plt.title("UMAP Plot - February Dataset with Mock, Zika, and Dengue Highlighted") -plt.show() - -# %% -# %% Create a 1x3 grid of heatmaps -fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0]) -axs[0].set_title('Mock Cells') -axs[0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1]) -axs[1].set_title('Zika MOI 5') -axs[1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[2]) -axs[2].set_title('Dengue MOI 5') -axs[2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Set labels and adjust layout -for ax in axs: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') - -plt.tight_layout() -plt.show() - -# %% -import matplotlib.pyplot as plt -import seaborn as sns - -# %% Create a 2x3 grid of heatmaps (1 row for each heatmap, splitting infected and uninfected in the second row) -fig, axs = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0, 0]) -axs[0, 0].set_title('Mock Cells') -axs[0, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[0, 1]) -axs[0, 1].set_title('Zika MOI 5') -axs[0, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[0, 2]) -axs[0, 2].set_title('Dengue MOI 5') -axs[0, 2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Infected Cells Heatmap -sns.histplot(x=infected_cells["UMAP1"], y=infected_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[1, 0]) -axs[1, 0].set_title('Infected Cells') -axs[1, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Uninfected Cells Heatmap -sns.histplot(x=uninfected_cells["UMAP1"], y=uninfected_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1, 1]) -axs[1, 1].set_title('Uninfected Cells') -axs[1, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Remove the last subplot (bottom right corner) -fig.delaxes(axs[1, 2]) - -# Set labels and adjust layout -for ax in axs.flat: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') - -plt.tight_layout() -plt.show() - - - -# %% diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py deleted file mode 100644 index d3052018a..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py +++ /dev/null @@ -1,87 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - -from viscy.representation.embedding_writer import read_embedding_dataset - - -# %% Function to Load Annotations from GMM CSV -def load_gmm_annotation(gmm_csv_path): - gmm_df = pd.read_csv(gmm_csv_path) - return gmm_df - -# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on GMM Labels -def count_infected_cell_states_over_time(embedding_dataset, gmm_df): - # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - - # Merge with GMM data to add GMM labels - df = pd.merge(df, gmm_df[['id', 'fov_name', 'Predicted_Label']], on=['fov_name', 'id'], how='left') - - # Filter by time range (3 HPI to 30 HPI) - df = df[(df['t'] >= 3) & (df['t'] <= 27)] - - # Determine the well type (Mock, Zika, Dengue) based on fov_name - df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else - ('Zika' if '/A/4' in x else 'Dengue')) - - # Group by time, well type, and GMM label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Predicted_Label']).size().unstack(fill_value=0) - - # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - - # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - - return state_counts - -# %% Function to Plot Percentage of Infected Cells Over Time -def plot_infected_cell_states(state_counts): - plt.figure(figsize=(12, 8)) - - # Loop through each well type - for well_type in ['Mock', 'Zika', 'Dengue']: - # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - - # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') - - plt.title("Percentage of Infected Cells Over Time - February") - plt.xlabel("Hours Post Perturbation") - plt.ylabel("Percentage of Infected Cells") - plt.legend(title="Well Type") - plt.grid(True) - plt.show() - -# %% Load and process Feb Dataset -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_embedding_dataset = read_embedding_dataset(feb_features_path) - -# Load the GMM annotation CSV -gmm_csv_path = "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file -gmm_df = load_gmm_annotation(gmm_csv_path) - -# %% Count Infected Cell States Over Time as Percentage using GMM labels -state_counts = count_infected_cell_states_over_time(feb_embedding_dataset, gmm_df) -print(state_counts.head()) -state_counts.info() - -# %% Plot Infected Cell States Over Time as Percentage -plot_infected_cell_states(state_counts) - -# %% - - diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py b/applications/contrastive_phenotyping/figures/figure_4e_2_june.py deleted file mode 100644 index 1605ba278..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py +++ /dev/null @@ -1,85 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - -from viscy.representation.embedding_writer import read_embedding_dataset - - -# %% Function to Load Annotations from CSV -def load_annotation(csv_path): - return pd.read_csv(csv_path) - -# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on Predicted Labels -def count_infected_cell_states_over_time(embedding_dataset, prediction_df): - # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - - # Merge with the prediction data to add Predicted Labels - df = pd.merge(df, prediction_df[['id', 'fov_name', 'Infection_Class']], on=['fov_name', 'id'], how='left') - - # Filter by time range (2 HPI to 50 HPI) - df = df[(df['t'] >= 2) & (df['t'] <= 50)] - - # Determine the well type (Mock, Dengue, Zika) based on fov_name - df['well_type'] = df['fov_name'].apply( - lambda x: 'Mock' if '/0/1' in x or '/0/2' in x or '/0/3' in x or '/0/4' in x else - ('Dengue' if '/0/5' in x or '/0/6' in x else 'Zika')) - - # Group by time, well type, and Predicted_Label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Infection_Class']).size().unstack(fill_value=0) - - # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - - # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - - return state_counts - -# %% Function to Plot Percentage of Infected Cells Over Time -def plot_infected_cell_states(state_counts): - plt.figure(figsize=(12, 8)) - - # Loop through each well type - for well_type in ['Mock', 'Dengue', 'Zika']: - # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - - # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') - - plt.title("Percentage of Infected Cells Over Time - June") - plt.xlabel("Hours Post Perturbation") - plt.ylabel("Percentage of Infected Cells") - plt.legend(title="Well Type") - plt.grid(True) - plt.show() - -# %% Load and process June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") -june_embedding_dataset = read_embedding_dataset(june_features_path) - -# Load the predicted labels from CSV -prediction_csv_path = "3up_gmm_clustering_results_june_pca_6components.csv" # Path to predicted labels CSV file -prediction_df = load_annotation(prediction_csv_path) - -# %% Count Infected Cell States Over Time as Percentage using Predicted labels -state_counts = count_infected_cell_states_over_time(june_embedding_dataset, prediction_df) -print(state_counts.head()) -state_counts.info() - -# %% Plot Infected Cell States Over Time as Percentage -plot_infected_cell_states(state_counts) - -# %% diff --git a/applications/contrastive_phenotyping/figures/save_patches.py b/applications/contrastive_phenotyping/figures/save_patches.py deleted file mode 100644 index ebba6c320..000000000 --- a/applications/contrastive_phenotyping/figures/save_patches.py +++ /dev/null @@ -1,67 +0,0 @@ -# %% script to save 128 by 128 image patches from napari viewer - -import os -import sys -from pathlib import Path - -import numpy as np - -sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") -# from viscy.data.triplet import TripletDataModule -from viscy.representation.evaluation import dataset_of_tracks - -# %% input parameters - -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - -fov_name = "/B/4/6" -track_id = 52 -source_channel = ["Phase3D", "RFP"] - -# %% load dataset - -prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, -) -whole = np.stack([p["anchor"] for p in prediction_dataset]) -phase = whole[:, 0] -fluor = whole[:, 1] - -# use the following if you want to visualize a specific phase slice with max projected fluor -# phase = whole[:, 0, 3] # 3 is the slice number -# fluor = np.max(whole[:, 1], axis=1) - -# load image -# v = napari.Viewer() -# v.add_image(phase) -# v.add_image(fluor) - -# %% save patches as png images - -# use sliders on napari to get the deisred contrast and make other adjustments -# then use save screenshot if saving the image patch manually -# you can add code to automate the process if desired - -# %% save as numpy files - -out_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/data/" -fov_name_out = fov_name.replace("/", "_") -np.save( - (os.path.join(out_dir, "phase" + fov_name_out + "_" + str(track_id) + ".npy")), - phase, -) -np.save( - (os.path.join(out_dir, "fluor" + fov_name_out + "_" + str(track_id) + ".npy")), - fluor, -) - -# %% From efdcec0aa171e892447f1727df4b0db297daebcc Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 15:06:25 -0800 Subject: [PATCH 11/38] deleted duplicate scripts --- .../evaluation/PC_vs_CF_singleChannel.py | 245 ------------ .../Infection_classification_25DModel.py | 106 ----- .../Infection_classification_covnextModel.py | 107 ------ .../classify_infection_25D.py | 356 ----------------- .../classify_infection_covnext.py | 363 ------------------ 5 files changed, 1177 deletions(-) delete mode 100644 applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py delete mode 100644 applications/infection_classification/Infection_classification_25DModel.py delete mode 100644 applications/infection_classification/Infection_classification_covnextModel.py delete mode 100644 applications/infection_classification/classify_infection_25D.py delete mode 100644 applications/infection_classification/classify_infection_covnext.py diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py deleted file mode 100644 index 3d5049166..000000000 --- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py +++ /dev/null @@ -1,245 +0,0 @@ -""" Script to compute the correlation between PCA and UMAP features and computed features -* finds the computed features best representing the PCA and UMAP components -* outputs a heatmap of the correlation between PCA and UMAP features and computed features -""" - -# %% -import sys -from pathlib import Path - -sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") - -import numpy as np -import pandas as pd -import plotly.express as px -from scipy.stats import spearmanr -from sklearn.decomposition import PCA - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks -from viscy.representation.evaluation.feature import ( - FeatureExtractor as FE, -) - -# %% -features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval_phase/predictions/epoch_186/1chan_128patch_186ckpt_Febtest.zarr" -) -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/track.zarr" -) - -# %% - -source_channel = ["Phase3D"] -z_range = (28, 43) -normalizations = None -# fov_name = "/B/4/5" -# track_id = 11 - -embedding_dataset = read_embedding_dataset(features_path) -embedding_dataset - -# load all unprojected features: -features = embedding_dataset["features"] - -# %% PCA analysis of the features - -pca = PCA(n_components=3) -embedding = pca.fit_transform(features.values) -features = ( - features.assign_coords(PCA1=("sample", embedding[:, 0])) - .assign_coords(PCA2=("sample", embedding[:, 1])) - .assign_coords(PCA3=("sample", embedding[:, 2])) - .set_index(sample=["PCA1", "PCA2", "PCA3"], append=True) -) - -# %% convert the xarray to dataframe structure and add columns for computed features -features_df = features.to_dataframe() -features_df = features_df.drop(columns=["features"]) -df = features_df.drop_duplicates() -features = df.reset_index(drop=True) - -features = features[features["fov_name"].str.startswith("/B/")] - -features["Phase Symmetry Score"] = np.nan -features["Entropy Phase"] = np.nan -features["Contrast Phase"] = np.nan -features["Dissimilarity Phase"] = np.nan -features["Homogeneity Phase"] = np.nan -features["Phase IQR"] = np.nan -features["Phase Standard Deviation"] = np.nan -features["Phase radial profile"] = np.nan - -# %% compute the computed features and add them to the dataset - -fov_names_list = features["fov_name"].unique() -unique_fov_names = sorted(list(set(fov_names_list))) - -for fov_name in unique_fov_names: - - unique_track_ids = features[features["fov_name"] == fov_name]["track_id"].unique() - unique_track_ids = list(set(unique_track_ids)) - - for track_id in unique_track_ids: - - # load the image patches - - prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, - ) - - whole = np.stack([p["anchor"] for p in prediction_dataset]) - phase = whole[:, 0, 3] - - for t in range(phase.shape[0]): - # Compute Fourier descriptors for phase image - phase_descriptors = FE.compute_fourier_descriptors(phase[t]) - # Analyze symmetry of phase image - phase_symmetry_score = FE.analyze_symmetry(phase_descriptors) - - # Compute higher frequency features using spectral entropy - entropy_phase = FE.compute_spectral_entropy(phase[t]) - - # Compute texture analysis using GLCM - contrast_phase, dissimilarity_phase, homogeneity_phase = ( - FE.compute_glcm_features(phase[t]) - ) - - # Compute interqualtile range of pixel intensities - iqr = FE.compute_iqr(phase[t]) - - # Compute standard deviation of pixel intensities - phase_std_dev = FE.compute_std_dev(phase[t]) - - # Compute radial intensity gradient - phase_radial_profile = FE.compute_radial_intensity_gradient(phase[t]) - - # update the features dataframe with the computed features - - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Symmetry Score", - ] = phase_symmetry_score - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Entropy Phase", - ] = entropy_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Contrast Phase", - ] = contrast_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Dissimilarity Phase", - ] = dissimilarity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Homogeneity Phase", - ] = homogeneity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase IQR", - ] = iqr - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Standard Deviation", - ] = phase_std_dev - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase radial profile", - ] = phase_radial_profile - -# %% -# Save the features dataframe to a CSV file -features.to_csv( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv", - index=False, -) - -# read the csv file -# features = pd.read_csv( -# "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv" -# ) - -# remove the rows with missing values -features = features.dropna() - -# sub_features = features[features["Time"] == 20] -feature_df_removed = features.drop( - columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] -) - -# Compute correlation between PCA features and computed features -correlation = feature_df_removed.corr(method="spearman") - -# %% calculate the p-value and draw volcano plot to show the significance of the correlation - -p_values = pd.DataFrame(index=correlation.index, columns=correlation.columns) - -for i in correlation.index: - for j in correlation.columns: - if i != j: - p_values.loc[i, j] = spearmanr( - feature_df_removed[i], feature_df_removed[j] - )[1] - -p_values = p_values.astype(float) - -# %% draw an interactive volcano plot showing -log10(p-value) vs fold change - -# Flatten the correlation and p-values matrices and create a DataFrame -correlation_flat = correlation.values.flatten() -p_values_flat = p_values.values.flatten() -# Create a list of feature names for the flattened correlation and p-values -feature_names = [f"{i}_{j}" for i in correlation.index for j in correlation.columns] - -data = pd.DataFrame( - { - "Correlation": correlation_flat, - "-log10(p-value)": -np.log10(p_values_flat), - "feature_names": feature_names, - } -) - -# Create an interactive scatter plot using Plotly -fig = px.scatter( - data, - x="Correlation", - y="-log10(p-value)", - title="Volcano plot showing significance of correlation", - labels={"Correlation": "Correlation", "-log10(p-value)": "-log10(p-value)"}, - opacity=0.5, - hover_data=["feature_names"], -) - -fig.show() -# Save the interactive volcano plot as an HTML file -fig.write_html( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/volcano_plot_1chan.html" -) - -# %% diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py deleted file mode 100644 index a4e712f5b..000000000 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ /dev/null @@ -1,106 +0,0 @@ -# %% -import lightning.pytorch as pl -import torch -import torch.nn as nn -from applications.infection_classification.classify_infection_25D import ( - SemanticSegUNet25D, -) -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from viscy.data.hcs import HCSDataModule -from viscy.preprocessing.pixel_ratio import sematic_class_weights -from viscy.transforms import NormalizeSampled, RandWeightedCropd - -# %% Create a dataloader and visualize the batches. - -# Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" - -# %% create data module - -# Create an instance of HCSDataModule -data_module = HCSDataModule( - dataset_path, - source_channel=["Phase", "HSP90"], - target_channel=["Inf_mask"], - yx_patch_size=[512, 512], - split_ratio=0.8, - z_window_size=5, - architecture="2.5D", - num_workers=3, - batch_size=32, - normalizations=[ - NormalizeSampled( - keys=["Phase", "HSP90"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], - augmentations=[ - RandWeightedCropd( - num_samples=4, - spatial_size=[-1, 512, 512], - keys=["Phase", "HSP90"], - w_key="Inf_mask", - ) - ], -) - -pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") - -# Prepare the data -data_module.prepare_data() - -# Setup the data -data_module.setup(stage="fit") - -# Create a dataloader -train_dm = data_module.train_dataloader() - -val_dm = data_module.val_dataloader() - - -# %% Define the logger -logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", - name="logs", -) - -# Pass the logger to the Trainer -trainer = pl.Trainer( - logger=logger, - max_epochs=200, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - log_every_n_steps=1, - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) - -# Define the checkpoint callback -checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - filename="checkpoint_{epoch:02d}", - save_top_k=-1, - verbose=True, - monitor="loss/validate", - mode="min", -) - -# Add the checkpoint callback to the trainer -trainer.callbacks.append(checkpoint_callback) - -# Fit the model -model = SemanticSegUNet25D( - in_channels=2, - out_channels=3, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), -) - -print(model) - -# %% Run training. - -trainer.fit(model, data_module) - -# %% diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py deleted file mode 100644 index bfe203625..000000000 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ /dev/null @@ -1,107 +0,0 @@ -# %% -# import sys -# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") -import lightning.pytorch as pl -import torch -import torch.nn as nn -from applications.infection_classification.classify_infection_covnext import ( - SemanticSegUNet22D, -) -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from viscy.data.hcs import HCSDataModule -from viscy.preprocessing.pixel_ratio import sematic_class_weights -from viscy.transforms import NormalizeSampled, RandWeightedCropd - -# %% Create a dataloader and visualize the batches. - -# Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" - -# %% craete data module - -# Create an instance of HCSDataModule -data_module = HCSDataModule( - dataset_path, - source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - target_channel=["Inf_mask"], - yx_patch_size=[256, 256], - split_ratio=0.8, - z_window_size=5, - architecture="2.2D", - num_workers=3, - batch_size=16, - normalizations=[ - NormalizeSampled( - keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], - augmentations=[ - RandWeightedCropd( - num_samples=4, - spatial_size=[-1, 256, 256], - keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - w_key="Inf_mask", - ) - ], -) -pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") - -# Prepare the data -data_module.prepare_data() - -# Setup the data -data_module.setup(stage="fit") - -# Create a dataloader -train_dm = data_module.train_dataloader() - -val_dm = data_module.val_dataloader() - - -# %% Define the logger -logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", - name="logs", -) - -# Pass the logger to the Trainer -trainer = pl.Trainer( - logger=logger, - max_epochs=200, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - log_every_n_steps=1, - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) - -# Define the checkpoint callback -checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - filename="checkpoint_{epoch:02d}", - save_top_k=-1, - verbose=True, - monitor="loss/validate", - mode="min", -) - -# Add the checkpoint callback to the trainer`` -trainer.callbacks.append(checkpoint_callback) - -# Fit the model -model = SemanticSegUNet22D( - in_channels=4, - out_channels=3, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), -) - -print(model) - -# %% Run training. - -trainer.fit(model, data_module) - -# %% diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py deleted file mode 100644 index e16f56f42..000000000 --- a/applications/infection_classification/classify_infection_25D.py +++ /dev/null @@ -1,356 +0,0 @@ -# import torchview -from typing import Literal, Sequence - -import cv2 -import lightning.pytorch as pl -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from matplotlib.cm import get_cmap -from monai.transforms import DivisiblePad -from skimage.exposure import rescale_intensity -from skimage.measure import label, regionprops -from torch import Tensor - -from viscy.data.hcs import Sample -from viscy.unet.networks.Unet25D import Unet25d - -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) - confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def compute_confusion_matrix( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - conf_mat = np.zeros((num_classes, num_classes)) - for i in range(batch_size): - y_true_cpu = y_true[i].cpu().numpy() - y_pred_cpu = y_pred[i].cpu().numpy() - y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize( - y_pred_reshaped, - dsize=y_true_reshaped.shape[::-1], - interpolation=cv2.INTER_NEAREST, - ) - y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - - # find objects in every image - label_img = label(y_true_reshaped) - regions = regionprops(label_img) - - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - if region.area > 0: - row, col = region.centroid - pred_id = y_pred_resized[int(row), int(col)] - test_id = y_true_reshaped[int(row), int(col)] - - if pred_id == 1 and test_id == 1: - conf_mat[1, 1] += 1 - if pred_id == 1 and test_id == 2: - conf_mat[0, 1] += 1 - if pred_id == 2 and test_id == 1: - conf_mat[1, 0] += 1 - if pred_id == 2 and test_id == 2: - conf_mat[0, 0] += 1 - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - return conf_mat - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - index_to_label_dict = dict( - enumerate(index_to_label_dict) - ) # Convert list to dictionary - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - # plt.show(fig) # Show the figure - return fig - - -# Define a 25d unet model for infection classification as a lightning module. - - -class SemanticSegUNet25D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - ckpt_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = Unet25d( - in_channels=in_channels, - out_channels=out_channels, - num_blocks=4, - num_block_layers=4, - ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - # prob_chan = prob_pred[:, 2, :, :] - # prob_chan = prob_chan.unsqueeze(1) - return labels_pred # log the class predicted image - # return prob_chan # log the probability predicted image - - def on_test_start(self): - self.pred_cm = torch.zeros((2, 2)) - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - def test_step(self, batch: Sample): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - - target = self._predict_pad(batch["target"]) # Extract the target from the batch - pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=2 - ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(pred_cm, self.index_to_label_dict), - self.current_epoch, - ) - - # Accumulate the confusion matrix at the end of test epoch and log. - def on_test_end(self): - confusion_matrix_sum = self.pred_cm - self.logger.experiment.add_figure( - "Confusion Matrix", - plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - - -# %% diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py deleted file mode 100644 index 397e822db..000000000 --- a/applications/infection_classification/classify_infection_covnext.py +++ /dev/null @@ -1,363 +0,0 @@ -# import torchview -from typing import Literal, Sequence - -import cv2 -import lightning.pytorch as pl -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from matplotlib.cm import get_cmap -from monai.transforms import DivisiblePad -from skimage.exposure import rescale_intensity -from skimage.measure import label, regionprops -from torch import Tensor - -from viscy.data.hcs import Sample -from viscy.translation.engine import VSUNet - -# -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) - confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def compute_confusion_matrix( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - conf_mat = np.zeros((num_classes, num_classes)) - for i in range(batch_size): - y_true_cpu = y_true[i].cpu().numpy() - y_pred_cpu = y_pred[i].cpu().numpy() - y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize( - y_pred_reshaped, - dsize=y_true_reshaped.shape[::-1], - interpolation=cv2.INTER_NEAREST, - ) - y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - - # find objects in every image - label_img = label(y_true_reshaped) - regions = regionprops(label_img) - - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - if region.area > 0: - row, col = region.centroid - pred_id = y_pred_resized[int(row), int(col)] - test_id = y_true_reshaped[int(row), int(col)] - - if pred_id == 1 and test_id == 1: - conf_mat[1, 1] += 1 - if pred_id == 1 and test_id == 2: - conf_mat[0, 1] += 1 - if pred_id == 2 and test_id == 1: - conf_mat[1, 0] += 1 - if pred_id == 2 and test_id == 2: - conf_mat[0, 0] += 1 - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - return conf_mat - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - index_to_label_dict = dict( - enumerate(index_to_label_dict) - ) # Convert list to dictionary - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - # plt.show(fig) # Show the figure - return fig - - -# Define a 25d unet model for infection classification as a lightning module. - - -class SemanticSegUNet22D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - ckpt_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet22D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = VSUNet( - architecture="2.2D", - model_config={ - "in_channels": in_channels, - "out_channels": out_channels, - "in_stack_depth": 5, - "backbone": "convnextv2_tiny", - "stem_kernel_size": (5, 4, 4), - "decoder_mode": "pixelshuffle", - "head_expansion_ratio": 4, - }, - ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - # prob_chan = prob_pred[:, 2, :, :] - # prob_chan = prob_chan.unsqueeze(1) - return labels_pred # log the class predicted image - # return prob_chan # log the probability predicted image - - def on_test_start(self): - self.pred_cm = torch.zeros((2, 2)) - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - def test_step(self, batch: Sample): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - - target = self._predict_pad(batch["target"]) # Extract the target from the batch - pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=2 - ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(pred_cm, self.index_to_label_dict), - self.current_epoch, - ) - - # Accumulate the confusion matrix at the end of test epoch and log. - def on_test_end(self): - confusion_matrix_sum = self.pred_cm - self.logger.experiment.add_figure( - "Confusion Matrix", - plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - - -# %% From c2d1b3e6d7afd713e108dea0080285eea7c24fb6 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 15:27:43 -0800 Subject: [PATCH 12/38] deleted duplicate script --- .../evaluation/pca_umap_embeddings_time.py | 220 ------------------ 1 file changed, 220 deletions(-) delete mode 100644 applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py deleted file mode 100644 index 5f59da3e0..000000000 --- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py +++ /dev/null @@ -1,220 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler -from umap import UMAP - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation - -# %% Paths and parameters. - - -features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - - -# %% -embedding_dataset = read_embedding_dataset(features_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% -# Add UMAP coordinates to the dataset and plot w/ time - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - - -# Add the title to the plot -plt.title("Cell & Time Aware Sampling (30 min interval)") -plt.xlim(-10, 20) -plt.ylim(-10, 20) -# plt.savefig('umap_cell_time_aware_time.svg', format='svg') -plt.savefig("updated_cell_time_aware_time.png", format="png") -# Show the plot -plt.show() - - -# %% - - -any_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" -) -embedding_dataset = read_embedding_dataset(any_features_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% Any time sampling plot - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - - -# Add the title to the plot -plt.title("Cell Aware Sampling") - -plt.xlim(-10, 20) -plt.ylim(-10, 20) - -plt.savefig("1_updated_cell_aware_time.png", format="png") -# plt.savefig('umap_cell_aware_time.pdf', format='pdf') -# Show the plot -plt.show() - - -# %% - - -contrastive_learning_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" -) -embedding_dataset = read_embedding_dataset(contrastive_learning_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% Any time sampling plot - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - -# Add the title to the plot -plt.title("Classical Contrastive Learning Sampling") -plt.xlim(-10, 20) -plt.ylim(-10, 20) -plt.savefig("updated_classical_time.png", format="png") -# plt.savefig('classical_time.pdf', format='pdf') - -# Show the plot -plt.show() - - -# %% PCA - - -pca = PCA(n_components=4) -# scaled_features = StandardScaler().fit_transform(features.values) -# pca_features = pca.fit_transform(scaled_features) -pca_features = pca.fit_transform(features.values) - - -features = ( - features.assign_coords(PCA1=("sample", pca_features[:, 0])) - .assign_coords(PCA2=("sample", pca_features[:, 1])) - .assign_coords(PCA3=("sample", pca_features[:, 2])) - .assign_coords(PCA4=("sample", pca_features[:, 3])) - .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) -) - - -# %% plot PCA components w/ time - - -plt.figure(figsize=(10, 10)) -sns.scatterplot( - x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8 -) - - -# %% OVERLAY INFECTION ANNOTATION -ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" -) - - -infection = load_annotation( - features, - ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - - -# %% -sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) - - -# %% plot PCA components with infection hue -sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8) - - -# %% From a5305232f6d2acfe42b261dc997dbc01db1cd689 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 22:03:05 -0800 Subject: [PATCH 13/38] add accuracy metric plot --- .../figures/ALFI_accuracy_metrics.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py diff --git a/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py new file mode 100644 index 000000000..4aaff20e6 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py @@ -0,0 +1,137 @@ + +# %% compute accuracy of model from ALFI data using cell division state classification + +from pathlib import Path +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +from sklearn.linear_model import LogisticRegression + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% +accuracies = [] + +features_paths = { + '7 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr', + '14 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr', + '28 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr', + '56 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr', + '91 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr', +} + +for interval_name, path in features_paths.items(): + features_path = Path(path) + embedding_dataset = read_embedding_dataset(features_path) + embedding_dataset + features = embedding_dataset["features"] + + # load the cell cycle state annotation + + def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" + ) + + division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, + ) + + # train a linear classifier on half the data + + division_npy = division.cat.codes.values + division_npy_filtered = division_npy[division_npy != -1] + + feature_npy = features.values + feature_npy_filtered = feature_npy[division_npy != -1] + + # add time and well info into dataframe + time_npy = features["t"].values + time_npy_filtered = time_npy[division_npy != -1] + + + fov_name_list = features["fov_name"].values + fov_name_list_filtered = fov_name_list[division_npy != -1] + + data = pd.DataFrame( + { + "division": division_npy_filtered, + "time": time_npy_filtered, + "fov_name": fov_name_list_filtered, + } + ) + # Add all 768 features to the dataframe + feature_columns = pd.DataFrame(feature_npy_filtered, columns=[f"feature_{i+1}" for i in range(768)]) + data = pd.concat([data, feature_columns], axis=1) + + # dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" + data_train_val = data[ + data["fov_name"].str.contains("/0/0/0") + | data["fov_name"].str.contains("/0/1/0") + | data["fov_name"].str.contains("/0/2/0") + ] + + data_test = data[ + data["fov_name"].str.contains("/0/3/0") + | data["fov_name"].str.contains("/0/4/0") + ] + + x_train = data_train_val.drop( + columns=[ + "division", + "fov_name", + "time", + ] + ) + y_train = data_train_val["division"] + + # train a logistic regression model + clf = LogisticRegression(random_state=0).fit(x_train, y_train) + + # test the trained classifer on the other half of the data + + x_test = data_test.drop( + columns=[ + "division", + "fov_name", + "time", + ] + ) + y_test = data_test["division"] + + # predict the infection state for the testing set + y_pred = clf.predict(x_test) + + # compute the accuracy of the classifier + + accuracy = np.mean(y_pred == y_test) + # save the accuracy for final ploting + print(f"Accuracy of model trained on {interval_name} data: {accuracy}") + accuracies.append(accuracy) + +# %% plot the accuracy of the model trained on different time intervals + +plt.figure(figsize=(8, 6)) +plt.bar(features_paths.keys(), accuracies) +plt.ylabel("Accuracy") +plt.xlabel("Time interval") +plt.title("Accuracy") +plt.ylim(0.9, 1) +plt.show() +# %% From 83a45834848392387a8bccfcfb15511d7f9607c8 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 22:03:51 -0800 Subject: [PATCH 14/38] plot ALFI phatemap with annotation --- .../figures/plot_phatemap_ALFI.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py diff --git a/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py b/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py new file mode 100644 index 000000000..9d2d78f58 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py @@ -0,0 +1,93 @@ + +# %% + +from pathlib import Path +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% + +features_path = Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/log_alfi_triplet_time_intervals/prediction/ALFI_91mins.zarr" +) +# data_path = Path( +# "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_ZIKV_DENV.zarr" +# ) +# tracks_path = Path( +# "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +# ) + +# %% + +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +PHATE1 = embedding_dataset["PHATE1"].values +PHATE2 = embedding_dataset["PHATE2"].values + +# %% plot PHATE map based on the embedding dataset time points + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=embedding_dataset["t"], s=7, alpha=0.8 +) + +# %% color using human annotation for cell cycle state + +def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + +# %% load the cell cycle state annotation + +ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" +) + +division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, +) + +# %% plot PHATE map based on the cell cycle annotation + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=division, s=7, alpha=0.8 +) + +# %% plot intercative plot to hover over the points on scatter plot and see the fov_name and track_id + +import plotly.express as px + +fig = px.scatter( + embedding_dataset.to_dataframe(), + x="PHATE1", + y="PHATE2", + color=division, + hover_data=["fov_name", "id"], +) + +# %% +# find row index in 'division' where the value is -1 +division[division == -1].index +# find the track_id and 't' value of cell instance where 'fov_name' is '/0/0/0' and 'id' is 1000941 +embedding_dataset.where(embedding_dataset["fov_name"] == "/0/0/0", drop=True).where( + embedding_dataset["id"] == 1000942, drop=True +) + +# %% From a090eb97de5d36b284cd8624830f57eafced7240 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 9 Jan 2025 22:07:59 -0800 Subject: [PATCH 15/38] style plot --- .../figures/ALFI_accuracy_metrics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py index 4aaff20e6..030673a64 100644 --- a/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py +++ b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py @@ -18,6 +18,7 @@ '28 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr', '56 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr', '91 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr', + 'classical': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr', } for interval_name, path in features_paths.items(): @@ -129,9 +130,9 @@ def load_annotation(da, path, name, categories: dict | None = None): plt.figure(figsize=(8, 6)) plt.bar(features_paths.keys(), accuracies) -plt.ylabel("Accuracy") -plt.xlabel("Time interval") -plt.title("Accuracy") +plt.xticks(rotation=45, ha='right', fontsize=12) +plt.ylabel("Accuracy", fontsize=14) +plt.xlabel("Time interval", fontsize=14) plt.ylim(0.9, 1) plt.show() # %% From e397de16c660ba4f5095a9af004a3875581d1ec3 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 10 Jan 2025 15:03:10 -0800 Subject: [PATCH 16/38] adding back cosine --- viscy/representation/evaluation/distance.py | 24 +++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index b7b563854..b8552fa24 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -22,6 +22,20 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities.tolist() +def calculate_euclidian_distance_cell(embedding_dataset, fov_name, track_id): + """Extract embeddings and calculate euclidean distances for a specific cell""" + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + first_time_point_embedding = features[0].reshape(1, -1) + euclidean_distances = np.linalg.norm(first_time_point_embedding - features, axis=1) + return time_points, euclidean_distances.tolist() + + def compute_displacement( embedding_dataset, distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", @@ -85,8 +99,14 @@ def compute_displacement( )[0] if len(matching_indices) == 1: - future_embedding = embeddings[matching_indices[0]].reshape(1, -1) - displacement = np.sum((current_embedding - future_embedding) ** 2) + if distance_metric == "euclidean_squared": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = np.sum((current_embedding - future_embedding) ** 2) + elif distance_metric == "cosine": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = cosine_similarity( + current_embedding, future_embedding + ) displacement_per_delta_t[delta_t].append(displacement) return dict(displacement_per_delta_t) From fb93f606d377c40503ff1b7d7ee1cb47641e865c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 10 Jan 2025 16:57:00 -0800 Subject: [PATCH 17/38] adding the normalization of the euclidian distance and the code for plotting --- .../cosine_dissimilarity_dataset.py | 60 ++-- .../evaluation/euclidean_distance_dataset.py | 286 ++++++++++++++++++ viscy/representation/evaluation/clustering.py | 7 +- 3 files changed, 324 insertions(+), 29 deletions(-) create mode 100644 applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index fb8ce2f84..59f375afc 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -166,18 +166,6 @@ def analyze_embedding_smoothness( } if verbose: - # Plot cross distance matrix - plt.figure() - plt.imshow(cross_dist) - plt.show() - - # Plot histograms - plot_histogram( - piece_wise_dissimilarity_per_track, - "Adjacent Frame Dissimilarity per Track", - "Cosine Dissimilarity", - "Frequency", - ) # Plot the comparison histogram and save if output_path is provided fig = plt.figure() @@ -228,36 +216,52 @@ def analyze_embedding_smoothness( if __name__ == "__main__": # plotting VERBOSE = True - PATH_TO_GDRIVE_FIGUE = "./" - prediction_path_1 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" - ) - prediction_path_2 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" - ) + # Define models as a dictionary with meaningful keys + prediction_paths = { + "ntxent_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" + ), + "triplet_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + ), + } - # Create a list of models to evaluate - models = [ - (prediction_path_1, "ntxent"), - (prediction_path_2, "triplet"), - ] + # output_folder to save the distributions as .csv + output_folder = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/cosine_dissimilarity_distributions" + ) + output_folder.mkdir(parents=True, exist_ok=True) # Evaluate each model - for prediction_path, loss_name in tqdm(models, desc="Evaluating models"): - print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {loss_name})") + for model_name, prediction_path in tqdm( + prediction_paths.items(), desc="Evaluating models" + ): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") print("-" * 80) metrics = analyze_embedding_smoothness( prediction_path, verbose=VERBOSE, output_path=PATH_TO_GDRIVE_FIGUE, - loss_name=loss_name, + loss_name=model_name, overwrite=True, ) - # Print adjacent frame dissimilarity statistics + # Save distributions to CSV + distributions_df = pd.DataFrame( + { + "adjacent_frame": pd.Series(metrics["dissimilarity_distribution"]), + "random_sampling": pd.Series(metrics["random_distribution"]), + } + ) + csv_path = ( + output_folder / f"{prediction_path.stem}_{model_name}_distributions.csv" + ) + distributions_df.to_csv(csv_path, index=False) + + # Print statistics (rest of the printing code remains the same) print("\nAdjacent Frame Dissimilarity Statistics:") print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") diff --git a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py new file mode 100644 index 000000000..ec2965533 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py @@ -0,0 +1,286 @@ +# %% +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import seaborn as sns +from numpy.typing import NDArray + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) +import numpy as np +from tqdm import tqdm +import pandas as pd + +from scipy.stats import gaussian_kde +from scipy.optimize import minimize_scalar + + +plt.style.use("../evaluation/figure.mplstyle") + + +def compute_piece_wise_distance( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_distance_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + indices = subdata.index.values + single_track_distance = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_distance = compare_time_offset( + single_track_distance, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_distance_per_track.append(piece_wise_distance) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_distance_per_track, piece_wise_rank_difference_per_track + + +def plot_histogram( + data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" +): + plt.figure() + plt.title(title) + sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.tight_layout() + plt.show() + + +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. + + Args: + data: Array of values to find the peak for + + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" + ) + return result.x + + +def analyze_embedding_smoothness( + prediction_path: Path, + verbose: bool = False, + output_path: Optional[str] = None, + loss_name: Optional[str] = None, + overwrite: bool = False, +) -> dict: + """ + Analyze the smoothness and dynamic range of embeddings using Euclidean distance. + + Args: + prediction_path: Path to the embedding dataset + verbose: If True, generates additional plots + output_path: Path to save the final plot (optional) + loss_name: Name of the loss function used (optional) + overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) + + Returns: + dict: Dictionary containing metrics including: + - distance_mean: Mean of adjacent frame distance + - distance_std: Standard deviation of adjacent frame distance + - distance_median: Median of adjacent frame distance + - distance_peak: Peak of adjacent frame distribution + - distance_p99: 99th percentile of adjacent frame distance + - distance_p1: 1st percentile of adjacent frame distance + - distance_distribution: Full distribution of adjacent frame distances + - random_mean: Mean of random sampling distance + - random_std: Standard deviation of random sampling distance + - random_median: Median of random sampling distance + - random_peak: Peak of random sampling distribution + - random_distribution: Full distribution of random sampling distances + - dynamic_range: Difference between random and adjacent peaks + """ + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + # Compute the Euclidean distance + cross_dist = pairwise_distance_matrix(features, metric="euclidean") + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise distance and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_distance_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_distance(features_df, cross_dist, rank_fractions) + ) + + all_distance = np.concatenate(piece_wise_distance_per_track) + + p99_piece_wise_distance = np.array( + [np.percentile(track, 99) for track in piece_wise_distance_per_track] + ) + p1_percentile_piece_wise_distance = np.array( + [np.percentile(track, 1) for track in piece_wise_distance_per_track] + ) + + # Random sampling values in the distance matrix with same size as adjacent frame measurements + n_samples = len(all_distance) + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + # Compute the peaks of both distributions using KDE + adjacent_peak = float(find_distribution_peak(all_distance)) + random_peak = float(find_distribution_peak(sampled_values)) + dynamic_range = float(random_peak - adjacent_peak) + + metrics = { + "distance_mean": float(np.mean(all_distance)), + "distance_std": float(np.std(all_distance)), + "distance_median": float(np.median(all_distance)), + "distance_peak": adjacent_peak, + "distance_p99": p99_piece_wise_distance, + "distance_p1": p1_percentile_piece_wise_distance, + "distance_distribution": all_distance, + "random_mean": float(np.mean(sampled_values)), + "random_std": float(np.std(sampled_values)), + "random_median": float(np.median(sampled_values)), + "random_peak": random_peak, + "random_distribution": sampled_values, + "dynamic_range": dynamic_range, + } + + if verbose: + # Plot the comparison histogram and save if output_path is provided + fig = plt.figure() + sns.histplot( + metrics["distance_distribution"], + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + metrics["random_distribution"], + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", + ) + plt.xlabel("Euclidean Distance") + plt.ylabel("Density") + # Add vertical lines for the peaks + plt.axvline(x=metrics["distance_peak"], color="cyan", linestyle="--", alpha=0.8) + plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + + if output_path and loss_name: + output_file = Path( + f"{output_path}/euclidean_distance_smoothness_{prediction_path.stem}_{loss_name}.pdf" + ) + if output_file.exists() and not overwrite: + raise FileExistsError( + f"File {output_file} already exists and overwrite=False" + ) + fig.savefig( + output_file, + dpi=600, + ) + plt.show() + + return metrics + + +# Example usage: +if __name__ == "__main__": + # plotting + VERBOSE = True + PATH_TO_GDRIVE_FIGUE = "./" + + # Define models as a dictionary with meaningful keys + prediction_paths = { + "ntxent_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" + ), + "triplet_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + ), + } + + # output_folder to save the distributions as .csv + output_folder = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/euclidean_distance_distributions" + ) + output_folder.mkdir(parents=True, exist_ok=True) + + # Evaluate each model + for model_name, prediction_path in tqdm( + prediction_paths.items(), desc="Evaluating models" + ): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") + print("-" * 80) + + metrics = analyze_embedding_smoothness( + prediction_path, + verbose=VERBOSE, + output_path=PATH_TO_GDRIVE_FIGUE, + loss_name=model_name, + overwrite=True, + ) + + # Save distributions to CSV + distributions_df = pd.DataFrame( + { + "adjacent_frame": pd.Series(metrics["distance_distribution"]), + "random_sampling": pd.Series(metrics["random_distribution"]), + } + ) + csv_path = ( + output_folder / f"{prediction_path.stem}_{model_name}_distributions.csv" + ) + distributions_df.to_csv(csv_path, index=False) + + # Print statistics (existing code) + print("\nAdjacent Frame Distance Statistics:") + print(f"{'Mean:':<15} {metrics['distance_mean']:.3f}") + print(f"{'Std:':<15} {metrics['distance_std']:.3f}") + print(f"{'Median:':<15} {metrics['distance_median']:.3f}") + print(f"{'Peak:':<15} {metrics['distance_peak']:.3f}") + print(f"{'P1:':<15} {np.mean(metrics['distance_p1']):.3f}") + print(f"{'P99:':<15} {np.mean(metrics['distance_p99']):.3f}") + + # Print random sampling statistics + print("\nRandom Sampling Statistics:") + print(f"{'Mean:':<15} {metrics['random_mean']:.3f}") + print(f"{'Std:':<15} {metrics['random_std']:.3f}") + print(f"{'Median:':<15} {metrics['random_median']:.3f}") + print(f"{'Peak:':<15} {metrics['random_peak']:.3f}") + + # Print dynamic range + print("\nComparison Metrics:") + print(f"{'Dynamic Range:':<15} {metrics['dynamic_range']:.3f}") + + # Print distribution sizes + print("\nDistribution Sizes:") + print( + f"{'Adjacent Frame:':<15} {len(metrics['distance_distribution']):,d} samples" + ) + print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + +# %% diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..f94643aac 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -48,7 +48,12 @@ def pairwise_distance_matrix(features: ArrayLike, metric: str = "cosine") -> NDA NDArray Distance matrix of shape (n_samples, n_samples) """ - return cdist(features, features, metric=metric) + distances = cdist(features, features, metric=metric) + if metric == "euclidean": + # Normalize by sqrt of embedding dimension + print(f"features.shape: {features.shape}") + distances /= np.sqrt(features.shape[1]) + return distances def rank_nearest_neighbors( From 57851bcdb2bbc7ea2839b5ade02869aad338d58f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 10 Jan 2025 18:07:36 -0800 Subject: [PATCH 18/38] cleanup and refactoring to move usefult functions into viscy.representation.evaluation.distance and clustering files --- .../evaluation/ALFI_displacement.py | 312 ------------------ .../cosine_dissimilarity_dataset.py | 254 ++------------ ...imilarity.py => cosine_similarity_demo.py} | 0 .../evaluation/euclidean_distance_dataset.py | 255 ++------------ viscy/representation/evaluation/clustering.py | 7 +- viscy/representation/evaluation/distance.py | 305 +++++++++++++++-- 6 files changed, 320 insertions(+), 813 deletions(-) delete mode 100644 applications/contrastive_phenotyping/evaluation/ALFI_displacement.py rename applications/contrastive_phenotyping/evaluation/{cosine_similarity.py => cosine_similarity_demo.py} (100%) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py deleted file mode 100644 index 595f283f7..000000000 --- a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py +++ /dev/null @@ -1,312 +0,0 @@ -# %% -from pathlib import Path -import matplotlib.pyplot as plt -import numpy as np -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.distance import ( - calculate_normalized_euclidean_distance_cell, - compute_displacement, - compute_dynamic_range, - compute_rms_per_track, -) -from collections import defaultdict -from tabulate import tabulate - -import numpy as np -from sklearn.metrics.pairwise import cosine_similarity -from collections import OrderedDict - -# %% function - -# Removed redundant compute_displacement_mean_std_full function -# Removed redundant compute_dynamic_range and compute_rms_per_track functions - - -def plot_rms_histogram(rms_values, label, bins=30): - """ - Plot histogram of RMS values across tracks. - - Parameters: - rms_values : list - List of RMS values, one for each track. - label : str - Label for the dataset (used in the title). - bins : int, optional - Number of bins for the histogram. Default is 30. - - Returns: - None: Displays the histogram. - """ - plt.figure(figsize=(10, 6)) - plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black") - plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16) - plt.xlabel("RMS of Time Derivative", fontsize=14) - plt.ylabel("Frequency", fontsize=14) - plt.grid(True) - plt.show() - - -def plot_displacement( - mean_displacement, std_displacement, label, metrics_no_track=None -): - """ - Plot embedding displacement over time with mean and standard deviation. - - Parameters: - mean_displacement : dict - Mean displacement for each tau. - std_displacement : dict - Standard deviation of displacement for each tau. - label : str - Label for the dataset. - metrics_no_track : dict, optional - Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against. - - Returns: - None: Displays the plot. - """ - plt.figure(figsize=(10, 6)) - taus = list(mean_displacement.keys()) - mean_values = list(mean_displacement.values()) - std_values = list(std_displacement.values()) - - plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green") - plt.fill_between( - taus, - np.array(mean_values) - np.array(std_values), - np.array(mean_values) + np.array(std_values), - color="green", - alpha=0.3, - label=f"Std Dev ({label})", - ) - - if metrics_no_track: - mean_values_no_track = list(metrics_no_track["mean_displacement"].values()) - std_values_no_track = list(metrics_no_track["std_displacement"].values()) - - plt.plot( - taus, - mean_values_no_track, - marker="o", - label="Classical Contrastive (No Tracking)", - color="blue", - ) - plt.fill_between( - taus, - np.array(mean_values_no_track) - np.array(std_values_no_track), - np.array(mean_values_no_track) + np.array(std_values_no_track), - color="blue", - alpha=0.3, - label="Std Dev (No Tracking)", - ) - - plt.xlabel("Time Shift (τ)", fontsize=14) - plt.ylabel("Euclidean Distance", fontsize=14) - plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16) - plt.grid(True) - plt.legend(fontsize=12) - plt.show() - - -def plot_overlay_displacement(overlay_displacement_data): - """ - Plot embedding displacement over time for all datasets in one plot. - - Parameters: - overlay_displacement_data : dict - A dictionary containing mean displacement per tau for all datasets. - - Returns: - None: Displays the plot. - """ - plt.figure(figsize=(12, 8)) - for label, mean_displacement in overlay_displacement_data.items(): - taus = list(mean_displacement.keys()) - mean_values = list(mean_displacement.values()) - plt.plot(taus, mean_values, marker="o", label=label) - - plt.xlabel("Time Shift (τ)", fontsize=14) - plt.ylabel("Euclidean Distance", fontsize=14) - plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) - plt.grid(True) - plt.legend(fontsize=12) - plt.show() - - -# %% hist stats -def plot_boxplot_rms_across_models(datasets_rms): - """ - Plot a boxplot for the distribution of RMS values across models. - - Parameters: - datasets_rms : dict - A dictionary where keys are dataset names and values are lists of RMS values. - - Returns: - None: Displays the boxplot. - """ - plt.figure(figsize=(12, 6)) - labels = list(datasets_rms.keys()) - data = list(datasets_rms.values()) - print(labels) - print(data) - # Plot the boxplot - plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) - - plt.title( - "Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16 - ) - plt.ylabel("RMS of Time Derivative", fontsize=14) - plt.xticks(rotation=45, fontsize=12) - plt.grid(axis="y", linestyle="--", alpha=0.7) - plt.tight_layout() - plt.show() - - -def plot_histogram_absolute_differences(datasets_abs_diff): - """ - Plot histograms of absolute differences across embeddings for all models. - - Parameters: - datasets_abs_diff : dict - A dictionary where keys are dataset names and values are lists of absolute differences. - - Returns: - None: Displays the histograms. - """ - plt.figure(figsize=(12, 6)) - for label, abs_diff in datasets_abs_diff.items(): - plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) - - plt.title("Histograms of Absolute Differences Across Models", fontsize=16) - plt.xlabel("Absolute Difference", fontsize=14) - plt.ylabel("Density", fontsize=14) - plt.legend(fontsize=12) - plt.grid(alpha=0.7) - plt.tight_layout() - plt.show() - - -# %% Paths to datasets -feature_paths = { - "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", - "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", - "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", - "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", - "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", -} - -no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" - -# %% Process Datasets -max_tau = 69 -metrics = {} - -overlay_displacement_data = {} -datasets_rms = {} -datasets_abs_diff = {} - -# Process "No Tracking" dataset -features_path_no_track = Path(no_track_path) -embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) - -mean_displacement_no_track, std_displacement_no_track = compute_displacement( - embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True -) -dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) -metrics["No Tracking"] = { - "dynamic_range": dynamic_range_no_track, - "mean_displacement": mean_displacement_no_track, - "std_displacement": std_displacement_no_track, -} - -overlay_displacement_data["No Tracking"] = mean_displacement_no_track - -print("\nProcessing No Tracking dataset...") -print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") - -plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking") - -rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) -datasets_rms["No Tracking"] = rms_values_no_track - -print(f"Plotting histogram of RMS values for No Tracking dataset...") -plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) - -# Compute absolute differences for "No Tracking" -abs_diff_no_track = np.concatenate( - [ - np.linalg.norm( - np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), - axis=-1, - ) - for indices in np.split( - np.arange(len(embedding_dataset_no_track["track_id"])), - np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1, - ) - ] -) -datasets_abs_diff["No Tracking"] = abs_diff_no_track - -# Process other datasets -for label, path in feature_paths.items(): - print(f"\nProcessing {label} dataset...") - - features_path = Path(path) - embedding_dataset = read_embedding_dataset(features_path) - - mean_displacement, std_displacement = compute_displacement( - embedding_dataset, max_tau=max_tau, return_mean_std=True - ) - dynamic_range = compute_dynamic_range(mean_displacement) - metrics[label] = { - "dynamic_range": dynamic_range, - "mean_displacement": mean_displacement, - "std_displacement": std_displacement, - } - - overlay_displacement_data[label] = mean_displacement - - print(f"Dynamic Range for {label}: {dynamic_range}") - - plot_displacement( - mean_displacement, - std_displacement, - label, - metrics_no_track=metrics.get("No Tracking", None), - ) - - rms_values = compute_rms_per_track(embedding_dataset) - datasets_rms[label] = rms_values - - print(f"Plotting histogram of RMS values for {label}...") - plot_rms_histogram(rms_values, label, bins=30) - - abs_diff = np.concatenate( - [ - np.linalg.norm( - np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1 - ) - for indices in np.split( - np.arange(len(embedding_dataset["track_id"])), - np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1, - ) - ] - ) - datasets_abs_diff[label] = abs_diff - -print("\nPlotting overlayed displacement for all datasets...") -plot_overlay_displacement(overlay_displacement_data) - -print("\nSummary of Dynamic Ranges:") -for label, metric in metrics.items(): - print(f"{label}: Dynamic Range = {metric['dynamic_range']}") - -print("\nPlotting RMS boxplot across models...") -plot_boxplot_rms_across_models(datasets_rms) - -print("\nPlotting histograms of absolute differences across models...") -plot_histogram_absolute_differences(datasets_abs_diff) - - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index 59f375afc..fb204078f 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -1,223 +1,17 @@ # %% from pathlib import Path -from typing import Optional import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.preprocessing import StandardScaler -from numpy.typing import NDArray - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.clustering import ( - compare_time_offset, - pairwise_distance_matrix, - rank_nearest_neighbors, - select_block, -) -import numpy as np from tqdm import tqdm -import pandas as pd - -from scipy.stats import gaussian_kde -from scipy.optimize import minimize_scalar +from viscy.representation.evaluation.distance import ( + compute_embedding_distances, + analyze_and_plot_distances, +) plt.style.use("../evaluation/figure.mplstyle") - -def compute_piece_wise_dissimilarity( - features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray -): - """ - Computing the smoothness and dynamic range - - Get the off diagonal per block and compute the mode - - The blocks are not square, so we need to get the off diagonal elements - - Get the 1 and 99 percentile of the off diagonal per block - """ - piece_wise_dissimilarity_per_track = [] - piece_wise_rank_difference_per_track = [] - for name, subdata in features_df.groupby(["fov_name", "track_id"]): - if len(subdata) > 1: - indices = subdata.index.values - single_track_dissimilarity = select_block(cross_dist, indices) - single_track_rank_fraction = select_block(rank_fractions, indices) - piece_wise_dissimilarity = compare_time_offset( - single_track_dissimilarity, time_offset=1 - ) - piece_wise_rank_difference = compare_time_offset( - single_track_rank_fraction, time_offset=1 - ) - piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) - piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) - return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track - - -def plot_histogram( - data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" -): - plt.figure() - plt.title(title) - sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.tight_layout() - plt.show() - - -def find_distribution_peak(data: np.ndarray) -> float: - """ - Find the peak (mode) of a distribution using kernel density estimation. - - Args: - data: Array of values to find the peak for - - Returns: - float: The x-value where the peak occurs - """ - kde = gaussian_kde(data) - # Find the peak (maximum) of the KDE - result = minimize_scalar( - lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" - ) - return result.x - - -def analyze_embedding_smoothness( - prediction_path: Path, - verbose: bool = False, - output_path: Optional[str] = None, - loss_name: Optional[str] = None, - overwrite: bool = False, -) -> dict: - """ - Analyze the smoothness and dynamic range of embeddings. - - Args: - prediction_path: Path to the embedding dataset - verbose: If True, generates additional plots - output_path: Path to save the final plot (optional) - loss_name: Name of the loss function used (optional) - overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) - - Returns: - dict: Dictionary containing metrics including: - - dissimilarity_mean: Mean of adjacent frame dissimilarity - - dissimilarity_std: Standard deviation of adjacent frame dissimilarity - - dissimilarity_median: Median of adjacent frame dissimilarity - - dissimilarity_peak: Peak of adjacent frame distribution - - dissimilarity_p99: 99th percentile of adjacent frame dissimilarity - - dissimilarity_p1: 1st percentile of adjacent frame dissimilarity - - dissimilarity_distribution: Full distribution of adjacent frame dissimilarities - - random_mean: Mean of random sampling dissimilarity - - random_std: Standard deviation of random sampling dissimilarity - - random_median: Median of random sampling dissimilarity - - random_peak: Peak of random sampling distribution - - random_distribution: Full distribution of random sampling dissimilarities - - dynamic_range: Difference between random and adjacent peaks - """ - # Read the dataset - embeddings = read_embedding_dataset(prediction_path) - features = embeddings["features"] - - scaled_features = StandardScaler().fit_transform(features.values) - # Compute the cosine dissimilarity - cross_dist = pairwise_distance_matrix(scaled_features, metric="cosine") - rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - - # Compute piece-wise dissimilarity and rank difference - features_df = features["sample"].to_dataframe().reset_index(drop=True) - piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( - compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) - ) - - all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) - - p99_piece_wise_dissimilarity = np.array( - [np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track] - ) - p1_percentile_piece_wise_dissimilarity = np.array( - [np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track] - ) - - # Random sampling values in the dissimilarity matrix with same size as adjacent frame measurements - n_samples = len(all_dissimilarity) - random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) - sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] - - # Compute the peaks of both distributions using KDE - adjacent_peak = float(find_distribution_peak(all_dissimilarity)) - random_peak = float(find_distribution_peak(sampled_values)) - dynamic_range = float(random_peak - adjacent_peak) - - metrics = { - "dissimilarity_mean": float(np.mean(all_dissimilarity)), - "dissimilarity_std": float(np.std(all_dissimilarity)), - "dissimilarity_median": float(np.median(all_dissimilarity)), - "dissimilarity_peak": adjacent_peak, - "dissimilarity_p99": p99_piece_wise_dissimilarity, - "dissimilarity_p1": p1_percentile_piece_wise_dissimilarity, - "dissimilarity_distribution": all_dissimilarity, - "random_mean": float(np.mean(sampled_values)), - "random_std": float(np.std(sampled_values)), - "random_median": float(np.median(sampled_values)), - "random_peak": random_peak, - "random_distribution": sampled_values, - "dynamic_range": dynamic_range, - } - - if verbose: - - # Plot the comparison histogram and save if output_path is provided - fig = plt.figure() - sns.histplot( - metrics["dissimilarity_distribution"], - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", - ) - sns.histplot( - metrics["random_distribution"], - bins=30, - kde=True, - color="red", - alpha=0.5, - stat="density", - ) - plt.xlabel("Cosine Dissimilarity") - plt.ylabel("Density") - # Add vertical lines for the peaks - plt.axvline( - x=metrics["dissimilarity_peak"], color="cyan", linestyle="--", alpha=0.8 - ) - plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) - plt.tight_layout() - plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - - if output_path and loss_name: - output_file = Path( - f"{output_path}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf" - ) - if output_file.exists() and not overwrite: - raise FileExistsError( - f"File {output_file} already exists and overwrite=False" - ) - fig.savefig( - output_file, - dpi=600, - ) - plt.show() - - return metrics - - -# Example usage: if __name__ == "__main__": - # plotting - VERBOSE = True - PATH_TO_GDRIVE_FIGUE = "./" - # Define models as a dictionary with meaningful keys prediction_paths = { "ntxent_sensor_phase": Path( @@ -239,36 +33,30 @@ def analyze_embedding_smoothness( prediction_paths.items(), desc="Evaluating models" ): print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") - print("-" * 80) - metrics = analyze_embedding_smoothness( - prediction_path, - verbose=VERBOSE, - output_path=PATH_TO_GDRIVE_FIGUE, - loss_name=model_name, - overwrite=True, + # Compute and save distributions + distributions_df = compute_embedding_distances( + prediction_path=prediction_path, + output_folder=output_folder, + distance_metric="cosine", + verbose=True, ) - # Save distributions to CSV - distributions_df = pd.DataFrame( - { - "adjacent_frame": pd.Series(metrics["dissimilarity_distribution"]), - "random_sampling": pd.Series(metrics["random_distribution"]), - } - ) - csv_path = ( - output_folder / f"{prediction_path.stem}_{model_name}_distributions.csv" + # Analyze distributions and create plots + metrics = analyze_and_plot_distances( + distributions_df, + output_file_path=output_folder / f"{model_name}_distance_plot.pdf", + overwrite=True, ) - distributions_df.to_csv(csv_path, index=False) - # Print statistics (rest of the printing code remains the same) - print("\nAdjacent Frame Dissimilarity Statistics:") + # Print statistics + print("\nAdjacent Frame Distance Statistics:") print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") - print(f"{'P1:':<15} {np.mean(metrics['dissimilarity_p1']):.3f}") - print(f"{'P99:':<15} {np.mean(metrics['dissimilarity_p99']):.3f}") + print(f"{'P1:':<15} {metrics['dissimilarity_p1']:.3f}") + print(f"{'P99:':<15} {metrics['dissimilarity_p99']:.3f}") # Print random sampling statistics print("\nRandom Sampling Statistics:") @@ -284,8 +72,8 @@ def analyze_embedding_smoothness( # Print distribution sizes print("\nDistribution Sizes:") print( - f"{'Adjacent Frame:':<15} {len(metrics['dissimilarity_distribution']):,d} samples" + f"{'Adjacent Frame:':<15} {len(distributions_df['adjacent_frame']):,d} samples" ) - print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + print(f"{'Random:':<15} {len(distributions_df['random_sampling']):,d} samples") # %% diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity_demo.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_similarity.py rename to applications/contrastive_phenotyping/evaluation/cosine_similarity_demo.py diff --git a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py index ec2965533..e49a44002 100644 --- a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py @@ -1,218 +1,17 @@ # %% from pathlib import Path -from typing import Optional import matplotlib.pyplot as plt -import seaborn as sns -from numpy.typing import NDArray - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.clustering import ( - compare_time_offset, - pairwise_distance_matrix, - rank_nearest_neighbors, - select_block, -) -import numpy as np from tqdm import tqdm -import pandas as pd - -from scipy.stats import gaussian_kde -from scipy.optimize import minimize_scalar +from viscy.representation.evaluation.distance import ( + compute_embedding_distances, + analyze_and_plot_distances, +) plt.style.use("../evaluation/figure.mplstyle") - -def compute_piece_wise_distance( - features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray -): - """ - Computing the smoothness and dynamic range - - Get the off diagonal per block and compute the mode - - The blocks are not square, so we need to get the off diagonal elements - - Get the 1 and 99 percentile of the off diagonal per block - """ - piece_wise_distance_per_track = [] - piece_wise_rank_difference_per_track = [] - for name, subdata in features_df.groupby(["fov_name", "track_id"]): - if len(subdata) > 1: - indices = subdata.index.values - single_track_distance = select_block(cross_dist, indices) - single_track_rank_fraction = select_block(rank_fractions, indices) - piece_wise_distance = compare_time_offset( - single_track_distance, time_offset=1 - ) - piece_wise_rank_difference = compare_time_offset( - single_track_rank_fraction, time_offset=1 - ) - piece_wise_distance_per_track.append(piece_wise_distance) - piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) - return piece_wise_distance_per_track, piece_wise_rank_difference_per_track - - -def plot_histogram( - data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" -): - plt.figure() - plt.title(title) - sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.tight_layout() - plt.show() - - -def find_distribution_peak(data: np.ndarray) -> float: - """ - Find the peak (mode) of a distribution using kernel density estimation. - - Args: - data: Array of values to find the peak for - - Returns: - float: The x-value where the peak occurs - """ - kde = gaussian_kde(data) - # Find the peak (maximum) of the KDE - result = minimize_scalar( - lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" - ) - return result.x - - -def analyze_embedding_smoothness( - prediction_path: Path, - verbose: bool = False, - output_path: Optional[str] = None, - loss_name: Optional[str] = None, - overwrite: bool = False, -) -> dict: - """ - Analyze the smoothness and dynamic range of embeddings using Euclidean distance. - - Args: - prediction_path: Path to the embedding dataset - verbose: If True, generates additional plots - output_path: Path to save the final plot (optional) - loss_name: Name of the loss function used (optional) - overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) - - Returns: - dict: Dictionary containing metrics including: - - distance_mean: Mean of adjacent frame distance - - distance_std: Standard deviation of adjacent frame distance - - distance_median: Median of adjacent frame distance - - distance_peak: Peak of adjacent frame distribution - - distance_p99: 99th percentile of adjacent frame distance - - distance_p1: 1st percentile of adjacent frame distance - - distance_distribution: Full distribution of adjacent frame distances - - random_mean: Mean of random sampling distance - - random_std: Standard deviation of random sampling distance - - random_median: Median of random sampling distance - - random_peak: Peak of random sampling distribution - - random_distribution: Full distribution of random sampling distances - - dynamic_range: Difference between random and adjacent peaks - """ - # Read the dataset - embeddings = read_embedding_dataset(prediction_path) - features = embeddings["features"] - - # Compute the Euclidean distance - cross_dist = pairwise_distance_matrix(features, metric="euclidean") - rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - - # Compute piece-wise distance and rank difference - features_df = features["sample"].to_dataframe().reset_index(drop=True) - piece_wise_distance_per_track, piece_wise_rank_difference_per_track = ( - compute_piece_wise_distance(features_df, cross_dist, rank_fractions) - ) - - all_distance = np.concatenate(piece_wise_distance_per_track) - - p99_piece_wise_distance = np.array( - [np.percentile(track, 99) for track in piece_wise_distance_per_track] - ) - p1_percentile_piece_wise_distance = np.array( - [np.percentile(track, 1) for track in piece_wise_distance_per_track] - ) - - # Random sampling values in the distance matrix with same size as adjacent frame measurements - n_samples = len(all_distance) - random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) - sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] - - # Compute the peaks of both distributions using KDE - adjacent_peak = float(find_distribution_peak(all_distance)) - random_peak = float(find_distribution_peak(sampled_values)) - dynamic_range = float(random_peak - adjacent_peak) - - metrics = { - "distance_mean": float(np.mean(all_distance)), - "distance_std": float(np.std(all_distance)), - "distance_median": float(np.median(all_distance)), - "distance_peak": adjacent_peak, - "distance_p99": p99_piece_wise_distance, - "distance_p1": p1_percentile_piece_wise_distance, - "distance_distribution": all_distance, - "random_mean": float(np.mean(sampled_values)), - "random_std": float(np.std(sampled_values)), - "random_median": float(np.median(sampled_values)), - "random_peak": random_peak, - "random_distribution": sampled_values, - "dynamic_range": dynamic_range, - } - - if verbose: - # Plot the comparison histogram and save if output_path is provided - fig = plt.figure() - sns.histplot( - metrics["distance_distribution"], - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", - ) - sns.histplot( - metrics["random_distribution"], - bins=30, - kde=True, - color="red", - alpha=0.5, - stat="density", - ) - plt.xlabel("Euclidean Distance") - plt.ylabel("Density") - # Add vertical lines for the peaks - plt.axvline(x=metrics["distance_peak"], color="cyan", linestyle="--", alpha=0.8) - plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) - plt.tight_layout() - plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - - if output_path and loss_name: - output_file = Path( - f"{output_path}/euclidean_distance_smoothness_{prediction_path.stem}_{loss_name}.pdf" - ) - if output_file.exists() and not overwrite: - raise FileExistsError( - f"File {output_file} already exists and overwrite=False" - ) - fig.savefig( - output_file, - dpi=600, - ) - plt.show() - - return metrics - - -# Example usage: if __name__ == "__main__": - # plotting - VERBOSE = True - PATH_TO_GDRIVE_FIGUE = "./" - # Define models as a dictionary with meaningful keys prediction_paths = { "ntxent_sensor_phase": Path( @@ -234,36 +33,30 @@ def analyze_embedding_smoothness( prediction_paths.items(), desc="Evaluating models" ): print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") - print("-" * 80) - metrics = analyze_embedding_smoothness( - prediction_path, - verbose=VERBOSE, - output_path=PATH_TO_GDRIVE_FIGUE, - loss_name=model_name, - overwrite=True, + # Compute and save distributions + distributions_df = compute_embedding_distances( + prediction_path=prediction_path, + output_folder=output_folder, + distance_metric="euclidean", + verbose=True, ) - # Save distributions to CSV - distributions_df = pd.DataFrame( - { - "adjacent_frame": pd.Series(metrics["distance_distribution"]), - "random_sampling": pd.Series(metrics["random_distribution"]), - } - ) - csv_path = ( - output_folder / f"{prediction_path.stem}_{model_name}_distributions.csv" + # Analyze distributions and create plots + metrics = analyze_and_plot_distances( + distributions_df, + output_file_path=output_folder / f"{model_name}_distance_plot.pdf", + overwrite=True, ) - distributions_df.to_csv(csv_path, index=False) - # Print statistics (existing code) + # Print statistics print("\nAdjacent Frame Distance Statistics:") - print(f"{'Mean:':<15} {metrics['distance_mean']:.3f}") - print(f"{'Std:':<15} {metrics['distance_std']:.3f}") - print(f"{'Median:':<15} {metrics['distance_median']:.3f}") - print(f"{'Peak:':<15} {metrics['distance_peak']:.3f}") - print(f"{'P1:':<15} {np.mean(metrics['distance_p1']):.3f}") - print(f"{'P99:':<15} {np.mean(metrics['distance_p99']):.3f}") + print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") + print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") + print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") + print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") + print(f"{'P1:':<15} {metrics['dissimilarity_p1']:.3f}") + print(f"{'P99:':<15} {metrics['dissimilarity_p99']:.3f}") # Print random sampling statistics print("\nRandom Sampling Statistics:") @@ -279,8 +72,8 @@ def analyze_embedding_smoothness( # Print distribution sizes print("\nDistribution Sizes:") print( - f"{'Adjacent Frame:':<15} {len(metrics['distance_distribution']):,d} samples" + f"{'Adjacent Frame:':<15} {len(distributions_df['adjacent_frame']):,d} samples" ) - print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + print(f"{'Random:':<15} {len(distributions_df['random_sampling']):,d} samples") # %% diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index f94643aac..ebf49455f 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -48,12 +48,7 @@ def pairwise_distance_matrix(features: ArrayLike, metric: str = "cosine") -> NDA NDArray Distance matrix of shape (n_samples, n_samples) """ - distances = cdist(features, features, metric=metric) - if metric == "euclidean": - # Normalize by sqrt of embedding dimension - print(f"features.shape: {features.shape}") - distances /= np.sqrt(features.shape[1]) - return distances + return cdist(features, features, metric=metric) def rank_nearest_neighbors( diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index b8552fa24..3c2a17fe8 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,29 +1,75 @@ from collections import defaultdict -from typing import Dict, List, Literal, Tuple +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple +import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import seaborn as sns +from numpy.typing import NDArray +from scipy.optimize import minimize_scalar +from scipy.stats import gaussian_kde from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import StandardScaler from tqdm import tqdm +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) -def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - """Extract embeddings and calculate cosine similarities for a specific cell""" - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - first_time_point_embedding = features[0].reshape(1, -1) - cosine_similarities = cosine_similarity( - first_time_point_embedding, features - ).flatten() - return time_points, cosine_similarities.tolist() +def calculate_distance_cell( + embedding_dataset, + fov_name, + track_id, + metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", +): + """ + Calculate distances between a cell's first timepoint embedding and all its subsequent embeddings. + + This function extracts embeddings for a specific cell (identified by fov_name and track_id) + and calculates the distance between its first timepoint embedding and all subsequent timepoints + using the specified distance metric. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing the embeddings and metadata. Must have dimensions for 'features', + 'fov_name', 'track_id', and 't' (time). + fov_name : str + Field of view name to identify the specific imaging area. + track_id : int + Track ID of the cell to analyze. + metric : {'cosine', 'euclidean', 'normalized_euclidean'}, default='cosine' + Distance metric to use for calculations: + - 'cosine': Cosine similarity between embeddings + - 'euclidean': Standard Euclidean distance + - 'normalized_euclidean': Euclidean distance between L2-normalized embeddings -def calculate_euclidian_distance_cell(embedding_dataset, fov_name, track_id): - """Extract embeddings and calculate euclidean distances for a specific cell""" + Returns + ------- + time_points : numpy.ndarray + Array of time points corresponding to the calculated distances. + distances : list + List of distances between the first timepoint embedding and each subsequent + timepoint embedding, calculated using the specified metric. + + Notes + ----- + For 'normalized_euclidean', embeddings are L2-normalized before distance calculation. + Cosine similarity results in values between -1 and 1, where 1 indicates identical + direction, 0 indicates orthogonality, and -1 indicates opposite directions. + Euclidean distances are always non-negative. + + Examples + -------- + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="cosine") + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="euclidean") + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), @@ -31,9 +77,18 @@ def calculate_euclidian_distance_cell(embedding_dataset, fov_name, track_id): ) features = filtered_data["features"].values # (sample, features) time_points = filtered_data["t"].values # (sample,) + + if metric == "normalized_euclidean": + features = features / np.linalg.norm(features, axis=1, keepdims=True) + first_time_point_embedding = features[0].reshape(1, -1) - euclidean_distances = np.linalg.norm(first_time_point_embedding - features, axis=1) - return time_points, euclidean_distances.tolist() + + if metric == "cosine": + distances = cosine_similarity(first_time_point_embedding, features).flatten() + else: # both euclidean and normalized_euclidean use norm + distances = np.linalg.norm(first_time_point_embedding - features, axis=1) + + return time_points, distances.tolist() def compute_displacement( @@ -201,17 +256,205 @@ def compute_rms_per_track(embedding_dataset): return rms_values -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. + + Args: + data: Array of values to find the peak for + + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) - first_time_point_embedding = normalized_features[0].reshape(1, -1) - euclidean_distances = np.linalg.norm( - first_time_point_embedding - normalized_features, axis=1 + return result.x + + +def compute_piece_wise_dissimilarity( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def compute_embedding_distances( + prediction_path: Path, + output_folder: Path, + distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", + verbose: bool = False, +) -> pd.DataFrame: + """ + Compute and save pairwise distances between embeddings. + + Parameters + ---------- + prediction_path : Path + Path to the embedding dataset + output_folder : Path + Folder where to save the CSV file + distance_metric : str, optional + Distance metric to use for computing distances between embeddings + verbose : bool, optional + If True, plots the distance matrix visualization + + Returns + ------- + pd.DataFrame + DataFrame containing the adjacent frame and random sampling distances + """ + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + if distance_metric != "euclidean": + features = StandardScaler().fit_transform(features.values) + + # Compute the distance matrix + cross_dist = pairwise_distance_matrix(features, metric=distance_metric) + + # Normalize by sqrt of embedding dimension if using euclidean distance + if distance_metric == "euclidean": + cross_dist /= np.sqrt(features.shape[1]) + + if verbose: + # Plot the distance matrix + plt.figure(figsize=(10, 10)) + plt.imshow(cross_dist, cmap="viridis") + plt.colorbar(label=f"{distance_metric.capitalize()} Distance") + plt.title(f"{distance_metric.capitalize()} Distance Matrix") + plt.tight_layout() + plt.show() + + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise dissimilarity and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) + ) + + all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) + + # Random sampling values in the dissimilarity matrix + n_samples = len(all_dissimilarity) + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + # Create and save DataFrame + distributions_df = pd.DataFrame( + { + "adjacent_frame": pd.Series(all_dissimilarity), + "random_sampling": pd.Series(sampled_values), + } + ) + + csv_path = output_folder / f"{prediction_path.stem}_distributions.csv" + distributions_df.to_csv(csv_path, index=False) + + return distributions_df + + +def analyze_and_plot_distances( + distributions_df: pd.DataFrame, + output_file_path: Optional[str], + overwrite: bool = False, +) -> dict: + """ + Analyze distance distributions and create visualization plots. + + Parameters + ---------- + distributions_df : pd.DataFrame + DataFrame containing 'adjacent_frame' and 'random_sampling' columns + output_file_path : str, optional + Path to save the plot ideally with a .pdf extension. Uses `plt.savefig()` + overwrite : bool, default=False + If True, overwrites existing files + + Returns + ------- + dict + Dictionary containing computed metrics including means, standard deviations, + medians, peaks, and dynamic range of the distributions + """ + # Compute statistics + adjacent_dist = distributions_df["adjacent_frame"].values + random_dist = distributions_df["random_sampling"].values + + # Compute peaks + adjacent_peak = float(find_distribution_peak(adjacent_dist)) + random_peak = float(find_distribution_peak(random_dist)) + dynamic_range = float(random_peak - adjacent_peak) + + metrics = { + "dissimilarity_mean": float(np.mean(adjacent_dist)), + "dissimilarity_std": float(np.std(adjacent_dist)), + "dissimilarity_median": float(np.median(adjacent_dist)), + "dissimilarity_peak": adjacent_peak, + "dissimilarity_p99": float(np.percentile(adjacent_dist, 99)), + "dissimilarity_p1": float(np.percentile(adjacent_dist, 1)), + "random_mean": float(np.mean(random_dist)), + "random_std": float(np.std(random_dist)), + "random_median": float(np.median(random_dist)), + "random_peak": random_peak, + "dynamic_range": dynamic_range, + } + + # Create plot + fig = plt.figure() + sns.histplot( + data=distributions_df, + x="adjacent_frame", + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + data=distributions_df, + x="random_sampling", + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", ) - return time_points, euclidean_distances.tolist() + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + plt.axvline(x=adjacent_peak, color="cyan", linestyle="--", alpha=0.8) + plt.axvline(x=random_peak, color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + if output_file_path.exists() and not overwrite: + raise FileExistsError( + f"File {output_file_path} already exists and overwrite=False" + ) + fig.savefig(output_file_path, dpi=600) + plt.show() + + return metrics From 10378930d98c3fad8f6ddc2d1107b642b27a9ce4 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 13 Jan 2025 11:20:40 -0800 Subject: [PATCH 19/38] changed csv file naming --- .../evaluation/euclidean_distance_dataset.py | 17 +++++++++-------- viscy/representation/evaluation/distance.py | 8 ++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py index e49a44002..3ed67dcc2 100644 --- a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py @@ -1,6 +1,7 @@ # %% from pathlib import Path - +import sys +sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") import matplotlib.pyplot as plt from tqdm import tqdm @@ -9,22 +10,22 @@ analyze_and_plot_distances, ) -plt.style.use("../evaluation/figure.mplstyle") +# plt.style.use("../evaluation/figure.mplstyle") if __name__ == "__main__": # Define models as a dictionary with meaningful keys prediction_paths = { - "ntxent_sensor_phase": Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" + "ntxent_classical": Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr" ), - "triplet_sensor_phase": Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + "triplet_classical": Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/log_alfi_triplet_time_intervals/prediction/ALFI_classical.zarr" ), } # output_folder to save the distributions as .csv output_folder = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/euclidean_distance_distributions" + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/metrics" ) output_folder.mkdir(parents=True, exist_ok=True) @@ -37,7 +38,7 @@ # Compute and save distributions distributions_df = compute_embedding_distances( prediction_path=prediction_path, - output_folder=output_folder, + output_path=output_folder / f"{model_name}_distance_.csv", distance_metric="euclidean", verbose=True, ) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 3c2a17fe8..cefd8b6e6 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -303,7 +303,7 @@ def compute_piece_wise_dissimilarity( def compute_embedding_distances( prediction_path: Path, - output_folder: Path, + output_path: Path, distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", verbose: bool = False, ) -> pd.DataFrame: @@ -314,8 +314,8 @@ def compute_embedding_distances( ---------- prediction_path : Path Path to the embedding dataset - output_folder : Path - Folder where to save the CSV file + output_path : Path + name of saved CSV file distance_metric : str, optional Distance metric to use for computing distances between embeddings verbose : bool, optional @@ -372,7 +372,7 @@ def compute_embedding_distances( } ) - csv_path = output_folder / f"{prediction_path.stem}_distributions.csv" + csv_path = output_path distributions_df.to_csv(csv_path, index=False) return distributions_df From 5a9c296fd750eb139ae117116e9011d9d8015bcf Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 15 Jan 2025 10:59:20 -0800 Subject: [PATCH 20/38] add script for ALFI cell division --- .../figures/ALFI_cell_division.py | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/ALFI_cell_division.py diff --git a/applications/contrastive_phenotyping/figures/ALFI_cell_division.py b/applications/contrastive_phenotyping/figures/ALFI_cell_division.py new file mode 100644 index 000000000..eaf941d93 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/ALFI_cell_division.py @@ -0,0 +1,252 @@ + +# %% Figure on ALFI cell division model showing +# (a) Euclidean distance over a cell division event and +# (b) difference between trajectory of cell in time-aware and classical method over division event + +from pathlib import Path +from collections import defaultdict +import seaborn as sns +import matplotlib.pyplot as plt +from matplotlib.patches import FancyArrowPatch +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +import pandas as pd + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% Task A: plot the Eucledian distance for a dividing cell + +# Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr", + "14 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr", + "91 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr", + "Classical": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr", +} + +track_well = '/0/2/0' +parent_id = 3 # 11 +daughter1_track = 4 # 12 +daughter2_track = 5 # 13 + +# %% plot the eucledian distance over time lag for a parent cell for different time intervals + +def compute_displacement_track(fov_name, track_id, current_time, distance_metric="euclidean_squared", max_delta_t=10): + + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # find index where fov_name, track_id and current_time match + i = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == current_time) + )[0][0] + current_embedding = embeddings[i].reshape(1, -1) + + # Check if max_delta_t is provided, otherwise use the maximum timepoint + if max_delta_t is None: + max_delta_t = timepoints.max() + + displacement_per_delta_t = defaultdict(list) + + # Compute displacements for each delta t + for delta_t in range(1, max_delta_t + 1): + future_time = current_time + delta_t + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + if distance_metric == "euclidean_squared": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = np.sum((current_embedding - future_embedding) ** 2) + elif distance_metric == "cosine": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = cosine_similarity( + current_embedding, future_embedding + ) + displacement_per_delta_t[delta_t].append(displacement) + + return displacement_per_delta_t + +# %% plot the eucledian distance for a parent cell + +plt.figure(figsize=(10, 6)) +for label, path in feature_paths.items(): + embedding_dataset = read_embedding_dataset(path) + displacement_per_delta_t = compute_displacement_track(track_well, parent_id, 1) + delta_ts = sorted(displacement_per_delta_t.keys()) + displacements = [np.mean(displacement_per_delta_t[delta_t]) for delta_t in delta_ts] + plt.plot(delta_ts, displacements, label=label) + +plt.xlabel("Time Interval (delta t)") +plt.ylabel("Displacement (Euclidean Distance)") +plt.title("Displacement vs Time Interval for Parent Cell") +plt.legend() +plt.show() + +# %% Task B: plot the phate map and overlay the dividing cell trajectory + +# for time-aware model uncomment the next three lines +# embedding_dataset = read_embedding_dataset( +# "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr" +# ) + +# for classical model uncomment the next three line +embedding_dataset = read_embedding_dataset( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr" +) + +PHATE1 = embedding_dataset["PHATE1"].values +PHATE2 = embedding_dataset["PHATE2"].values + +# %% plot PHATE map based on the embedding dataset time points + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=embedding_dataset["t"], s=7, alpha=0.8 +) + +# %% color using human annotation for cell cycle state + +def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + +# %% load the cell cycle state annotation + +ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" +) + +division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, +) + +# %% find a parent that divides to two daughter cells for ploting trajectory + +cell_parent = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == parent_id, drop=True +) +cell_parent = cell_parent["PHATE1"].values, cell_parent["PHATE2"].values +cell_parent = pd.DataFrame(np.column_stack(cell_parent), columns=["PHATE1", "PHATE2"]) + +cell_daughter1 = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == daughter1_track, drop=True +) +cell_daughter1 = cell_daughter1["PHATE1"].values, cell_daughter1["PHATE2"].values +cell_daughter1 = pd.DataFrame(np.column_stack(cell_daughter1), columns=["PHATE1", "PHATE2"]) + +cell_daughter2 = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == daughter2_track, drop=True +) +cell_daughter2 = cell_daughter2["PHATE1"].values, cell_daughter2["PHATE2"].values +cell_daughter2 = pd.DataFrame(np.column_stack(cell_daughter2), columns=["PHATE1", "PHATE2"]) + +# %% Plot: display one arrow at end of trajectory of cell overlayed on PHATE + +sns.scatterplot( + x=embedding_dataset["PHATE1"], + y=embedding_dataset["PHATE2"], + hue=division, + palette={"interphase": "steelblue", "mitosis": "orangered", -1: "green"}, + s=7, + alpha=0.5, +) + +# sns.lineplot(x=cell_parent["PHATE1"], y=cell_parent["PHATE2"], color="black", linewidth=2) +# sns.lineplot( +# x=cell_daughter1["PHATE1"], y=cell_daughter1["PHATE2"], color="blue", linewidth=2 +# ) +# sns.lineplot( +# x=cell_daughter2["PHATE1"], y=cell_daughter2["PHATE2"], color="red", linewidth=2 +# ) + +parent_arrow = FancyArrowPatch( + (cell_parent["PHATE1"].values[28], cell_parent["PHATE2"].values[28]), + (cell_parent["PHATE1"].values[35], cell_parent["PHATE2"].values[35]), + color="black", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(parent_arrow) +parent_arrow = FancyArrowPatch( + (cell_parent["PHATE1"].values[35], cell_parent["PHATE2"].values[35]), + (cell_parent["PHATE1"].values[38], cell_parent["PHATE2"].values[38]), + color="black", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(parent_arrow) +daughter1_arrow = FancyArrowPatch( + (cell_daughter1["PHATE1"].values[0], cell_daughter1["PHATE2"].values[0]), + (cell_daughter1["PHATE1"].values[1], cell_daughter1["PHATE2"].values[1]), + color="blue", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter1_arrow) +daughter1_arrow = FancyArrowPatch( + (cell_daughter1["PHATE1"].values[1], cell_daughter1["PHATE2"].values[1]), + (cell_daughter1["PHATE1"].values[10], cell_daughter1["PHATE2"].values[10]), + color="blue", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter1_arrow) +daughter2_arrow = FancyArrowPatch( + (cell_daughter2["PHATE1"].values[0], cell_daughter2["PHATE2"].values[0]), + (cell_daughter2["PHATE1"].values[1], cell_daughter2["PHATE2"].values[1]), + color="red", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter2_arrow) +daughter2_arrow = FancyArrowPatch( + (cell_daughter2["PHATE1"].values[1], cell_daughter2["PHATE2"].values[1]), + (cell_daughter2["PHATE1"].values[10], cell_daughter2["PHATE2"].values[10]), + color="red", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter2_arrow) + +# %% From 4ecc79b20e9cc4712bdeb1c51298290d1a1a6133 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 16 Jan 2025 09:17:41 -0800 Subject: [PATCH 21/38] PCA analysis --- .../pseudotime_analysis/pca_analysis.py | 439 ++++++++++++++++++ 1 file changed, 439 insertions(+) create mode 100644 applications/pseudotime_analysis/pca_analysis.py diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py new file mode 100644 index 000000000..d8bc91204 --- /dev/null +++ b/applications/pseudotime_analysis/pca_analysis.py @@ -0,0 +1,439 @@ +# %% +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import seaborn as sns +from viscy.representation.embedding_writer import read_embedding_dataset +from scipy.spatial.distance import pdist, squareform + + +def analyze_pc_loadings(pca, feature_names=None, top_n=5): + """Analyze which features contribute most to each PC.""" + if feature_names is None: + feature_names = [f"Feature_{i}" for i in range(pca.components_[0].shape[0])] + + pc_loadings = [] + for i, pc in enumerate(pca.components_): + # Get the absolute loadings + abs_loadings = np.abs(pc) + # Get indices of top contributing features + top_indices = np.argsort(abs_loadings)[-top_n:][::-1] + + # Store the results + pc_dict = { + "PC": i + 1, + "Variance_Explained": pca.explained_variance_ratio_[i], + "Top_Features": [feature_names[idx] for idx in top_indices], + "Top_Loadings": [pc[idx] for idx in top_indices], + } + pc_loadings.append(pc_dict) + + return pd.DataFrame(pc_loadings) + + +def analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + phenotype_of_interest, + seed_timepoint, + time_window, +): + """Analyze how tracks cluster in PC space within the time window.""" + # Get points within time window + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + window_points = pca_result[time_mask] + window_tracks = track_ids[time_mask] + window_labels = labels[time_mask] + + # Calculate mean position for each track + track_means = {} + phenotype_tracks = [] + + for track_id in np.unique(window_tracks): + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.any(track_mask): + track_means[track_id] = np.mean(window_points[track_mask], axis=0) + phenotype_tracks.append(track_id) + + if len(phenotype_tracks) < 2: + return None + + # Calculate pairwise distances between track means + track_positions = np.array([track_means[tid] for tid in phenotype_tracks]) + distances = pdist(track_positions) + mean_distance = np.mean(distances) + std_distance = np.std(distances) + + # Calculate spread within each track + track_spreads = {} + for track_id in phenotype_tracks: + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.sum(track_mask) > 1: + track_points = window_points[track_mask] + spread = np.mean(pdist(track_points)) + track_spreads[track_id] = spread + + mean_spread = np.mean(list(track_spreads.values())) if track_spreads else 0 + + return { + "n_tracks": len(phenotype_tracks), + "mean_inter_track_distance": mean_distance, + "std_inter_track_distance": std_distance, + "mean_intra_track_spread": mean_spread, + "clustering_ratio": mean_distance / mean_spread if mean_spread > 0 else np.inf, + } + + +def analyze_pc_distributions( + pca_result, + labels, + phenotype_of_interest, + time_points=None, + seed_timepoint=None, + time_window=None, +): + """Analyze the distributions of each PC for phenotype vs background.""" + n_components = pca_result.shape[1] + results = [] + + for i in range(n_components): + # Get phenotype and background points + if ( + time_points is not None + and seed_timepoint is not None + and time_window is not None + ): + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + pc_values_phenotype = pca_result[ + time_mask & (labels == phenotype_of_interest), i + ] + pc_values_background = pca_result[ + time_mask & (labels != phenotype_of_interest), i + ] + else: + pc_values_phenotype = pca_result[labels == phenotype_of_interest, i] + pc_values_background = pca_result[labels != phenotype_of_interest, i] + + # Calculate basic statistics + stats = { + "PC": i + 1, + "phenotype_mean": np.mean(pc_values_phenotype), + "background_mean": np.mean(pc_values_background), + "phenotype_std": np.std(pc_values_phenotype), + "background_std": np.std(pc_values_background), + "separation": abs( + np.mean(pc_values_phenotype) - np.mean(pc_values_background) + ) + / (np.std(pc_values_phenotype) + np.std(pc_values_background)), + } + + # Check for multimodality using a simple peak detection + hist, bins = np.histogram(pc_values_phenotype, bins="auto") + peaks = len( + [ + i + for i in range(1, len(hist) - 1) + if hist[i] > hist[i - 1] and hist[i] > hist[i + 1] + ] + ) + stats["n_peaks"] = peaks + + results.append(stats) + + return pd.DataFrame(results) + + +def analyze_embeddings_with_pca( + embedding_path, + annotation_path, + phenotype_of_interest=2, + n_components=8, + seed_timepoint=55, + time_window=10, +): + # Load embeddings + embedding_dataset = read_embedding_dataset(embedding_path) + features = embedding_dataset["features"] + track_ids = embedding_dataset["track_id"].values + fovs = embedding_dataset["fov_name"].values + # Add time information for ordering points + time_points = embedding_dataset["t"].values + + # Load annotations + annotations_df = pd.read_csv(annotation_path) + + # Create a mapping dictionary for annotations + annotation_map = { + (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] + for _, row in annotations_df.iterrows() + } + + # Create labels array, -1 for unannotated cells + labels = np.array( + [ + annotation_map.get((str(fov), int(track_id)), -1) + for fov, track_id in zip(fovs, track_ids) + ] + ) + + # Scale the features + scaler = StandardScaler() + scaled_features = scaler.fit_transform(features.values) + + # Perform PCA + pca = PCA(n_components=n_components) + pca_result = pca.fit_transform(scaled_features) + + # Calculate explained variance + explained_variance_ratio = pca.explained_variance_ratio_ + cumulative_variance_ratio = np.cumsum(explained_variance_ratio) + + # Create track-specific colors for the phenotype of interest + phenotype_mask = labels == phenotype_of_interest + tracks_of_interest = np.unique(track_ids[phenotype_mask]) + track_colors = plt.cm.tab10(np.linspace(0, 1, len(tracks_of_interest))) + track_color_map = dict(zip(tracks_of_interest, track_colors)) + + # Create plots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) + + # Scree plot + ax1.plot(range(1, n_components + 1), explained_variance_ratio, "bo-") + ax1.plot(range(1, n_components + 1), cumulative_variance_ratio, "ro-") + ax1.set_xlabel("Principal Component") + ax1.set_ylabel("Explained Variance Ratio") + ax1.set_title("Scree Plot") + ax1.legend(["Individual", "Cumulative"]) + + # First two components plot + # Plot other phenotypes in gray + other_mask = labels != phenotype_of_interest + ax2.scatter( + pca_result[other_mask, 0], + pca_result[other_mask, 1], + alpha=0.1, + color="gray", + label="Other cells", + s=10, + ) + + # Plot each track of the phenotype of interest with decreasing opacity + for track_id in tracks_of_interest: + track_mask = (track_ids == track_id) & phenotype_mask + track_points = pca_result[track_mask] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points = track_points[sort_idx] + track_times = track_times[sort_idx] + + # Select points within the time window + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + if np.any(time_mask): # Only plot if there are points in the window + window_points = track_points[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points)): + ax2.scatter( + window_points[idx, 0], + window_points[idx, 1], + color=track_color_map[track_id], + alpha=alphas[idx], + s=50, + label=( + f"Track {track_id}" if idx == len(window_points) - 1 else None + ), + ) + + ax2.set_xlabel("First Principal Component") + ax2.set_ylabel("Second Principal Component") + ax2.set_title( + f"First Two Principal Components - Phenotype {phenotype_of_interest}\nTime window: {seed_timepoint}±{time_window}" + ) + ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + + plt.tight_layout() + plt.show() + + # Pairwise component plots + fig, axes = plt.subplots(n_components, n_components, figsize=(20, 20)) + + for i in range(n_components): + for j in range(n_components): + if i != j: + # Plot other points first + axes[i, j].scatter( + pca_result[other_mask, j], + pca_result[other_mask, i], + alpha=0.1, + color="gray", + s=5, + ) + + # Plot each track with decreasing opacity + for track_id in tracks_of_interest: + track_mask = (track_ids == track_id) & phenotype_mask + track_points_j = pca_result[track_mask, j] + track_points_i = pca_result[track_mask, i] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points_j = track_points_j[sort_idx] + track_points_i = track_points_i[sort_idx] + track_times = track_times[sort_idx] + + # Select points within the time window + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + if np.any(time_mask): # Only plot if there are points in the window + window_points_j = track_points_j[time_mask] + window_points_i = track_points_i[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points_j)): + axes[i, j].scatter( + window_points_j[idx], + window_points_i[idx], + color=track_color_map[track_id], + alpha=alphas[idx], + s=30, + ) + + axes[i, j].set_xlabel(f"PC{j+1}") + axes[i, j].set_ylabel(f"PC{i+1}") + else: + # On diagonal, show distribution + sns.histplot( + pca_result[other_mask, i], ax=axes[i, i], color="gray", alpha=0.3 + ) + for track_id in tracks_of_interest: + track_mask = (track_ids == track_id) & phenotype_mask + # For histograms, use all points in the time window + time_mask = ( + time_points[track_mask] >= seed_timepoint - time_window + ) & (time_points[track_mask] <= seed_timepoint + time_window) + if np.any(time_mask): + sns.histplot( + pca_result[track_mask][time_mask, i], + ax=axes[i, i], + color=track_color_map[track_id], + alpha=0.5, + ) + axes[i, i].set_xlabel(f"PC{i+1}") + + plt.tight_layout() + plt.show() + + # Print variance explained + print("\nExplained variance ratio by component:") + for i, var in enumerate(explained_variance_ratio): + print(f"PC{i+1}: {var:.3f} ({cumulative_variance_ratio[i]:.3f} cumulative)") + + # Add analysis of PC loadings + pc_analysis = analyze_pc_loadings(pca) + print("\nPC Loading Analysis:") + print(pc_analysis.to_string(index=False)) + + # Add analysis of track clustering + cluster_analysis = analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + phenotype_of_interest, + seed_timepoint, + time_window, + ) + + if cluster_analysis: + print("\nTrack Clustering Analysis:") + print(f"Number of tracks in window: {cluster_analysis['n_tracks']}") + print( + f"Mean distance between tracks: {cluster_analysis['mean_inter_track_distance']:.3f}" + ) + print( + f"Mean spread within tracks: {cluster_analysis['mean_intra_track_spread']:.3f}" + ) + print( + f"Clustering ratio (inter/intra): {cluster_analysis['clustering_ratio']:.3f}" + ) + print("(Lower clustering ratio suggests tighter clustering)") + + # Add distribution analysis + dist_analysis = analyze_pc_distributions( + pca_result, + labels, + phenotype_of_interest, + time_points, + seed_timepoint, + time_window, + ) + print("\nPC Distribution Analysis:") + print( + "(Separation score > 1 suggests good separation between phenotype and background)" + ) + print(dist_analysis.to_string(index=False)) + + return ( + pca, + pca_result, + explained_variance_ratio, + labels, + pc_analysis, + cluster_analysis, + dist_analysis, + ) + + +# %% +if __name__ == "__main__": + embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" + ( + pca, + pca_result, + variance_ratio, + labels, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path, + phenotype_of_interest=1, + seed_timepoint=55, + time_window=10, + ) + +# %% From 8a2073b4ddb3c074ced97bcc0b2386fab95f19bf Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 16 Jan 2025 11:25:30 -0800 Subject: [PATCH 22/38] adding the random track sampling --- .../pseudotime_analysis/pca_analysis.py | 203 ++++++++++++++---- 1 file changed, 164 insertions(+), 39 deletions(-) diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py index d8bc91204..3aed085ce 100644 --- a/applications/pseudotime_analysis/pca_analysis.py +++ b/applications/pseudotime_analysis/pca_analysis.py @@ -8,6 +8,10 @@ from viscy.representation.embedding_writer import read_embedding_dataset from scipy.spatial.distance import pdist, squareform +# Set global random seed for reproducibility +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) + def analyze_pc_loadings(pca, feature_names=None, top_n=5): """Analyze which features contribute most to each PC.""" @@ -157,36 +161,129 @@ def analyze_pc_distributions( def analyze_embeddings_with_pca( embedding_path, - annotation_path, - phenotype_of_interest=2, + annotation_path=None, + phenotype_of_interest=None, + n_random_tracks=10, n_components=8, - seed_timepoint=55, + seed_timepoint=None, time_window=10, + fov_patterns=None, ): + """Analyze embeddings using PCA, either for specific phenotypes or random tracks. + + Args: + embedding_path: Path to embedding zarr file + annotation_path: Optional path to annotation CSV file. If None, uses random tracks + phenotype_of_interest: Which phenotype to analyze (only used if annotation_path is provided) + n_random_tracks: Number of random tracks to select (only used if annotation_path is None) + n_components: Number of PCA components + seed_timepoint: Center of time window. If None, uses all timepoints + time_window: Size of time window (+/-). Only used if seed_timepoint is not None + fov_patterns: List of patterns to filter FOVs (e.g. ['/C/2/*', '/B/3/*']). + Optional even when using annotation_path - can be used to restrict + analysis to specific FOVs while still using phenotype information. + """ + if annotation_path is None: + print(f"\nUsing random tracks (global seed: {RANDOM_SEED})") + + if seed_timepoint is None: + print("\nUsing all timepoints") + else: + print(f"\nUsing time window: {seed_timepoint}±{time_window}") + # Load embeddings embedding_dataset = read_embedding_dataset(embedding_path) features = embedding_dataset["features"] track_ids = embedding_dataset["track_id"].values fovs = embedding_dataset["fov_name"].values - # Add time information for ordering points time_points = embedding_dataset["t"].values - # Load annotations - annotations_df = pd.read_csv(annotation_path) + # Filter FOVs if patterns are provided + if fov_patterns is not None: + print(f"\nFiltering FOVs with patterns: {fov_patterns}") + fov_mask = np.zeros_like(fovs, dtype=bool) + for pattern in fov_patterns: + fov_mask |= np.char.find(fovs.astype(str), pattern) >= 0 + + # Update all arrays with the FOV mask + features = features[fov_mask] + track_ids = track_ids[fov_mask] + fovs = fovs[fov_mask] + time_points = time_points[fov_mask] + + print(f"Found {len(np.unique(fovs))} FOVs matching patterns") + + # Get tracks of interest + if annotation_path is not None: + # Load annotations and get phenotype tracks + annotations_df = pd.read_csv(annotation_path) + annotation_map = { + (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] + for _, row in annotations_df.iterrows() + } + labels = np.array( + [ + annotation_map.get((str(fov), int(track_id)), -1) + for fov, track_id in zip(fovs, track_ids) + ] + ) + selection_mask = labels == phenotype_of_interest + tracks_of_interest = np.unique(track_ids[selection_mask]) + other_mask = ~selection_mask + mode = f"phenotype {phenotype_of_interest}" + else: + # Select random tracks from different FOVs when possible + # Create a mapping of FOV to tracks + fov_track_map = {} + for fov, track_id in zip(fovs, track_ids): + if fov not in fov_track_map: + fov_track_map[fov] = [] + if track_id not in fov_track_map[fov]: # Avoid duplicates + fov_track_map[fov].append(track_id) + + # Get list of all FOVs + available_fovs = list(fov_track_map.keys()) + tracks_of_interest = [] + + # First, try to get one track from each FOV + np.random.shuffle(available_fovs) # Randomize FOV order + for fov in available_fovs: + if len(tracks_of_interest) < n_random_tracks: + # Randomly select a track from this FOV + track = np.random.choice(fov_track_map[fov]) + tracks_of_interest.append(track) + else: + break + + # If we still need more tracks, randomly select from remaining tracks + if len(tracks_of_interest) < n_random_tracks: + # Get all remaining tracks that aren't already selected + remaining_tracks = [ + track + for track in np.unique(track_ids) + if track not in tracks_of_interest + ] + # Select additional tracks + additional_tracks = np.random.choice( + remaining_tracks, + size=min( + n_random_tracks - len(tracks_of_interest), len(remaining_tracks) + ), + replace=False, + ) + tracks_of_interest.extend(additional_tracks) - # Create a mapping dictionary for annotations - annotation_map = { - (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] - for _, row in annotations_df.iterrows() - } + tracks_of_interest = np.array(tracks_of_interest) + selection_mask = np.isin(track_ids, tracks_of_interest) + other_mask = ~selection_mask + labels = np.where(selection_mask, 1, 0) + mode = "random tracks" - # Create labels array, -1 for unannotated cells - labels = np.array( - [ - annotation_map.get((str(fov), int(track_id)), -1) - for fov, track_id in zip(fovs, track_ids) - ] - ) + # Print selected tracks with their FOVs + print("\nSelected tracks:") + for track in tracks_of_interest: + track_fovs = np.unique(fovs[track_ids == track]) + print(f"Track {track}: FOV {track_fovs[0]}") # Scale the features scaler = StandardScaler() @@ -200,9 +297,7 @@ def analyze_embeddings_with_pca( explained_variance_ratio = pca.explained_variance_ratio_ cumulative_variance_ratio = np.cumsum(explained_variance_ratio) - # Create track-specific colors for the phenotype of interest - phenotype_mask = labels == phenotype_of_interest - tracks_of_interest = np.unique(track_ids[phenotype_mask]) + # Create track-specific colors track_colors = plt.cm.tab10(np.linspace(0, 1, len(tracks_of_interest))) track_color_map = dict(zip(tracks_of_interest, track_colors)) @@ -218,8 +313,7 @@ def analyze_embeddings_with_pca( ax1.legend(["Individual", "Cumulative"]) # First two components plot - # Plot other phenotypes in gray - other_mask = labels != phenotype_of_interest + # Plot other tracks/cells in gray ax2.scatter( pca_result[other_mask, 0], pca_result[other_mask, 1], @@ -229,9 +323,9 @@ def analyze_embeddings_with_pca( s=10, ) - # Plot each track of the phenotype of interest with decreasing opacity + # Plot tracks of interest with decreasing opacity for track_id in tracks_of_interest: - track_mask = (track_ids == track_id) & phenotype_mask + track_mask = track_ids == track_id track_points = pca_result[track_mask] track_times = time_points[track_mask] @@ -240,10 +334,14 @@ def analyze_embeddings_with_pca( track_points = track_points[sort_idx] track_times = track_times[sort_idx] - # Select points within the time window - time_mask = (track_times >= seed_timepoint - time_window) & ( - track_times <= seed_timepoint + time_window - ) + # Apply time window if specified + if seed_timepoint is not None: + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + else: + time_mask = np.ones_like(track_times, dtype=bool) # Use all points + if np.any(time_mask): # Only plot if there are points in the window window_points = track_points[time_mask] window_times = track_times[time_mask] @@ -269,9 +367,10 @@ def analyze_embeddings_with_pca( ax2.set_xlabel("First Principal Component") ax2.set_ylabel("Second Principal Component") - ax2.set_title( - f"First Two Principal Components - Phenotype {phenotype_of_interest}\nTime window: {seed_timepoint}±{time_window}" - ) + title = f"First Two Principal Components - {mode}" + if seed_timepoint is not None: + title += f"\nTime window: {seed_timepoint}±{time_window}" + ax2.set_title(title) ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left") plt.tight_layout() @@ -294,7 +393,7 @@ def analyze_embeddings_with_pca( # Plot each track with decreasing opacity for track_id in tracks_of_interest: - track_mask = (track_ids == track_id) & phenotype_mask + track_mask = track_ids == track_id track_points_j = pca_result[track_mask, j] track_points_i = pca_result[track_mask, i] track_times = time_points[track_mask] @@ -338,7 +437,7 @@ def analyze_embeddings_with_pca( pca_result[other_mask, i], ax=axes[i, i], color="gray", alpha=0.3 ) for track_id in tracks_of_interest: - track_mask = (track_ids == track_id) & phenotype_mask + track_mask = track_ids == track_id # For histograms, use all points in the time window time_mask = ( time_points[track_mask] >= seed_timepoint - time_window @@ -371,7 +470,7 @@ def analyze_embeddings_with_pca( track_ids, time_points, labels, - phenotype_of_interest, + 1 if annotation_path is None else phenotype_of_interest, seed_timepoint, time_window, ) @@ -394,14 +493,14 @@ def analyze_embeddings_with_pca( dist_analysis = analyze_pc_distributions( pca_result, labels, - phenotype_of_interest, - time_points, + 1 if annotation_path is None else phenotype_of_interest, + time_points if seed_timepoint is not None else None, seed_timepoint, time_window, ) print("\nPC Distribution Analysis:") print( - "(Separation score > 1 suggests good separation between phenotype and background)" + "(Separation score > 1 suggests good separation between selected tracks and background)" ) print(dist_analysis.to_string(index=False)) @@ -410,6 +509,7 @@ def analyze_embeddings_with_pca( pca_result, explained_variance_ratio, labels, + tracks_of_interest, pc_analysis, cluster_analysis, dist_analysis, @@ -420,20 +520,45 @@ def analyze_embeddings_with_pca( if __name__ == "__main__": embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" + # %% + # Using phenotype annotations with specific FOVs + print("\nAnalyzing phenotype 1 in specific FOVs:") ( pca, pca_result, variance_ratio, labels, + tracks, pc_analysis, cluster_analysis, dist_analysis, ) = analyze_embeddings_with_pca( embedding_path, - annotation_path, + annotation_path=annotation_path, phenotype_of_interest=1, seed_timepoint=55, time_window=10, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + ) + + # Using random tracks from specific FOVs + print("\nAnalyzing random tracks from specific FOVs:") + ( + pca, + pca_result, + variance_ratio, + labels, + tracks, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path=None, # This triggers random track selection + n_random_tracks=10, + seed_timepoint=55, + time_window=30, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns ) # %% From d8fa880cd9a1b799421fc9b777ee729004131364 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 22 Jan 2025 14:40:53 -0800 Subject: [PATCH 23/38] add script for interactive plot with image display --- .../figures/interactive_plot_wDisplay.py | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py diff --git a/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py b/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py new file mode 100644 index 000000000..43441918c --- /dev/null +++ b/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py @@ -0,0 +1,138 @@ + +# This is a simple example of an interactive plot using Dash. +from pathlib import Path +import dash +from dash import dcc, html +import plotly.express as px +import pandas as pd +import numpy as np +import base64 +from io import BytesIO +from PIL import Image +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +import dash.dependencies as dd + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks + +# Initialize Dash app +app = dash.Dash(__name__) + +# Sample DataFrame for demonstration +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" +) +embedding_dataset = read_embedding_dataset(features_path) +features = embedding_dataset["features"] +scaled_features = StandardScaler().fit_transform(features.values) +pca = PCA(n_components=3) +embedding = pca.fit_transform(scaled_features) +features = ( + features.assign_coords(PCA1=("sample", embedding[:, 0])) + .assign_coords(PCA2=("sample", embedding[:, 1])) + .assign_coords(PCA3=("sample", embedding[:, 2])) + .set_index(sample=["PCA1", "PCA2", "PCA3"], append=True) +) + +df = pd.DataFrame({k: v for k, v in features.coords.items() if k != "features"}) + +# Image paths for each track and time + +data_path = Path( + "/hpc/projects/organelle_phenotyping/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/registered_chunked.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" +) + +# Create scatter plot with hover data (track_id, t, fov_name) +fig = px.scatter( + df, + x="PCA1", + y="PCA2", + color="PCA1", + hover_name="fov_name", + hover_data=["id", "t", "track_id"], # Include track_id and t for image lookup +) + +# Layout of the app +app.layout = html.Div([ + dcc.Graph( + id="scatter-plot", + figure=fig, + ), + html.Div([ + html.Img(id="hover-image", src="", style={"width": "150px", "height": "150px"}) + ]) +]) + +# Helper function to convert numpy array to base64 image +def numpy_to_base64(img_array): + # Clip, normalize, and scale to the range [0, 255] + img_array = np.clip(img_array, img_array.min(), img_array.max()) # Clip values to the expected range + img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min()) # Normalize to [0, 1] + img_array = (img_array * 255).astype(np.uint8) # Scale to [0, 255] and convert to uint8 + + img = Image.fromarray(img_array) + buffered = BytesIO() + img.save(buffered, format="PNG") + return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + + +# Callback to update the image when a point is hovered over +@app.callback( + dd.Output("hover-image", "src"), + [dd.Input("scatter-plot", "hoverData")] +) +def update_image(hoverData): + if hoverData is None: + return "" # Return empty if no hover + + # Extract the necessary information from hoverData + fov_name = hoverData['points'][0]['hovertext'] # fov_name is in hovertext + track_id = hoverData['points'][0]['customdata'][2] # track_id from hover_data + t = hoverData['points'][0]['customdata'][1] # t from hover_data + + print(f"Hovering over: fov_name={fov_name}, track_id={track_id}, t={t}") + + # Lookup the image path based on fov_name, track_id, and t + # image_key = (fov_name, track_id, t) + + # Get the image URL if it exists + # return image_paths.get(image_key, "") # Return empty string if no match + source_channel = ["Phase3D"] + z_range = (33,34) + predict_dataset = dataset_of_tracks( + data_path, + tracks_path, + [fov_name], + [track_id], + z_range=z_range, + source_channel=source_channel, + ) + # image_patch = np.stack([p["anchor"][0, 7].numpy() for p in predict_dataset]) + + # Check if the dataset was retrieved successfully + if not predict_dataset: + print(f"No dataset found for fov_name={fov_name}, track_id={track_id}, t={t}") + return "" # Return empty if no dataset is found + + # Extract the image patch (assuming it's a numpy array) + try: + image_patch = np.stack([p["anchor"][0].numpy() for p in predict_dataset]) + image_patch = image_patch[0,0] + print(f"Image patch shape: {image_patch.shape}") + except Exception as e: + print(f"Error extracting image patch: {e}") + return "" + + # Check if the image is valid (this step is just a safety check) + if image_patch.ndim != 2: + print(f"Invalid image data: image_patch is not 2D.") + return "" + + return numpy_to_base64(image_patch) + +if __name__ == '__main__': + app.run_server(debug=True) \ No newline at end of file From 8326788b3a641815557bcc93d8264e96a5dd6ed4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 23 Jan 2025 08:41:10 -0800 Subject: [PATCH 24/38] gmm clustering and silhoutte score --- .../pseudotime_analysis/pca_analysis.py | 320 +++++++++++++++++- 1 file changed, 316 insertions(+), 4 deletions(-) diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py index 3aed085ce..d79730559 100644 --- a/applications/pseudotime_analysis/pca_analysis.py +++ b/applications/pseudotime_analysis/pca_analysis.py @@ -3,6 +3,8 @@ import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA +from sklearn.mixture import GaussianMixture +from sklearn.metrics import silhouette_score import matplotlib.pyplot as plt import seaborn as sns from viscy.representation.embedding_writer import read_embedding_dataset @@ -159,6 +161,261 @@ def analyze_pc_distributions( return pd.DataFrame(results) +def analyze_gmm_clustering( + pca_result, + track_ids, + time_points, + tracks_of_interest, + n_components_range=range(2, 7), + seed_timepoint=None, + time_window=None, +): + """Analyze clusters using Gaussian Mixture Models.""" + # Get points from tracks of interest + track_mask = np.isin(track_ids, tracks_of_interest) + points = pca_result[track_mask] + track_ids_subset = track_ids[track_mask] + times = time_points[track_mask] + + # Apply time window if specified + if seed_timepoint is not None and time_window is not None: + time_mask = (times >= seed_timepoint - time_window) & ( + times <= seed_timepoint + time_window + ) + points = points[time_mask] + track_ids_subset = track_ids_subset[time_mask] + times = times[time_mask] + + # Try different numbers of components + bic_scores = [] + silhouette_scores = [] + models = [] + + for n_components in n_components_range: + gmm = GaussianMixture( + n_components=n_components, random_state=RANDOM_SEED, n_init=10 + ) + gmm.fit(points) + labels = gmm.predict(points) + + bic_scores.append(gmm.bic(points)) + silhouette_scores.append(silhouette_score(points, labels)) + models.append(gmm) + + # Plot model selection metrics + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + + # BIC plot + ax1.plot(list(n_components_range), bic_scores, "bo-") + ax1.set_xlabel("Number of Components") + ax1.set_ylabel("BIC Score") + ax1.set_title("Model Selection: BIC") + + # Silhouette plot + ax2.plot(list(n_components_range), silhouette_scores, "ro-") + ax2.set_xlabel("Number of Components") + ax2.set_ylabel("Silhouette Score") + ax2.set_title("Model Selection: Silhouette") + + plt.tight_layout() + plt.show() + + # Select best model based on BIC + best_idx = np.argmin(bic_scores) + best_n_components = n_components_range[best_idx] + best_model = models[best_idx] + + # Get cluster assignments + labels = best_model.predict(points) + probs = best_model.predict_proba(points) + + # Plot clustering results + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + # Scatter plot colored by cluster + scatter = ax1.scatter( + points[:, 0], points[:, 1], c=labels, cmap="tab10", alpha=0.6, s=50 + ) + ax1.set_xlabel("PC1") + ax1.set_ylabel("PC2") + ax1.set_title(f"GMM Clustering (n={best_n_components})") + plt.colorbar(scatter, ax=ax1, label="Cluster") + + # Plot cluster assignment probabilities + max_probs = np.max(probs, axis=1) + scatter = ax2.scatter( + points[:, 0], points[:, 1], c=max_probs, cmap="viridis", alpha=0.6, s=50 + ) + ax2.set_xlabel("PC1") + ax2.set_ylabel("PC2") + ax2.set_title("Cluster Assignment Probability") + plt.colorbar(scatter, ax=ax2, label="Probability") + + plt.tight_layout() + plt.show() + + # Analyze cluster composition + cluster_stats = [] + for i in range(best_n_components): + cluster_mask = labels == i + cluster_tracks = np.unique(track_ids_subset[cluster_mask]) + cluster_stats.append( + { + "cluster": i, + "n_points": np.sum(cluster_mask), + "n_tracks": len(cluster_tracks), + "tracks": cluster_tracks, + "mean_prob": np.mean(probs[cluster_mask, i]), + "std_prob": np.std(probs[cluster_mask, i]), + } + ) + + # Print cluster statistics + print(f"\nBest number of clusters (BIC): {best_n_components}") + print("\nCluster Statistics:") + for stats in cluster_stats: + print(f"\nCluster {stats['cluster']}:") + print(f" Points: {stats['n_points']}") + print(f" Tracks: {stats['n_tracks']}") + print(f" Mean probability: {stats['mean_prob']:.3f} ± {stats['std_prob']:.3f}") + print(f" Tracks in cluster: {stats['tracks']}") + + return { + "best_model": best_model, + "best_n_components": best_n_components, + "labels": labels, + "probabilities": probs, + "bic_scores": bic_scores, + "silhouette_scores": silhouette_scores, + "cluster_stats": cluster_stats, + } + + +def analyze_cluster_characteristics( + gmm_results, + pca_result, + track_ids, + time_points, + tracks_of_interest, + pc_analysis=None, + seed_timepoint=None, + time_window=None, +): + """Analyze characteristics of GMM clusters including temporal patterns and PC contributions.""" + # Get points from tracks of interest first + track_mask = np.isin(track_ids, tracks_of_interest) + points = pca_result[track_mask] + track_ids_subset = track_ids[track_mask] + times = time_points[track_mask] + + # Apply time window if specified + if seed_timepoint is not None and time_window is not None: + time_mask = (times >= seed_timepoint - time_window) & ( + times <= seed_timepoint + time_window + ) + points = points[time_mask] + track_ids_subset = track_ids_subset[time_mask] + times = times[time_mask] + + # Get cluster assignments for the filtered points + labels = gmm_results["labels"] + probs = gmm_results["probabilities"] + n_clusters = gmm_results["best_n_components"] + + # Analyze temporal patterns in each cluster + print("\nTemporal patterns in clusters:") + for i in range(n_clusters): + cluster_mask = labels == i + cluster_times = times[cluster_mask] + if len(cluster_times) > 0: + print(f"\nCluster {i}:") + print( + f" Time range: {np.min(cluster_times):.1f} to {np.max(cluster_times):.1f}" + ) + print( + f" Mean time: {np.mean(cluster_times):.1f} ± {np.std(cluster_times):.1f}" + ) + + # Analyze PC contributions to cluster separation + print("\nPC contributions to cluster separation:") + for pc_idx in range(min(4, points.shape[1])): # Analyze first 4 PCs + pc_values = points[:, pc_idx] + cluster_means = [np.mean(pc_values[labels == i]) for i in range(n_clusters)] + cluster_stds = [np.std(pc_values[labels == i]) for i in range(n_clusters)] + + # Calculate separation score (ratio of between-cluster to within-cluster variance) + between_var = np.var(cluster_means) + within_var = np.mean(cluster_stds) + separation_score = between_var / within_var if within_var > 0 else float("inf") + + print(f"\nPC{pc_idx + 1}:") + print(f" Separation score: {separation_score:.3f}") + if pc_analysis is not None: + pc_info = pc_analysis[pc_analysis["PC"] == pc_idx + 1].iloc[0] + print( + f" Top contributing features: {', '.join(pc_info['Top_Features'][:3])}" + ) + + # Print cluster-specific stats + for i in range(n_clusters): + cluster_mask = labels == i + print(f" Cluster {i}: {cluster_means[i]:.3f} ± {cluster_stds[i]:.3f}") + + # Analyze track transitions between clusters + print("\nTrack transitions between clusters:") + for track_id in tracks_of_interest: + track_mask = track_ids_subset == track_id + track_labels = labels[track_mask] + track_times = times[track_mask] + + if len(track_labels) > 1: + # Sort by time + sort_idx = np.argsort(track_times) + track_labels = track_labels[sort_idx] + track_times = track_times[sort_idx] + + # Find transitions + transitions = np.where(track_labels[1:] != track_labels[:-1])[0] + if len(transitions) > 0: + print(f"\nTrack {track_id}:") + for trans_idx in transitions: + from_cluster = track_labels[trans_idx] + to_cluster = track_labels[trans_idx + 1] + trans_time = track_times[trans_idx + 1] + print(f" {trans_time:.1f}: {from_cluster} -> {to_cluster}") + + return { + "temporal_patterns": { + i: { + "mean_time": np.mean(times[labels == i]), + "std_time": np.std(times[labels == i]), + } + for i in range(n_clusters) + }, + "pc_contributions": { + f"PC{pc_idx + 1}": { + "separation_score": ( + np.var( + [ + np.mean(points[labels == i, pc_idx]) + for i in range(n_clusters) + ] + ) + / np.mean( + [np.std(points[labels == i, pc_idx]) for i in range(n_clusters)] + ) + if np.mean( + [np.std(points[labels == i, pc_idx]) for i in range(n_clusters)] + ) + > 0 + else float("inf") + ) + } + for pc_idx in range(min(4, points.shape[1])) + }, + } + + def analyze_embeddings_with_pca( embedding_path, annotation_path=None, @@ -504,6 +761,7 @@ def analyze_embeddings_with_pca( ) print(dist_analysis.to_string(index=False)) + # Return PCA results and additional data needed for clustering return ( pca, pca_result, @@ -513,6 +771,8 @@ def analyze_embeddings_with_pca( pc_analysis, cluster_analysis, dist_analysis, + track_ids, + time_points, ) @@ -520,7 +780,7 @@ def analyze_embeddings_with_pca( if __name__ == "__main__": embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" - # %% + # Using phenotype annotations with specific FOVs print("\nAnalyzing phenotype 1 in specific FOVs:") ( @@ -532,13 +792,39 @@ def analyze_embeddings_with_pca( pc_analysis, cluster_analysis, dist_analysis, + track_ids, + time_points, ) = analyze_embeddings_with_pca( embedding_path, annotation_path=annotation_path, phenotype_of_interest=1, seed_timepoint=55, time_window=10, - fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], + ) + + # Run GMM clustering analysis separately + print("\nPerforming GMM clustering analysis...") + gmm_results = analyze_gmm_clustering( + pca_result, + track_ids, + time_points, + tracks, + seed_timepoint=55, + time_window=10, + ) + + # Analyze cluster characteristics + print("\nAnalyzing cluster characteristics...") + cluster_characteristics = analyze_cluster_characteristics( + gmm_results, + pca_result, + track_ids, + time_points, + tracks, + pc_analysis=pc_analysis, + seed_timepoint=55, + time_window=10, ) # Using random tracks from specific FOVs @@ -552,13 +838,39 @@ def analyze_embeddings_with_pca( pc_analysis, cluster_analysis, dist_analysis, + track_ids, + time_points, ) = analyze_embeddings_with_pca( embedding_path, - annotation_path=None, # This triggers random track selection + annotation_path=None, n_random_tracks=10, seed_timepoint=55, time_window=30, - fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], + ) + # %% + # Run GMM clustering analysis for random tracks + print("\nPerforming GMM clustering analysis for random tracks...") + gmm_results = analyze_gmm_clustering( + pca_result, + track_ids, + time_points, + tracks, + seed_timepoint=55, + time_window=30, + ) + + # Analyze cluster characteristics for random tracks + print("\nAnalyzing cluster characteristics for random tracks...") + cluster_characteristics = analyze_cluster_characteristics( + gmm_results, + pca_result, + track_ids, + time_points, + tracks, + pc_analysis=pc_analysis, + seed_timepoint=55, + time_window=30, ) # %% From bf0bf1ecbe8433b2a758bb39f4abf3f781330dfb Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 23 Jan 2025 08:41:44 -0800 Subject: [PATCH 25/38] fix the issue with distance.py for plotting the pairwise matrix --- viscy/representation/evaluation/distance.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index cefd8b6e6..85b79a701 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -305,7 +305,6 @@ def compute_embedding_distances( prediction_path: Path, output_path: Path, distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", - verbose: bool = False, ) -> pd.DataFrame: """ Compute and save pairwise distances between embeddings. @@ -318,8 +317,6 @@ def compute_embedding_distances( name of saved CSV file distance_metric : str, optional Distance metric to use for computing distances between embeddings - verbose : bool, optional - If True, plots the distance matrix visualization Returns ------- @@ -340,15 +337,15 @@ def compute_embedding_distances( if distance_metric == "euclidean": cross_dist /= np.sqrt(features.shape[1]) - if verbose: - # Plot the distance matrix - plt.figure(figsize=(10, 10)) - plt.imshow(cross_dist, cmap="viridis") - plt.colorbar(label=f"{distance_metric.capitalize()} Distance") - plt.title(f"{distance_metric.capitalize()} Distance Matrix") - plt.tight_layout() - plt.show() - + # Plot the distance matrix + plt.figure(figsize=(10, 10)) + plt.imshow(cross_dist, cmap="viridis") + plt.colorbar(label=f"{distance_metric.capitalize()} Distance") + plt.title(f"{distance_metric.capitalize()} Distance Matrix") + plt.tight_layout() + base_name = prediction_path.stem + plt.savefig(output_path / f"{base_name}_distance_matrix.png", dpi=600) + plt.close() rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) # Compute piece-wise dissimilarity and rank difference From bec0de2688d2c651b4559bb8b151f3a2dfd7dd48 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 14:54:45 -0800 Subject: [PATCH 26/38] prototype wrapping the timearrow model and dataloader that takes ome-zarr to spit to tarrowdataset --- viscy/data/tarrow.py | 176 ++++++++++++++++++++++++++++++ viscy/representation/timearrow.py | 133 ++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 viscy/data/tarrow.py create mode 100644 viscy/representation/timearrow.py diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py new file mode 100644 index 000000000..c6e29f9fd --- /dev/null +++ b/viscy/data/tarrow.py @@ -0,0 +1,176 @@ +from pathlib import Path +import numpy as np +from iohub.ngff import Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from tarrow.data.tarrow_dataset import TarrowDataset +from torch.utils.data import DataLoader, ConcatDataset +import torch + + +class TarrowDataModule(LightningDataModule): + def __init__( + self, + ome_zarr_path: str | Path, + channel_name: str, + train_split: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + prefetch_factor: int | None = None, + include_fov_names: list[str] = [], + train_samples_per_epoch: int = 100000, + val_samples_per_epoch: int = 10000, + resolution: int = 0, + z_slice: int = 0, + **kwargs, + ): + """Initialize TarrowDataModule. + + Args: + ome_zarr_path: Path to OME-Zarr file + channel_name: Name of the channel to load + train_split: Fraction of data to use for training (0.0 to 1.0) + batch_size: Batch size for dataloaders + num_workers: Number of workers for dataloaders + prefetch_factor: Prefetch factor for dataloaders + include_fov_names: List of FOV names to include. If empty, use all FOVs. + train_samples_per_epoch: Number of training samples per epoch + val_samples_per_epoch: Number of validation samples per epoch + resolution: Resolution level to load from OME-Zarr + z_slice: Z-slice to load + **kwargs: Additional arguments passed to TarrowDataset + """ + super().__init__() + self.ome_zarr_path = ome_zarr_path + self.channel_name = channel_name + self.train_split = train_split + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.include_fov_names = include_fov_names + self.train_samples_per_epoch = train_samples_per_epoch + self.val_samples_per_epoch = val_samples_per_epoch + self.resolution = resolution + self.z_slice = z_slice + self.kwargs = kwargs + + def _get_channel_index(self, plate) -> int: + """Get the index of the specified channel from the plate metadata. + + Args: + plate: OME-Zarr plate object + + Returns: + Index of the specified channel + + Raises: + ValueError: If channel_name is not found in available channels + """ + # Get channel names from first position + _, first_pos = next(plate.positions()) + try: + return first_pos.channel_names.index(self.channel_name) + except ValueError: + available_channels = ", ".join(first_pos.channel_names) + raise ValueError( + f"Channel '{self.channel_name}' not found. Available channels: {available_channels}" + ) + + def _load_images( + self, positions: list[Position], channel_idx: int + ) -> list[np.ndarray]: + """Load all images from positions into memory. + + Args: + positions: List of positions to load + channel_idx: Index of channel to load + + Returns: + List of 2D numpy arrays + """ + imgs = [] + for pos in positions: + img_arr = pos[str(self.resolution)] + # Load all timepoints for this position + for t in range(len(img_arr)): + imgs.append(img_arr[t, channel_idx, self.z_slice]) + return imgs + + def setup(self, stage: str): + plate = open_ome_zarr(self.ome_zarr_path, mode="r") + + # Get channel index once + channel_idx = self._get_channel_index(plate) + + # Get the positions to load + if self.include_fov_names: + positions = [] + for fov_str, pos in plate.positions(): + normalized_include_fovs = [ + f.lstrip("/") for f in self.include_fov_names + ] + if fov_str in normalized_include_fovs: + positions.append(pos) + else: + positions = [pos for _, pos in plate.positions()] + + # Load all images into memory using the pre-determined channel index + imgs = self._load_images(positions, channel_idx) + + # Calculate split point + split_idx = int(len(imgs) * self.train_split) + + if stage == "fit": + # Create training dataset with first train_split% of images + self.train_dataset = TarrowDataset( + imgs=imgs[:split_idx], + **self.kwargs, + ) + + # Create validation dataset with remaining images + self.val_dataset = TarrowDataset( + imgs=imgs[split_idx:], + **{k: v for k, v in self.kwargs.items() if k != "augmenter"}, + ) + + elif stage == "test": + raise NotImplementedError(f"Invalid stage: {stage}") + elif stage == "predict": + raise NotImplementedError(f"Invalid stage: {stage}") + else: + raise NotImplementedError(f"Invalid stage: {stage}") + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + sampler=torch.utils.data.RandomSampler( + self.train_dataset, + replacement=True, + num_samples=self.train_samples_per_epoch, + ), + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + sampler=torch.utils.data.RandomSampler( + self.val_dataset, + replacement=True, + num_samples=self.val_samples_per_epoch, + ), + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py new file mode 100644 index 000000000..ae619d080 --- /dev/null +++ b/viscy/representation/timearrow.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from lightning.pytorch import LightningModule +from tarrow.models import TimeArrowNet +from tarrow.models.losses import DecorrelationLoss +from torch.optim import Adam +from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau + + +class TarrowModule(LightningModule): + def __init__( + self, + backbone="unet", + projection_head="minimal_batchnorm", + classification_head="minimal", + n_frames=2, + n_features=16, + n_input_channels=1, + symmetric=False, + learning_rate=1e-4, + weight_decay=1e-6, + lambda_decorrelation=0.01, + lr_scheduler="cyclic", + lr_patience=50, + **kwargs, + ): + """Lightning Module wrapper for TimeArrowNet. + + Args: + backbone: Dense network architecture + projection_head: Dense projection head architecture + classification_head: Classification head architecture + n_frames: Number of input frames + n_features: Number of output features from the backbone + n_input_channels: Number of input channels + symmetric: If True, use permutation-equivariant classification head + learning_rate: Learning rate for optimizer + weight_decay: Weight decay for optimizer + lambda_decorrelation: Prefactor of decorrelation loss + lr_scheduler: Learning rate scheduler ('plateau' or 'cyclic') + lr_patience: Patience for learning rate scheduler + """ + super().__init__() + self.save_hyperparameters() + + self.model = TimeArrowNet( + backbone=backbone, + projection_head=projection_head, + classification_head=classification_head, + n_frames=n_frames, + n_features=n_features, + n_input_channels=n_input_channels, + symmetric=symmetric, + device="cpu", # Let Lightning handle device placement + ) + + self.criterion = nn.CrossEntropyLoss(reduction="none") + self.criterion_decorr = DecorrelationLoss() + + def forward(self, x): + return self.model(x, mode="both") + + def _shared_step(self, batch, batch_idx, phase="train"): + x, y = batch + out, pro = self(x) + + if out.ndim > 2: + y = torch.broadcast_to( + y.unsqueeze(1).unsqueeze(1), (y.shape[0],) + out.shape[-2:] + ) + loss = self.criterion(out, y) + loss = torch.mean(loss, tuple(range(1, loss.ndim))) + y = y[:, 0, 0] + u_avg = torch.mean(out, tuple(range(2, out.ndim))) + else: + u_avg = out + loss = self.criterion(out, y) + + pred = torch.argmax(u_avg.detach(), 1) + loss = torch.mean(loss) + + # decorrelation loss + pro_batched = pro.flatten(0, 1) + loss_decorr = self.criterion_decorr(pro_batched) + loss_all = loss + self.hparams.lambda_decorrelation * loss_decorr + + acc = torch.mean((pred == y).float()) + + self.log(f"{phase}_loss", loss, prog_bar=True) + self.log(f"{phase}_loss_decorr", loss_decorr, prog_bar=True) + self.log(f"{phase}_accuracy", acc, prog_bar=True) + self.log(f"{phase}_pred1_ratio", pred.sum().float() / len(pred)) + + return loss_all + + def training_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, "val") + + def configure_optimizers(self): + optimizer = Adam( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) + + if self.hparams.lr_scheduler == "plateau": + scheduler = ReduceLROnPlateau( + optimizer, + factor=0.2, + patience=self.hparams.lr_patience, + verbose=True, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + elif self.hparams.lr_scheduler == "cyclic": + scheduler = CyclicLR( + optimizer, + base_lr=self.hparams.learning_rate, + max_lr=self.hparams.learning_rate * 10, + cycle_momentum=False, + step_size_up=self.trainer.estimated_stepping_batches, + scale_mode="cycle", + scale_fn=lambda x: 0.9**x, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} From fda9030853a8c4517b767b03c8d493ca5f7ea40d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 15:22:55 -0800 Subject: [PATCH 27/38] fixing tensorboard logging for metrics and losses. Adding some coments to clarify metrics. fixing docstrings to be numpy style --- viscy/data/tarrow.py | 113 +++++++++++++++++++++++------- viscy/representation/timearrow.py | 95 +++++++++++++++++++------ 2 files changed, 161 insertions(+), 47 deletions(-) diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index c6e29f9fd..fa7e0a711 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -8,6 +8,36 @@ class TarrowDataModule(LightningDataModule): + """Lightning DataModule for TimeArrowNet training. + + Parameters + ---------- + ome_zarr_path : str or Path + Path to OME-Zarr file + channel_name : str + Name of the channel to load + train_split : float, default=0.8 + Fraction of data to use for training (0.0 to 1.0) + batch_size : int, default=16 + Batch size for dataloaders + num_workers : int, default=8 + Number of workers for dataloaders + prefetch_factor : int, optional + Prefetch factor for dataloaders + include_fov_names : list[str], default=[] + List of FOV names to include. If empty, use all FOVs + train_samples_per_epoch : int, default=100000 + Number of training samples per epoch + val_samples_per_epoch : int, default=10000 + Number of validation samples per epoch + resolution : int, default=0 + Resolution level to load from OME-Zarr + z_slice : int, default=0 + Z-slice to load + **kwargs : dict + Additional arguments passed to TarrowDataset + """ + def __init__( self, ome_zarr_path: str | Path, @@ -23,22 +53,6 @@ def __init__( z_slice: int = 0, **kwargs, ): - """Initialize TarrowDataModule. - - Args: - ome_zarr_path: Path to OME-Zarr file - channel_name: Name of the channel to load - train_split: Fraction of data to use for training (0.0 to 1.0) - batch_size: Batch size for dataloaders - num_workers: Number of workers for dataloaders - prefetch_factor: Prefetch factor for dataloaders - include_fov_names: List of FOV names to include. If empty, use all FOVs. - train_samples_per_epoch: Number of training samples per epoch - val_samples_per_epoch: Number of validation samples per epoch - resolution: Resolution level to load from OME-Zarr - z_slice: Z-slice to load - **kwargs: Additional arguments passed to TarrowDataset - """ super().__init__() self.ome_zarr_path = ome_zarr_path self.channel_name = channel_name @@ -56,14 +70,20 @@ def __init__( def _get_channel_index(self, plate) -> int: """Get the index of the specified channel from the plate metadata. - Args: - plate: OME-Zarr plate object + Parameters + ---------- + plate : iohub.ngff.Plate + OME-Zarr plate object - Returns: + Returns + ------- + int Index of the specified channel - Raises: - ValueError: If channel_name is not found in available channels + Raises + ------ + ValueError + If channel_name is not found in available channels """ # Get channel names from first position _, first_pos = next(plate.positions()) @@ -80,11 +100,16 @@ def _load_images( ) -> list[np.ndarray]: """Load all images from positions into memory. - Args: - positions: List of positions to load - channel_idx: Index of channel to load + Parameters + ---------- + positions : list[Position] + List of positions to load + channel_idx : int + Index of channel to load - Returns: + Returns + ------- + list[np.ndarray] List of 2D numpy arrays """ imgs = [] @@ -96,6 +121,18 @@ def _load_images( return imgs def setup(self, stage: str): + """Set up the data module for a specific stage. + + Parameters + ---------- + stage : str + Stage to set up for ("fit", "test", or "predict") + + Raises + ------ + NotImplementedError + If stage is not "fit" + """ plate = open_ome_zarr(self.ome_zarr_path, mode="r") # Get channel index once @@ -140,6 +177,13 @@ def setup(self, stage: str): raise NotImplementedError(f"Invalid stage: {stage}") def train_dataloader(self): + """Create the training dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for training data with random sampling + """ return DataLoader( self.train_dataset, sampler=torch.utils.data.RandomSampler( @@ -154,6 +198,13 @@ def train_dataloader(self): ) def val_dataloader(self): + """Create the validation dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for validation data with random sampling + """ return DataLoader( self.val_dataset, sampler=torch.utils.data.RandomSampler( @@ -168,6 +219,18 @@ def val_dataloader(self): ) def test_dataloader(self): + """Create the test dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for test data without shuffling + + Raises + ------ + NotImplementedError + Test stage is not implemented yet + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index ae619d080..4fb71440e 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -8,6 +8,38 @@ class TarrowModule(LightningModule): + """Lightning Module wrapper for TimeArrowNet. + + Parameters + ---------- + backbone : str, default="unet" + Dense network architecture + projection_head : str, default="minimal_batchnorm" + Dense projection head architecture + classification_head : str, default="minimal" + Classification head architecture + n_frames : int, default=2 + Number of input frames + n_features : int, default=16 + Number of output features from the backbone + n_input_channels : int, default=1 + Number of input channels + symmetric : bool, default=False + If True, use permutation-equivariant classification head + learning_rate : float, default=1e-4 + Learning rate for optimizer + weight_decay : float, default=1e-6 + Weight decay for optimizer + lambda_decorrelation : float, default=0.01 + Prefactor of decorrelation loss + lr_scheduler : str, default="cyclic" + Learning rate scheduler ('plateau' or 'cyclic') + lr_patience : int, default=50 + Patience for learning rate scheduler + cam_size : tuple or int, optional + Size of the class activation map (H, W). If None, use input size. + """ + def __init__( self, backbone="unet", @@ -22,24 +54,9 @@ def __init__( lambda_decorrelation=0.01, lr_scheduler="cyclic", lr_patience=50, + cam_size=None, **kwargs, ): - """Lightning Module wrapper for TimeArrowNet. - - Args: - backbone: Dense network architecture - projection_head: Dense projection head architecture - classification_head: Classification head architecture - n_frames: Number of input frames - n_features: Number of output features from the backbone - n_input_channels: Number of input channels - symmetric: If True, use permutation-equivariant classification head - learning_rate: Learning rate for optimizer - weight_decay: Weight decay for optimizer - lambda_decorrelation: Prefactor of decorrelation loss - lr_scheduler: Learning rate scheduler ('plateau' or 'cyclic') - lr_patience: Patience for learning rate scheduler - """ super().__init__() self.save_hyperparameters() @@ -51,16 +68,46 @@ def __init__( n_features=n_features, n_input_channels=n_input_channels, symmetric=symmetric, - device="cpu", # Let Lightning handle device placement + cam_size=cam_size, ) self.criterion = nn.CrossEntropyLoss(reduction="none") self.criterion_decorr = DecorrelationLoss() def forward(self, x): + """Forward pass through the model. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, n_frames, channels, height, width) + + Returns + ------- + tuple + Tuple of (output, projection) where: + - output is the classification logits + - projection is the feature space projection + """ return self.model(x, mode="both") - def _shared_step(self, batch, batch_idx, phase="train"): + def _shared_step(self, batch, batch_idx, step="train"): + """Shared step for training and validation. + + Parameters + ---------- + batch : tuple + Tuple of (images, labels) + batch_idx : int + Index of the current batch + step : str, default="train" + Current step type ("train" or "val") + + Returns + ------- + torch.Tensor + Combined loss (classification + decorrelation) + """ x, y = batch out, pro = self(x) @@ -86,10 +133,14 @@ def _shared_step(self, batch, batch_idx, phase="train"): acc = torch.mean((pred == y).float()) - self.log(f"{phase}_loss", loss, prog_bar=True) - self.log(f"{phase}_loss_decorr", loss_decorr, prog_bar=True) - self.log(f"{phase}_accuracy", acc, prog_bar=True) - self.log(f"{phase}_pred1_ratio", pred.sum().float() / len(pred)) + # Main classification loss + self.log(f"loss/{step}_loss", loss, prog_bar=True) + # Decorrelation loss for feature space + self.log(f"loss/{step}_loss_decorr", loss_decorr, prog_bar=True) + # Classification accuracy + self.log(f"metric/{step}_accuracy", acc, prog_bar=True) + # Ratio of positive predictions (class 1) - useful to detect class imbalance + self.log(f"metric/{step}_pred1_ratio", pred.sum().float() / len(pred)) return loss_all From 3f6bc831deb68c8341141ea4045a90891883b977 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 16:20:40 -0800 Subject: [PATCH 28/38] adding tarrow's visualization for tensorboard --- viscy/data/tarrow.py | 6 ++ viscy/representation/timearrow.py | 121 +++++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 1 deletion(-) diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index fa7e0a711..f554544af 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -34,6 +34,10 @@ class TarrowDataModule(LightningDataModule): Resolution level to load from OME-Zarr z_slice : int, default=0 Z-slice to load + pin_memory : bool, default=True + Whether to pin memory + persistent_workers : bool, default=True + Whether to keep the workers alive between epochs **kwargs : dict Additional arguments passed to TarrowDataset """ @@ -51,6 +55,8 @@ def __init__( val_samples_per_epoch: int = 10000, resolution: int = 0, z_slice: int = 0, + pin_memory: bool = True, + persistent_workers: bool = True, **kwargs, ): super().__init__() diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index 4fb71440e..05c056ae1 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -5,6 +5,9 @@ from tarrow.models.losses import DecorrelationLoss from torch.optim import Adam from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau +from lightning.pytorch.callbacks import Callback +from torch.utils.data import DataLoader +import torchvision class TarrowModule(LightningModule): @@ -68,7 +71,6 @@ def __init__( n_features=n_features, n_input_channels=n_input_channels, symmetric=symmetric, - cam_size=cam_size, ) self.criterion = nn.CrossEntropyLoss(reduction="none") @@ -182,3 +184,120 @@ def configure_optimizers(self): scale_fn=lambda x: 0.9**x, ) return {"optimizer": optimizer, "lr_scheduler": scheduler} + + def embedding(self, x): + """Get dense embeddings from the model. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, n_frames, channels, height, width) + + Returns + ------- + torch.Tensor + Dense embeddings from the backbone network + """ + return self.model.embedding(x) + + +class TarrowVisualizationCallback(Callback): + """Callback for visualizing cells and embeddings in TensorBoard. + + Parameters + ---------- + dataset : Dataset + Dataset to visualize + max_samples : int, default=100 + Maximum number of samples to visualize + log_every_n_epochs : int, default=3 + How often to log visualizations + cam_size : tuple or int, optional + Size for class activation maps. If None, use original size + """ + + def __init__(self, dataset, max_samples=100, log_every_n_epochs=3, cam_size=None): + """ + Parameters + ---------- + dataset : Dataset + Dataset to visualize + max_samples : int, default=100 + Maximum number of samples to visualize + log_every_n_epochs : int, default=3 + How often to log visualizations + cam_size : tuple or int, optional + Size for class activation maps. If None, use original size + """ + super().__init__() + self.dataset = dataset + self.max_samples = max_samples + self.log_every_n_epochs = log_every_n_epochs + self.cam_size = cam_size + + def on_train_epoch_end(self, trainer, pl_module): + if (trainer.current_epoch + 1) % self.log_every_n_epochs == 0: + # Get samples from dataset + loader = DataLoader( + self.dataset, + batch_size=min(32, self.max_samples), + shuffle=True, + ) + batch = next(iter(loader)) + images, labels = batch + images = images.to(pl_module.device) + + # Get embeddings + with torch.no_grad(): + embeddings = pl_module.embedding(images) + out, _ = pl_module(images) + preds = torch.argmax(out, dim=1) + + # Log images + grid = torchvision.utils.make_grid( + images[:, 0], # First timepoint + nrow=8, + normalize=True, + value_range=(images.min(), images.max()), + ) + trainer.logger.experiment.add_image( + "cells/timepoint1", + grid, + trainer.current_epoch, + ) + + grid = torchvision.utils.make_grid( + images[:, 1], # Second timepoint + nrow=8, + normalize=True, + value_range=(images.min(), images.max()), + ) + trainer.logger.experiment.add_image( + "cells/timepoint2", + grid, + trainer.current_epoch, + ) + + # Log embeddings + trainer.logger.experiment.add_embedding( + embeddings.reshape(len(embeddings), -1), + metadata=[ + f"label={l.item()}, pred={p.item()}" for l, p in zip(labels, preds) + ], + label_img=images[:, 0], # Use first timepoint as label image + global_step=trainer.current_epoch, + ) + + # Log CAMs if cam_size is provided + if self.cam_size is not None and hasattr(pl_module.model, "get_cam"): + cam = pl_module.model.get_cam(images, size=self.cam_size) + grid = torchvision.utils.make_grid( + cam.unsqueeze(1), # Add channel dimension + nrow=8, + normalize=True, + ) + trainer.logger.experiment.add_image( + "cells/cam", + grid, + trainer.current_epoch, + ) From 91ad80026223ded210f3b37b186af77908d90d47 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 16:21:17 -0800 Subject: [PATCH 29/38] ruff --- viscy/data/tarrow.py | 5 +++-- viscy/representation/timearrow.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index f554544af..df443581b 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -1,10 +1,11 @@ from pathlib import Path + import numpy as np +import torch from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from tarrow.data.tarrow_dataset import TarrowDataset -from torch.utils.data import DataLoader, ConcatDataset -import torch +from torch.utils.data import DataLoader class TarrowDataModule(LightningDataModule): diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index 05c056ae1..6333a3075 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -1,13 +1,13 @@ import torch import torch.nn as nn +import torchvision from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import Callback from tarrow.models import TimeArrowNet from tarrow.models.losses import DecorrelationLoss from torch.optim import Adam from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau -from lightning.pytorch.callbacks import Callback from torch.utils.data import DataLoader -import torchvision class TarrowModule(LightningModule): From c54dd4a8648776fdceb8c8ebf43403fa158ecaa9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 16:23:38 -0800 Subject: [PATCH 30/38] remove unwanted files from pr --- .../pseudotime_analysis/pca_analysis.py | 876 ------------------ viscy/representation/evaluation/distance.py | 457 --------- 2 files changed, 1333 deletions(-) delete mode 100644 applications/pseudotime_analysis/pca_analysis.py delete mode 100644 viscy/representation/evaluation/distance.py diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py deleted file mode 100644 index d79730559..000000000 --- a/applications/pseudotime_analysis/pca_analysis.py +++ /dev/null @@ -1,876 +0,0 @@ -# %% -import numpy as np -import pandas as pd -from sklearn.preprocessing import StandardScaler -from sklearn.decomposition import PCA -from sklearn.mixture import GaussianMixture -from sklearn.metrics import silhouette_score -import matplotlib.pyplot as plt -import seaborn as sns -from viscy.representation.embedding_writer import read_embedding_dataset -from scipy.spatial.distance import pdist, squareform - -# Set global random seed for reproducibility -RANDOM_SEED = 42 -np.random.seed(RANDOM_SEED) - - -def analyze_pc_loadings(pca, feature_names=None, top_n=5): - """Analyze which features contribute most to each PC.""" - if feature_names is None: - feature_names = [f"Feature_{i}" for i in range(pca.components_[0].shape[0])] - - pc_loadings = [] - for i, pc in enumerate(pca.components_): - # Get the absolute loadings - abs_loadings = np.abs(pc) - # Get indices of top contributing features - top_indices = np.argsort(abs_loadings)[-top_n:][::-1] - - # Store the results - pc_dict = { - "PC": i + 1, - "Variance_Explained": pca.explained_variance_ratio_[i], - "Top_Features": [feature_names[idx] for idx in top_indices], - "Top_Loadings": [pc[idx] for idx in top_indices], - } - pc_loadings.append(pc_dict) - - return pd.DataFrame(pc_loadings) - - -def analyze_track_clustering( - pca_result, - track_ids, - time_points, - labels, - phenotype_of_interest, - seed_timepoint, - time_window, -): - """Analyze how tracks cluster in PC space within the time window.""" - # Get points within time window - time_mask = (time_points >= seed_timepoint - time_window) & ( - time_points <= seed_timepoint + time_window - ) - window_points = pca_result[time_mask] - window_tracks = track_ids[time_mask] - window_labels = labels[time_mask] - - # Calculate mean position for each track - track_means = {} - phenotype_tracks = [] - - for track_id in np.unique(window_tracks): - track_mask = (window_tracks == track_id) & ( - window_labels == phenotype_of_interest - ) - if np.any(track_mask): - track_means[track_id] = np.mean(window_points[track_mask], axis=0) - phenotype_tracks.append(track_id) - - if len(phenotype_tracks) < 2: - return None - - # Calculate pairwise distances between track means - track_positions = np.array([track_means[tid] for tid in phenotype_tracks]) - distances = pdist(track_positions) - mean_distance = np.mean(distances) - std_distance = np.std(distances) - - # Calculate spread within each track - track_spreads = {} - for track_id in phenotype_tracks: - track_mask = (window_tracks == track_id) & ( - window_labels == phenotype_of_interest - ) - if np.sum(track_mask) > 1: - track_points = window_points[track_mask] - spread = np.mean(pdist(track_points)) - track_spreads[track_id] = spread - - mean_spread = np.mean(list(track_spreads.values())) if track_spreads else 0 - - return { - "n_tracks": len(phenotype_tracks), - "mean_inter_track_distance": mean_distance, - "std_inter_track_distance": std_distance, - "mean_intra_track_spread": mean_spread, - "clustering_ratio": mean_distance / mean_spread if mean_spread > 0 else np.inf, - } - - -def analyze_pc_distributions( - pca_result, - labels, - phenotype_of_interest, - time_points=None, - seed_timepoint=None, - time_window=None, -): - """Analyze the distributions of each PC for phenotype vs background.""" - n_components = pca_result.shape[1] - results = [] - - for i in range(n_components): - # Get phenotype and background points - if ( - time_points is not None - and seed_timepoint is not None - and time_window is not None - ): - time_mask = (time_points >= seed_timepoint - time_window) & ( - time_points <= seed_timepoint + time_window - ) - pc_values_phenotype = pca_result[ - time_mask & (labels == phenotype_of_interest), i - ] - pc_values_background = pca_result[ - time_mask & (labels != phenotype_of_interest), i - ] - else: - pc_values_phenotype = pca_result[labels == phenotype_of_interest, i] - pc_values_background = pca_result[labels != phenotype_of_interest, i] - - # Calculate basic statistics - stats = { - "PC": i + 1, - "phenotype_mean": np.mean(pc_values_phenotype), - "background_mean": np.mean(pc_values_background), - "phenotype_std": np.std(pc_values_phenotype), - "background_std": np.std(pc_values_background), - "separation": abs( - np.mean(pc_values_phenotype) - np.mean(pc_values_background) - ) - / (np.std(pc_values_phenotype) + np.std(pc_values_background)), - } - - # Check for multimodality using a simple peak detection - hist, bins = np.histogram(pc_values_phenotype, bins="auto") - peaks = len( - [ - i - for i in range(1, len(hist) - 1) - if hist[i] > hist[i - 1] and hist[i] > hist[i + 1] - ] - ) - stats["n_peaks"] = peaks - - results.append(stats) - - return pd.DataFrame(results) - - -def analyze_gmm_clustering( - pca_result, - track_ids, - time_points, - tracks_of_interest, - n_components_range=range(2, 7), - seed_timepoint=None, - time_window=None, -): - """Analyze clusters using Gaussian Mixture Models.""" - # Get points from tracks of interest - track_mask = np.isin(track_ids, tracks_of_interest) - points = pca_result[track_mask] - track_ids_subset = track_ids[track_mask] - times = time_points[track_mask] - - # Apply time window if specified - if seed_timepoint is not None and time_window is not None: - time_mask = (times >= seed_timepoint - time_window) & ( - times <= seed_timepoint + time_window - ) - points = points[time_mask] - track_ids_subset = track_ids_subset[time_mask] - times = times[time_mask] - - # Try different numbers of components - bic_scores = [] - silhouette_scores = [] - models = [] - - for n_components in n_components_range: - gmm = GaussianMixture( - n_components=n_components, random_state=RANDOM_SEED, n_init=10 - ) - gmm.fit(points) - labels = gmm.predict(points) - - bic_scores.append(gmm.bic(points)) - silhouette_scores.append(silhouette_score(points, labels)) - models.append(gmm) - - # Plot model selection metrics - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) - - # BIC plot - ax1.plot(list(n_components_range), bic_scores, "bo-") - ax1.set_xlabel("Number of Components") - ax1.set_ylabel("BIC Score") - ax1.set_title("Model Selection: BIC") - - # Silhouette plot - ax2.plot(list(n_components_range), silhouette_scores, "ro-") - ax2.set_xlabel("Number of Components") - ax2.set_ylabel("Silhouette Score") - ax2.set_title("Model Selection: Silhouette") - - plt.tight_layout() - plt.show() - - # Select best model based on BIC - best_idx = np.argmin(bic_scores) - best_n_components = n_components_range[best_idx] - best_model = models[best_idx] - - # Get cluster assignments - labels = best_model.predict(points) - probs = best_model.predict_proba(points) - - # Plot clustering results - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) - - # Scatter plot colored by cluster - scatter = ax1.scatter( - points[:, 0], points[:, 1], c=labels, cmap="tab10", alpha=0.6, s=50 - ) - ax1.set_xlabel("PC1") - ax1.set_ylabel("PC2") - ax1.set_title(f"GMM Clustering (n={best_n_components})") - plt.colorbar(scatter, ax=ax1, label="Cluster") - - # Plot cluster assignment probabilities - max_probs = np.max(probs, axis=1) - scatter = ax2.scatter( - points[:, 0], points[:, 1], c=max_probs, cmap="viridis", alpha=0.6, s=50 - ) - ax2.set_xlabel("PC1") - ax2.set_ylabel("PC2") - ax2.set_title("Cluster Assignment Probability") - plt.colorbar(scatter, ax=ax2, label="Probability") - - plt.tight_layout() - plt.show() - - # Analyze cluster composition - cluster_stats = [] - for i in range(best_n_components): - cluster_mask = labels == i - cluster_tracks = np.unique(track_ids_subset[cluster_mask]) - cluster_stats.append( - { - "cluster": i, - "n_points": np.sum(cluster_mask), - "n_tracks": len(cluster_tracks), - "tracks": cluster_tracks, - "mean_prob": np.mean(probs[cluster_mask, i]), - "std_prob": np.std(probs[cluster_mask, i]), - } - ) - - # Print cluster statistics - print(f"\nBest number of clusters (BIC): {best_n_components}") - print("\nCluster Statistics:") - for stats in cluster_stats: - print(f"\nCluster {stats['cluster']}:") - print(f" Points: {stats['n_points']}") - print(f" Tracks: {stats['n_tracks']}") - print(f" Mean probability: {stats['mean_prob']:.3f} ± {stats['std_prob']:.3f}") - print(f" Tracks in cluster: {stats['tracks']}") - - return { - "best_model": best_model, - "best_n_components": best_n_components, - "labels": labels, - "probabilities": probs, - "bic_scores": bic_scores, - "silhouette_scores": silhouette_scores, - "cluster_stats": cluster_stats, - } - - -def analyze_cluster_characteristics( - gmm_results, - pca_result, - track_ids, - time_points, - tracks_of_interest, - pc_analysis=None, - seed_timepoint=None, - time_window=None, -): - """Analyze characteristics of GMM clusters including temporal patterns and PC contributions.""" - # Get points from tracks of interest first - track_mask = np.isin(track_ids, tracks_of_interest) - points = pca_result[track_mask] - track_ids_subset = track_ids[track_mask] - times = time_points[track_mask] - - # Apply time window if specified - if seed_timepoint is not None and time_window is not None: - time_mask = (times >= seed_timepoint - time_window) & ( - times <= seed_timepoint + time_window - ) - points = points[time_mask] - track_ids_subset = track_ids_subset[time_mask] - times = times[time_mask] - - # Get cluster assignments for the filtered points - labels = gmm_results["labels"] - probs = gmm_results["probabilities"] - n_clusters = gmm_results["best_n_components"] - - # Analyze temporal patterns in each cluster - print("\nTemporal patterns in clusters:") - for i in range(n_clusters): - cluster_mask = labels == i - cluster_times = times[cluster_mask] - if len(cluster_times) > 0: - print(f"\nCluster {i}:") - print( - f" Time range: {np.min(cluster_times):.1f} to {np.max(cluster_times):.1f}" - ) - print( - f" Mean time: {np.mean(cluster_times):.1f} ± {np.std(cluster_times):.1f}" - ) - - # Analyze PC contributions to cluster separation - print("\nPC contributions to cluster separation:") - for pc_idx in range(min(4, points.shape[1])): # Analyze first 4 PCs - pc_values = points[:, pc_idx] - cluster_means = [np.mean(pc_values[labels == i]) for i in range(n_clusters)] - cluster_stds = [np.std(pc_values[labels == i]) for i in range(n_clusters)] - - # Calculate separation score (ratio of between-cluster to within-cluster variance) - between_var = np.var(cluster_means) - within_var = np.mean(cluster_stds) - separation_score = between_var / within_var if within_var > 0 else float("inf") - - print(f"\nPC{pc_idx + 1}:") - print(f" Separation score: {separation_score:.3f}") - if pc_analysis is not None: - pc_info = pc_analysis[pc_analysis["PC"] == pc_idx + 1].iloc[0] - print( - f" Top contributing features: {', '.join(pc_info['Top_Features'][:3])}" - ) - - # Print cluster-specific stats - for i in range(n_clusters): - cluster_mask = labels == i - print(f" Cluster {i}: {cluster_means[i]:.3f} ± {cluster_stds[i]:.3f}") - - # Analyze track transitions between clusters - print("\nTrack transitions between clusters:") - for track_id in tracks_of_interest: - track_mask = track_ids_subset == track_id - track_labels = labels[track_mask] - track_times = times[track_mask] - - if len(track_labels) > 1: - # Sort by time - sort_idx = np.argsort(track_times) - track_labels = track_labels[sort_idx] - track_times = track_times[sort_idx] - - # Find transitions - transitions = np.where(track_labels[1:] != track_labels[:-1])[0] - if len(transitions) > 0: - print(f"\nTrack {track_id}:") - for trans_idx in transitions: - from_cluster = track_labels[trans_idx] - to_cluster = track_labels[trans_idx + 1] - trans_time = track_times[trans_idx + 1] - print(f" {trans_time:.1f}: {from_cluster} -> {to_cluster}") - - return { - "temporal_patterns": { - i: { - "mean_time": np.mean(times[labels == i]), - "std_time": np.std(times[labels == i]), - } - for i in range(n_clusters) - }, - "pc_contributions": { - f"PC{pc_idx + 1}": { - "separation_score": ( - np.var( - [ - np.mean(points[labels == i, pc_idx]) - for i in range(n_clusters) - ] - ) - / np.mean( - [np.std(points[labels == i, pc_idx]) for i in range(n_clusters)] - ) - if np.mean( - [np.std(points[labels == i, pc_idx]) for i in range(n_clusters)] - ) - > 0 - else float("inf") - ) - } - for pc_idx in range(min(4, points.shape[1])) - }, - } - - -def analyze_embeddings_with_pca( - embedding_path, - annotation_path=None, - phenotype_of_interest=None, - n_random_tracks=10, - n_components=8, - seed_timepoint=None, - time_window=10, - fov_patterns=None, -): - """Analyze embeddings using PCA, either for specific phenotypes or random tracks. - - Args: - embedding_path: Path to embedding zarr file - annotation_path: Optional path to annotation CSV file. If None, uses random tracks - phenotype_of_interest: Which phenotype to analyze (only used if annotation_path is provided) - n_random_tracks: Number of random tracks to select (only used if annotation_path is None) - n_components: Number of PCA components - seed_timepoint: Center of time window. If None, uses all timepoints - time_window: Size of time window (+/-). Only used if seed_timepoint is not None - fov_patterns: List of patterns to filter FOVs (e.g. ['/C/2/*', '/B/3/*']). - Optional even when using annotation_path - can be used to restrict - analysis to specific FOVs while still using phenotype information. - """ - if annotation_path is None: - print(f"\nUsing random tracks (global seed: {RANDOM_SEED})") - - if seed_timepoint is None: - print("\nUsing all timepoints") - else: - print(f"\nUsing time window: {seed_timepoint}±{time_window}") - - # Load embeddings - embedding_dataset = read_embedding_dataset(embedding_path) - features = embedding_dataset["features"] - track_ids = embedding_dataset["track_id"].values - fovs = embedding_dataset["fov_name"].values - time_points = embedding_dataset["t"].values - - # Filter FOVs if patterns are provided - if fov_patterns is not None: - print(f"\nFiltering FOVs with patterns: {fov_patterns}") - fov_mask = np.zeros_like(fovs, dtype=bool) - for pattern in fov_patterns: - fov_mask |= np.char.find(fovs.astype(str), pattern) >= 0 - - # Update all arrays with the FOV mask - features = features[fov_mask] - track_ids = track_ids[fov_mask] - fovs = fovs[fov_mask] - time_points = time_points[fov_mask] - - print(f"Found {len(np.unique(fovs))} FOVs matching patterns") - - # Get tracks of interest - if annotation_path is not None: - # Load annotations and get phenotype tracks - annotations_df = pd.read_csv(annotation_path) - annotation_map = { - (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] - for _, row in annotations_df.iterrows() - } - labels = np.array( - [ - annotation_map.get((str(fov), int(track_id)), -1) - for fov, track_id in zip(fovs, track_ids) - ] - ) - selection_mask = labels == phenotype_of_interest - tracks_of_interest = np.unique(track_ids[selection_mask]) - other_mask = ~selection_mask - mode = f"phenotype {phenotype_of_interest}" - else: - # Select random tracks from different FOVs when possible - # Create a mapping of FOV to tracks - fov_track_map = {} - for fov, track_id in zip(fovs, track_ids): - if fov not in fov_track_map: - fov_track_map[fov] = [] - if track_id not in fov_track_map[fov]: # Avoid duplicates - fov_track_map[fov].append(track_id) - - # Get list of all FOVs - available_fovs = list(fov_track_map.keys()) - tracks_of_interest = [] - - # First, try to get one track from each FOV - np.random.shuffle(available_fovs) # Randomize FOV order - for fov in available_fovs: - if len(tracks_of_interest) < n_random_tracks: - # Randomly select a track from this FOV - track = np.random.choice(fov_track_map[fov]) - tracks_of_interest.append(track) - else: - break - - # If we still need more tracks, randomly select from remaining tracks - if len(tracks_of_interest) < n_random_tracks: - # Get all remaining tracks that aren't already selected - remaining_tracks = [ - track - for track in np.unique(track_ids) - if track not in tracks_of_interest - ] - # Select additional tracks - additional_tracks = np.random.choice( - remaining_tracks, - size=min( - n_random_tracks - len(tracks_of_interest), len(remaining_tracks) - ), - replace=False, - ) - tracks_of_interest.extend(additional_tracks) - - tracks_of_interest = np.array(tracks_of_interest) - selection_mask = np.isin(track_ids, tracks_of_interest) - other_mask = ~selection_mask - labels = np.where(selection_mask, 1, 0) - mode = "random tracks" - - # Print selected tracks with their FOVs - print("\nSelected tracks:") - for track in tracks_of_interest: - track_fovs = np.unique(fovs[track_ids == track]) - print(f"Track {track}: FOV {track_fovs[0]}") - - # Scale the features - scaler = StandardScaler() - scaled_features = scaler.fit_transform(features.values) - - # Perform PCA - pca = PCA(n_components=n_components) - pca_result = pca.fit_transform(scaled_features) - - # Calculate explained variance - explained_variance_ratio = pca.explained_variance_ratio_ - cumulative_variance_ratio = np.cumsum(explained_variance_ratio) - - # Create track-specific colors - track_colors = plt.cm.tab10(np.linspace(0, 1, len(tracks_of_interest))) - track_color_map = dict(zip(tracks_of_interest, track_colors)) - - # Create plots - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) - - # Scree plot - ax1.plot(range(1, n_components + 1), explained_variance_ratio, "bo-") - ax1.plot(range(1, n_components + 1), cumulative_variance_ratio, "ro-") - ax1.set_xlabel("Principal Component") - ax1.set_ylabel("Explained Variance Ratio") - ax1.set_title("Scree Plot") - ax1.legend(["Individual", "Cumulative"]) - - # First two components plot - # Plot other tracks/cells in gray - ax2.scatter( - pca_result[other_mask, 0], - pca_result[other_mask, 1], - alpha=0.1, - color="gray", - label="Other cells", - s=10, - ) - - # Plot tracks of interest with decreasing opacity - for track_id in tracks_of_interest: - track_mask = track_ids == track_id - track_points = pca_result[track_mask] - track_times = time_points[track_mask] - - # Sort points by time - sort_idx = np.argsort(track_times) - track_points = track_points[sort_idx] - track_times = track_times[sort_idx] - - # Apply time window if specified - if seed_timepoint is not None: - time_mask = (track_times >= seed_timepoint - time_window) & ( - track_times <= seed_timepoint + time_window - ) - else: - time_mask = np.ones_like(track_times, dtype=bool) # Use all points - - if np.any(time_mask): # Only plot if there are points in the window - window_points = track_points[time_mask] - window_times = track_times[time_mask] - - # Normalize times within window for opacity - norm_times = (window_times - window_times.min()) / ( - window_times.max() - window_times.min() + 1e-10 - ) - alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] - - # Plot points with opacity based on normalized time - for idx in range(len(window_points)): - ax2.scatter( - window_points[idx, 0], - window_points[idx, 1], - color=track_color_map[track_id], - alpha=alphas[idx], - s=50, - label=( - f"Track {track_id}" if idx == len(window_points) - 1 else None - ), - ) - - ax2.set_xlabel("First Principal Component") - ax2.set_ylabel("Second Principal Component") - title = f"First Two Principal Components - {mode}" - if seed_timepoint is not None: - title += f"\nTime window: {seed_timepoint}±{time_window}" - ax2.set_title(title) - ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left") - - plt.tight_layout() - plt.show() - - # Pairwise component plots - fig, axes = plt.subplots(n_components, n_components, figsize=(20, 20)) - - for i in range(n_components): - for j in range(n_components): - if i != j: - # Plot other points first - axes[i, j].scatter( - pca_result[other_mask, j], - pca_result[other_mask, i], - alpha=0.1, - color="gray", - s=5, - ) - - # Plot each track with decreasing opacity - for track_id in tracks_of_interest: - track_mask = track_ids == track_id - track_points_j = pca_result[track_mask, j] - track_points_i = pca_result[track_mask, i] - track_times = time_points[track_mask] - - # Sort points by time - sort_idx = np.argsort(track_times) - track_points_j = track_points_j[sort_idx] - track_points_i = track_points_i[sort_idx] - track_times = track_times[sort_idx] - - # Select points within the time window - time_mask = (track_times >= seed_timepoint - time_window) & ( - track_times <= seed_timepoint + time_window - ) - if np.any(time_mask): # Only plot if there are points in the window - window_points_j = track_points_j[time_mask] - window_points_i = track_points_i[time_mask] - window_times = track_times[time_mask] - - # Normalize times within window for opacity - norm_times = (window_times - window_times.min()) / ( - window_times.max() - window_times.min() + 1e-10 - ) - alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] - - # Plot points with opacity based on normalized time - for idx in range(len(window_points_j)): - axes[i, j].scatter( - window_points_j[idx], - window_points_i[idx], - color=track_color_map[track_id], - alpha=alphas[idx], - s=30, - ) - - axes[i, j].set_xlabel(f"PC{j+1}") - axes[i, j].set_ylabel(f"PC{i+1}") - else: - # On diagonal, show distribution - sns.histplot( - pca_result[other_mask, i], ax=axes[i, i], color="gray", alpha=0.3 - ) - for track_id in tracks_of_interest: - track_mask = track_ids == track_id - # For histograms, use all points in the time window - time_mask = ( - time_points[track_mask] >= seed_timepoint - time_window - ) & (time_points[track_mask] <= seed_timepoint + time_window) - if np.any(time_mask): - sns.histplot( - pca_result[track_mask][time_mask, i], - ax=axes[i, i], - color=track_color_map[track_id], - alpha=0.5, - ) - axes[i, i].set_xlabel(f"PC{i+1}") - - plt.tight_layout() - plt.show() - - # Print variance explained - print("\nExplained variance ratio by component:") - for i, var in enumerate(explained_variance_ratio): - print(f"PC{i+1}: {var:.3f} ({cumulative_variance_ratio[i]:.3f} cumulative)") - - # Add analysis of PC loadings - pc_analysis = analyze_pc_loadings(pca) - print("\nPC Loading Analysis:") - print(pc_analysis.to_string(index=False)) - - # Add analysis of track clustering - cluster_analysis = analyze_track_clustering( - pca_result, - track_ids, - time_points, - labels, - 1 if annotation_path is None else phenotype_of_interest, - seed_timepoint, - time_window, - ) - - if cluster_analysis: - print("\nTrack Clustering Analysis:") - print(f"Number of tracks in window: {cluster_analysis['n_tracks']}") - print( - f"Mean distance between tracks: {cluster_analysis['mean_inter_track_distance']:.3f}" - ) - print( - f"Mean spread within tracks: {cluster_analysis['mean_intra_track_spread']:.3f}" - ) - print( - f"Clustering ratio (inter/intra): {cluster_analysis['clustering_ratio']:.3f}" - ) - print("(Lower clustering ratio suggests tighter clustering)") - - # Add distribution analysis - dist_analysis = analyze_pc_distributions( - pca_result, - labels, - 1 if annotation_path is None else phenotype_of_interest, - time_points if seed_timepoint is not None else None, - seed_timepoint, - time_window, - ) - print("\nPC Distribution Analysis:") - print( - "(Separation score > 1 suggests good separation between selected tracks and background)" - ) - print(dist_analysis.to_string(index=False)) - - # Return PCA results and additional data needed for clustering - return ( - pca, - pca_result, - explained_variance_ratio, - labels, - tracks_of_interest, - pc_analysis, - cluster_analysis, - dist_analysis, - track_ids, - time_points, - ) - - -# %% -if __name__ == "__main__": - embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" - annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" - - # Using phenotype annotations with specific FOVs - print("\nAnalyzing phenotype 1 in specific FOVs:") - ( - pca, - pca_result, - variance_ratio, - labels, - tracks, - pc_analysis, - cluster_analysis, - dist_analysis, - track_ids, - time_points, - ) = analyze_embeddings_with_pca( - embedding_path, - annotation_path=annotation_path, - phenotype_of_interest=1, - seed_timepoint=55, - time_window=10, - fov_patterns=["/C/2/", "/B/3/", "/B/2/"], - ) - - # Run GMM clustering analysis separately - print("\nPerforming GMM clustering analysis...") - gmm_results = analyze_gmm_clustering( - pca_result, - track_ids, - time_points, - tracks, - seed_timepoint=55, - time_window=10, - ) - - # Analyze cluster characteristics - print("\nAnalyzing cluster characteristics...") - cluster_characteristics = analyze_cluster_characteristics( - gmm_results, - pca_result, - track_ids, - time_points, - tracks, - pc_analysis=pc_analysis, - seed_timepoint=55, - time_window=10, - ) - - # Using random tracks from specific FOVs - print("\nAnalyzing random tracks from specific FOVs:") - ( - pca, - pca_result, - variance_ratio, - labels, - tracks, - pc_analysis, - cluster_analysis, - dist_analysis, - track_ids, - time_points, - ) = analyze_embeddings_with_pca( - embedding_path, - annotation_path=None, - n_random_tracks=10, - seed_timepoint=55, - time_window=30, - fov_patterns=["/C/2/", "/B/3/", "/B/2/"], - ) - # %% - # Run GMM clustering analysis for random tracks - print("\nPerforming GMM clustering analysis for random tracks...") - gmm_results = analyze_gmm_clustering( - pca_result, - track_ids, - time_points, - tracks, - seed_timepoint=55, - time_window=30, - ) - - # Analyze cluster characteristics for random tracks - print("\nAnalyzing cluster characteristics for random tracks...") - cluster_characteristics = analyze_cluster_characteristics( - gmm_results, - pca_result, - track_ids, - time_points, - tracks, - pc_analysis=pc_analysis, - seed_timepoint=55, - time_window=30, - ) - -# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py deleted file mode 100644 index 85b79a701..000000000 --- a/viscy/representation/evaluation/distance.py +++ /dev/null @@ -1,457 +0,0 @@ -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -from numpy.typing import NDArray -from scipy.optimize import minimize_scalar -from scipy.stats import gaussian_kde -from sklearn.metrics.pairwise import cosine_similarity -from sklearn.preprocessing import StandardScaler -from tqdm import tqdm - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.clustering import ( - compare_time_offset, - pairwise_distance_matrix, - rank_nearest_neighbors, - select_block, -) - - -def calculate_distance_cell( - embedding_dataset, - fov_name, - track_id, - metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", -): - """ - Calculate distances between a cell's first timepoint embedding and all its subsequent embeddings. - - This function extracts embeddings for a specific cell (identified by fov_name and track_id) - and calculates the distance between its first timepoint embedding and all subsequent timepoints - using the specified distance metric. - - Parameters - ---------- - embedding_dataset : xarray.Dataset - Dataset containing the embeddings and metadata. Must have dimensions for 'features', - 'fov_name', 'track_id', and 't' (time). - fov_name : str - Field of view name to identify the specific imaging area. - track_id : int - Track ID of the cell to analyze. - metric : {'cosine', 'euclidean', 'normalized_euclidean'}, default='cosine' - Distance metric to use for calculations: - - 'cosine': Cosine similarity between embeddings - - 'euclidean': Standard Euclidean distance - - 'normalized_euclidean': Euclidean distance between L2-normalized embeddings - - Returns - ------- - time_points : numpy.ndarray - Array of time points corresponding to the calculated distances. - distances : list - List of distances between the first timepoint embedding and each subsequent - timepoint embedding, calculated using the specified metric. - - Notes - ----- - For 'normalized_euclidean', embeddings are L2-normalized before distance calculation. - Cosine similarity results in values between -1 and 1, where 1 indicates identical - direction, 0 indicates orthogonality, and -1 indicates opposite directions. - Euclidean distances are always non-negative. - - Examples - -------- - >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="cosine") - >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="euclidean") - """ - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - - if metric == "normalized_euclidean": - features = features / np.linalg.norm(features, axis=1, keepdims=True) - - first_time_point_embedding = features[0].reshape(1, -1) - - if metric == "cosine": - distances = cosine_similarity(first_time_point_embedding, features).flatten() - else: # both euclidean and normalized_euclidean use norm - distances = np.linalg.norm(first_time_point_embedding - features, axis=1) - - return time_points, distances.tolist() - - -def compute_displacement( - embedding_dataset, - distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", - max_delta_t: int = None, -) -> Dict[int, List[float]]: - """Compute displacements between embeddings at different time differences. - - For each time difference τ, computes distances between embeddings of the same cell - separated by τ timepoints. Supports multiple distance metrics. - - Parameters - ---------- - embedding_dataset : xarray.Dataset - Dataset containing embeddings and metadata with the following variables: - - features: (N, D) array of embeddings - - fov_name: (N,) array of field of view names - - track_id: (N,) array of cell track IDs - - t: (N,) array of timepoints - distance_metric : str, optional - The metric to use for computing distances between embeddings. - Valid options are: - - "euclidean_squared": Squared Euclidean distance (default) - - "cosine": Cosine similarity - max_delta_t : int, optional - Maximum time difference τ to compute displacements for. - If None, uses the maximum possible time difference in the dataset. - - Returns - ------- - Dict[int, List[float]] - Dictionary mapping time difference τ to list of displacements. - Each displacement value represents the distance between a pair of - embeddings from the same cell separated by τ timepoints. - """ - - # Get data from dataset - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - # Check if max_delta_t is provided, otherwise use the maximum timepoint - if max_delta_t is None: - max_delta_t = timepoints.max() - - displacement_per_delta_t = defaultdict(list) - # Process each sample - for i in tqdm(range(len(fov_names)), desc="Processing FOVs"): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i].reshape(1, -1) - - # Compute displacements for each delta t - for delta_t in range(1, max_delta_t + 1): - future_time = current_time + delta_t - matching_indices = np.where( - (fov_names == fov_name) - & (track_ids == track_id) - & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - if distance_metric == "euclidean_squared": - future_embedding = embeddings[matching_indices[0]].reshape(1, -1) - displacement = np.sum((current_embedding - future_embedding) ** 2) - elif distance_metric == "cosine": - future_embedding = embeddings[matching_indices[0]].reshape(1, -1) - displacement = cosine_similarity( - current_embedding, future_embedding - ) - displacement_per_delta_t[delta_t].append(displacement) - return dict(displacement_per_delta_t) - - -def compute_displacement_statistics( - displacement_per_delta_t: Dict[int, List[float]] -) -> Tuple[Dict[int, float], Dict[int, float]]: - """Compute mean and standard deviation of displacements for each delta_t. - - Parameters - ---------- - displacement_per_delta_t : Dict[int, List[float]] - Dictionary mapping τ to list of displacements - - Returns - ------- - Tuple[Dict[int, float], Dict[int, float]] - Tuple of (mean_displacements, std_displacements) where each is a - dictionary mapping τ to the statistic - """ - mean_displacement_per_delta_t = { - delta_t: np.mean(displacements) - for delta_t, displacements in displacement_per_delta_t.items() - } - std_displacement_per_delta_t = { - delta_t: np.std(displacements) - for delta_t, displacements in displacement_per_delta_t.items() - } - return mean_displacement_per_delta_t, std_displacement_per_delta_t - - -def compute_dynamic_range(mean_displacement_per_delta_t): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. - - Parameters: - mean_displacement_per_delta_t: dict with τ as key and mean displacement as value - - Returns: - float: dynamic range (max displacement - min displacement) - """ - displacements = list(mean_displacement_per_delta_t.values()) - return max(displacements) - min(displacements) - - -def compute_rms_per_track(embedding_dataset): - """ - Compute RMS of the time derivative of embeddings per track. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - - Returns: - list: A list of RMS values, one for each track. - """ - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - cell_identifiers = np.array( - list(zip(fov_names, track_ids)), - dtype=[("fov_name", "O"), ("track_id", "int64")], - ) - unique_cells = np.unique(cell_identifiers) - - rms_values = [] - - for cell in unique_cells: - fov_name = cell["fov_name"] - track_id = cell["track_id"] - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - cell_timepoints = timepoints[indices] - cell_embeddings = embeddings[indices] - - if len(cell_embeddings) < 2: - continue - - sorted_indices = np.argsort(cell_timepoints) - cell_embeddings = cell_embeddings[sorted_indices] - differences = np.diff(cell_embeddings, axis=0) - - if differences.shape[0] == 0: - continue - - norms = np.linalg.norm(differences, axis=1) - rms = np.sqrt(np.mean(norms**2)) - rms_values.append(rms) - - return rms_values - - -def find_distribution_peak(data: np.ndarray) -> float: - """ - Find the peak (mode) of a distribution using kernel density estimation. - - Args: - data: Array of values to find the peak for - - Returns: - float: The x-value where the peak occurs - """ - kde = gaussian_kde(data) - # Find the peak (maximum) of the KDE - result = minimize_scalar( - lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" - ) - return result.x - - -def compute_piece_wise_dissimilarity( - features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray -): - """ - Computing the smoothness and dynamic range - - Get the off diagonal per block and compute the mode - - The blocks are not square, so we need to get the off diagonal elements - - Get the 1 and 99 percentile of the off diagonal per block - """ - piece_wise_dissimilarity_per_track = [] - piece_wise_rank_difference_per_track = [] - for name, subdata in features_df.groupby(["fov_name", "track_id"]): - if len(subdata) > 1: - indices = subdata.index.values - single_track_dissimilarity = select_block(cross_dist, indices) - single_track_rank_fraction = select_block(rank_fractions, indices) - piece_wise_dissimilarity = compare_time_offset( - single_track_dissimilarity, time_offset=1 - ) - piece_wise_rank_difference = compare_time_offset( - single_track_rank_fraction, time_offset=1 - ) - piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) - piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) - return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track - - -def compute_embedding_distances( - prediction_path: Path, - output_path: Path, - distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", -) -> pd.DataFrame: - """ - Compute and save pairwise distances between embeddings. - - Parameters - ---------- - prediction_path : Path - Path to the embedding dataset - output_path : Path - name of saved CSV file - distance_metric : str, optional - Distance metric to use for computing distances between embeddings - - Returns - ------- - pd.DataFrame - DataFrame containing the adjacent frame and random sampling distances - """ - # Read the dataset - embeddings = read_embedding_dataset(prediction_path) - features = embeddings["features"] - - if distance_metric != "euclidean": - features = StandardScaler().fit_transform(features.values) - - # Compute the distance matrix - cross_dist = pairwise_distance_matrix(features, metric=distance_metric) - - # Normalize by sqrt of embedding dimension if using euclidean distance - if distance_metric == "euclidean": - cross_dist /= np.sqrt(features.shape[1]) - - # Plot the distance matrix - plt.figure(figsize=(10, 10)) - plt.imshow(cross_dist, cmap="viridis") - plt.colorbar(label=f"{distance_metric.capitalize()} Distance") - plt.title(f"{distance_metric.capitalize()} Distance Matrix") - plt.tight_layout() - base_name = prediction_path.stem - plt.savefig(output_path / f"{base_name}_distance_matrix.png", dpi=600) - plt.close() - rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - - # Compute piece-wise dissimilarity and rank difference - features_df = features["sample"].to_dataframe().reset_index(drop=True) - piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( - compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) - ) - - all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) - - # Random sampling values in the dissimilarity matrix - n_samples = len(all_dissimilarity) - random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) - sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] - - # Create and save DataFrame - distributions_df = pd.DataFrame( - { - "adjacent_frame": pd.Series(all_dissimilarity), - "random_sampling": pd.Series(sampled_values), - } - ) - - csv_path = output_path - distributions_df.to_csv(csv_path, index=False) - - return distributions_df - - -def analyze_and_plot_distances( - distributions_df: pd.DataFrame, - output_file_path: Optional[str], - overwrite: bool = False, -) -> dict: - """ - Analyze distance distributions and create visualization plots. - - Parameters - ---------- - distributions_df : pd.DataFrame - DataFrame containing 'adjacent_frame' and 'random_sampling' columns - output_file_path : str, optional - Path to save the plot ideally with a .pdf extension. Uses `plt.savefig()` - overwrite : bool, default=False - If True, overwrites existing files - - Returns - ------- - dict - Dictionary containing computed metrics including means, standard deviations, - medians, peaks, and dynamic range of the distributions - """ - # Compute statistics - adjacent_dist = distributions_df["adjacent_frame"].values - random_dist = distributions_df["random_sampling"].values - - # Compute peaks - adjacent_peak = float(find_distribution_peak(adjacent_dist)) - random_peak = float(find_distribution_peak(random_dist)) - dynamic_range = float(random_peak - adjacent_peak) - - metrics = { - "dissimilarity_mean": float(np.mean(adjacent_dist)), - "dissimilarity_std": float(np.std(adjacent_dist)), - "dissimilarity_median": float(np.median(adjacent_dist)), - "dissimilarity_peak": adjacent_peak, - "dissimilarity_p99": float(np.percentile(adjacent_dist, 99)), - "dissimilarity_p1": float(np.percentile(adjacent_dist, 1)), - "random_mean": float(np.mean(random_dist)), - "random_std": float(np.std(random_dist)), - "random_median": float(np.median(random_dist)), - "random_peak": random_peak, - "dynamic_range": dynamic_range, - } - - # Create plot - fig = plt.figure() - sns.histplot( - data=distributions_df, - x="adjacent_frame", - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", - ) - sns.histplot( - data=distributions_df, - x="random_sampling", - bins=30, - kde=True, - color="red", - alpha=0.5, - stat="density", - ) - plt.xlabel("Cosine Dissimilarity") - plt.ylabel("Density") - plt.axvline(x=adjacent_peak, color="cyan", linestyle="--", alpha=0.8) - plt.axvline(x=random_peak, color="red", linestyle="--", alpha=0.8) - plt.tight_layout() - plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - if output_file_path.exists() and not overwrite: - raise FileExistsError( - f"File {output_file_path} already exists and overwrite=False" - ) - fig.savefig(output_file_path, dpi=600) - plt.show() - - return metrics From 822dbba3d92c9bebcaddf51d58520dfeaa403668 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 16:25:39 -0800 Subject: [PATCH 31/38] revert unwatned changes to files --- .../pseudotime_analysis/pca_analysis.py | 564 ++++++++++++++++++ viscy/representation/evaluation/distance.py | 460 ++++++++++++++ 2 files changed, 1024 insertions(+) create mode 100644 applications/pseudotime_analysis/pca_analysis.py create mode 100644 viscy/representation/evaluation/distance.py diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py new file mode 100644 index 000000000..3aed085ce --- /dev/null +++ b/applications/pseudotime_analysis/pca_analysis.py @@ -0,0 +1,564 @@ +# %% +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import seaborn as sns +from viscy.representation.embedding_writer import read_embedding_dataset +from scipy.spatial.distance import pdist, squareform + +# Set global random seed for reproducibility +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) + + +def analyze_pc_loadings(pca, feature_names=None, top_n=5): + """Analyze which features contribute most to each PC.""" + if feature_names is None: + feature_names = [f"Feature_{i}" for i in range(pca.components_[0].shape[0])] + + pc_loadings = [] + for i, pc in enumerate(pca.components_): + # Get the absolute loadings + abs_loadings = np.abs(pc) + # Get indices of top contributing features + top_indices = np.argsort(abs_loadings)[-top_n:][::-1] + + # Store the results + pc_dict = { + "PC": i + 1, + "Variance_Explained": pca.explained_variance_ratio_[i], + "Top_Features": [feature_names[idx] for idx in top_indices], + "Top_Loadings": [pc[idx] for idx in top_indices], + } + pc_loadings.append(pc_dict) + + return pd.DataFrame(pc_loadings) + + +def analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + phenotype_of_interest, + seed_timepoint, + time_window, +): + """Analyze how tracks cluster in PC space within the time window.""" + # Get points within time window + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + window_points = pca_result[time_mask] + window_tracks = track_ids[time_mask] + window_labels = labels[time_mask] + + # Calculate mean position for each track + track_means = {} + phenotype_tracks = [] + + for track_id in np.unique(window_tracks): + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.any(track_mask): + track_means[track_id] = np.mean(window_points[track_mask], axis=0) + phenotype_tracks.append(track_id) + + if len(phenotype_tracks) < 2: + return None + + # Calculate pairwise distances between track means + track_positions = np.array([track_means[tid] for tid in phenotype_tracks]) + distances = pdist(track_positions) + mean_distance = np.mean(distances) + std_distance = np.std(distances) + + # Calculate spread within each track + track_spreads = {} + for track_id in phenotype_tracks: + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.sum(track_mask) > 1: + track_points = window_points[track_mask] + spread = np.mean(pdist(track_points)) + track_spreads[track_id] = spread + + mean_spread = np.mean(list(track_spreads.values())) if track_spreads else 0 + + return { + "n_tracks": len(phenotype_tracks), + "mean_inter_track_distance": mean_distance, + "std_inter_track_distance": std_distance, + "mean_intra_track_spread": mean_spread, + "clustering_ratio": mean_distance / mean_spread if mean_spread > 0 else np.inf, + } + + +def analyze_pc_distributions( + pca_result, + labels, + phenotype_of_interest, + time_points=None, + seed_timepoint=None, + time_window=None, +): + """Analyze the distributions of each PC for phenotype vs background.""" + n_components = pca_result.shape[1] + results = [] + + for i in range(n_components): + # Get phenotype and background points + if ( + time_points is not None + and seed_timepoint is not None + and time_window is not None + ): + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + pc_values_phenotype = pca_result[ + time_mask & (labels == phenotype_of_interest), i + ] + pc_values_background = pca_result[ + time_mask & (labels != phenotype_of_interest), i + ] + else: + pc_values_phenotype = pca_result[labels == phenotype_of_interest, i] + pc_values_background = pca_result[labels != phenotype_of_interest, i] + + # Calculate basic statistics + stats = { + "PC": i + 1, + "phenotype_mean": np.mean(pc_values_phenotype), + "background_mean": np.mean(pc_values_background), + "phenotype_std": np.std(pc_values_phenotype), + "background_std": np.std(pc_values_background), + "separation": abs( + np.mean(pc_values_phenotype) - np.mean(pc_values_background) + ) + / (np.std(pc_values_phenotype) + np.std(pc_values_background)), + } + + # Check for multimodality using a simple peak detection + hist, bins = np.histogram(pc_values_phenotype, bins="auto") + peaks = len( + [ + i + for i in range(1, len(hist) - 1) + if hist[i] > hist[i - 1] and hist[i] > hist[i + 1] + ] + ) + stats["n_peaks"] = peaks + + results.append(stats) + + return pd.DataFrame(results) + + +def analyze_embeddings_with_pca( + embedding_path, + annotation_path=None, + phenotype_of_interest=None, + n_random_tracks=10, + n_components=8, + seed_timepoint=None, + time_window=10, + fov_patterns=None, +): + """Analyze embeddings using PCA, either for specific phenotypes or random tracks. + + Args: + embedding_path: Path to embedding zarr file + annotation_path: Optional path to annotation CSV file. If None, uses random tracks + phenotype_of_interest: Which phenotype to analyze (only used if annotation_path is provided) + n_random_tracks: Number of random tracks to select (only used if annotation_path is None) + n_components: Number of PCA components + seed_timepoint: Center of time window. If None, uses all timepoints + time_window: Size of time window (+/-). Only used if seed_timepoint is not None + fov_patterns: List of patterns to filter FOVs (e.g. ['/C/2/*', '/B/3/*']). + Optional even when using annotation_path - can be used to restrict + analysis to specific FOVs while still using phenotype information. + """ + if annotation_path is None: + print(f"\nUsing random tracks (global seed: {RANDOM_SEED})") + + if seed_timepoint is None: + print("\nUsing all timepoints") + else: + print(f"\nUsing time window: {seed_timepoint}±{time_window}") + + # Load embeddings + embedding_dataset = read_embedding_dataset(embedding_path) + features = embedding_dataset["features"] + track_ids = embedding_dataset["track_id"].values + fovs = embedding_dataset["fov_name"].values + time_points = embedding_dataset["t"].values + + # Filter FOVs if patterns are provided + if fov_patterns is not None: + print(f"\nFiltering FOVs with patterns: {fov_patterns}") + fov_mask = np.zeros_like(fovs, dtype=bool) + for pattern in fov_patterns: + fov_mask |= np.char.find(fovs.astype(str), pattern) >= 0 + + # Update all arrays with the FOV mask + features = features[fov_mask] + track_ids = track_ids[fov_mask] + fovs = fovs[fov_mask] + time_points = time_points[fov_mask] + + print(f"Found {len(np.unique(fovs))} FOVs matching patterns") + + # Get tracks of interest + if annotation_path is not None: + # Load annotations and get phenotype tracks + annotations_df = pd.read_csv(annotation_path) + annotation_map = { + (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] + for _, row in annotations_df.iterrows() + } + labels = np.array( + [ + annotation_map.get((str(fov), int(track_id)), -1) + for fov, track_id in zip(fovs, track_ids) + ] + ) + selection_mask = labels == phenotype_of_interest + tracks_of_interest = np.unique(track_ids[selection_mask]) + other_mask = ~selection_mask + mode = f"phenotype {phenotype_of_interest}" + else: + # Select random tracks from different FOVs when possible + # Create a mapping of FOV to tracks + fov_track_map = {} + for fov, track_id in zip(fovs, track_ids): + if fov not in fov_track_map: + fov_track_map[fov] = [] + if track_id not in fov_track_map[fov]: # Avoid duplicates + fov_track_map[fov].append(track_id) + + # Get list of all FOVs + available_fovs = list(fov_track_map.keys()) + tracks_of_interest = [] + + # First, try to get one track from each FOV + np.random.shuffle(available_fovs) # Randomize FOV order + for fov in available_fovs: + if len(tracks_of_interest) < n_random_tracks: + # Randomly select a track from this FOV + track = np.random.choice(fov_track_map[fov]) + tracks_of_interest.append(track) + else: + break + + # If we still need more tracks, randomly select from remaining tracks + if len(tracks_of_interest) < n_random_tracks: + # Get all remaining tracks that aren't already selected + remaining_tracks = [ + track + for track in np.unique(track_ids) + if track not in tracks_of_interest + ] + # Select additional tracks + additional_tracks = np.random.choice( + remaining_tracks, + size=min( + n_random_tracks - len(tracks_of_interest), len(remaining_tracks) + ), + replace=False, + ) + tracks_of_interest.extend(additional_tracks) + + tracks_of_interest = np.array(tracks_of_interest) + selection_mask = np.isin(track_ids, tracks_of_interest) + other_mask = ~selection_mask + labels = np.where(selection_mask, 1, 0) + mode = "random tracks" + + # Print selected tracks with their FOVs + print("\nSelected tracks:") + for track in tracks_of_interest: + track_fovs = np.unique(fovs[track_ids == track]) + print(f"Track {track}: FOV {track_fovs[0]}") + + # Scale the features + scaler = StandardScaler() + scaled_features = scaler.fit_transform(features.values) + + # Perform PCA + pca = PCA(n_components=n_components) + pca_result = pca.fit_transform(scaled_features) + + # Calculate explained variance + explained_variance_ratio = pca.explained_variance_ratio_ + cumulative_variance_ratio = np.cumsum(explained_variance_ratio) + + # Create track-specific colors + track_colors = plt.cm.tab10(np.linspace(0, 1, len(tracks_of_interest))) + track_color_map = dict(zip(tracks_of_interest, track_colors)) + + # Create plots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) + + # Scree plot + ax1.plot(range(1, n_components + 1), explained_variance_ratio, "bo-") + ax1.plot(range(1, n_components + 1), cumulative_variance_ratio, "ro-") + ax1.set_xlabel("Principal Component") + ax1.set_ylabel("Explained Variance Ratio") + ax1.set_title("Scree Plot") + ax1.legend(["Individual", "Cumulative"]) + + # First two components plot + # Plot other tracks/cells in gray + ax2.scatter( + pca_result[other_mask, 0], + pca_result[other_mask, 1], + alpha=0.1, + color="gray", + label="Other cells", + s=10, + ) + + # Plot tracks of interest with decreasing opacity + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + track_points = pca_result[track_mask] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points = track_points[sort_idx] + track_times = track_times[sort_idx] + + # Apply time window if specified + if seed_timepoint is not None: + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + else: + time_mask = np.ones_like(track_times, dtype=bool) # Use all points + + if np.any(time_mask): # Only plot if there are points in the window + window_points = track_points[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points)): + ax2.scatter( + window_points[idx, 0], + window_points[idx, 1], + color=track_color_map[track_id], + alpha=alphas[idx], + s=50, + label=( + f"Track {track_id}" if idx == len(window_points) - 1 else None + ), + ) + + ax2.set_xlabel("First Principal Component") + ax2.set_ylabel("Second Principal Component") + title = f"First Two Principal Components - {mode}" + if seed_timepoint is not None: + title += f"\nTime window: {seed_timepoint}±{time_window}" + ax2.set_title(title) + ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + + plt.tight_layout() + plt.show() + + # Pairwise component plots + fig, axes = plt.subplots(n_components, n_components, figsize=(20, 20)) + + for i in range(n_components): + for j in range(n_components): + if i != j: + # Plot other points first + axes[i, j].scatter( + pca_result[other_mask, j], + pca_result[other_mask, i], + alpha=0.1, + color="gray", + s=5, + ) + + # Plot each track with decreasing opacity + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + track_points_j = pca_result[track_mask, j] + track_points_i = pca_result[track_mask, i] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points_j = track_points_j[sort_idx] + track_points_i = track_points_i[sort_idx] + track_times = track_times[sort_idx] + + # Select points within the time window + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + if np.any(time_mask): # Only plot if there are points in the window + window_points_j = track_points_j[time_mask] + window_points_i = track_points_i[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points_j)): + axes[i, j].scatter( + window_points_j[idx], + window_points_i[idx], + color=track_color_map[track_id], + alpha=alphas[idx], + s=30, + ) + + axes[i, j].set_xlabel(f"PC{j+1}") + axes[i, j].set_ylabel(f"PC{i+1}") + else: + # On diagonal, show distribution + sns.histplot( + pca_result[other_mask, i], ax=axes[i, i], color="gray", alpha=0.3 + ) + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + # For histograms, use all points in the time window + time_mask = ( + time_points[track_mask] >= seed_timepoint - time_window + ) & (time_points[track_mask] <= seed_timepoint + time_window) + if np.any(time_mask): + sns.histplot( + pca_result[track_mask][time_mask, i], + ax=axes[i, i], + color=track_color_map[track_id], + alpha=0.5, + ) + axes[i, i].set_xlabel(f"PC{i+1}") + + plt.tight_layout() + plt.show() + + # Print variance explained + print("\nExplained variance ratio by component:") + for i, var in enumerate(explained_variance_ratio): + print(f"PC{i+1}: {var:.3f} ({cumulative_variance_ratio[i]:.3f} cumulative)") + + # Add analysis of PC loadings + pc_analysis = analyze_pc_loadings(pca) + print("\nPC Loading Analysis:") + print(pc_analysis.to_string(index=False)) + + # Add analysis of track clustering + cluster_analysis = analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + 1 if annotation_path is None else phenotype_of_interest, + seed_timepoint, + time_window, + ) + + if cluster_analysis: + print("\nTrack Clustering Analysis:") + print(f"Number of tracks in window: {cluster_analysis['n_tracks']}") + print( + f"Mean distance between tracks: {cluster_analysis['mean_inter_track_distance']:.3f}" + ) + print( + f"Mean spread within tracks: {cluster_analysis['mean_intra_track_spread']:.3f}" + ) + print( + f"Clustering ratio (inter/intra): {cluster_analysis['clustering_ratio']:.3f}" + ) + print("(Lower clustering ratio suggests tighter clustering)") + + # Add distribution analysis + dist_analysis = analyze_pc_distributions( + pca_result, + labels, + 1 if annotation_path is None else phenotype_of_interest, + time_points if seed_timepoint is not None else None, + seed_timepoint, + time_window, + ) + print("\nPC Distribution Analysis:") + print( + "(Separation score > 1 suggests good separation between selected tracks and background)" + ) + print(dist_analysis.to_string(index=False)) + + return ( + pca, + pca_result, + explained_variance_ratio, + labels, + tracks_of_interest, + pc_analysis, + cluster_analysis, + dist_analysis, + ) + + +# %% +if __name__ == "__main__": + embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" + # %% + # Using phenotype annotations with specific FOVs + print("\nAnalyzing phenotype 1 in specific FOVs:") + ( + pca, + pca_result, + variance_ratio, + labels, + tracks, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path=annotation_path, + phenotype_of_interest=1, + seed_timepoint=55, + time_window=10, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + ) + + # Using random tracks from specific FOVs + print("\nAnalyzing random tracks from specific FOVs:") + ( + pca, + pca_result, + variance_ratio, + labels, + tracks, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path=None, # This triggers random track selection + n_random_tracks=10, + seed_timepoint=55, + time_window=30, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + ) + +# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py new file mode 100644 index 000000000..cefd8b6e6 --- /dev/null +++ b/viscy/representation/evaluation/distance.py @@ -0,0 +1,460 @@ +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from numpy.typing import NDArray +from scipy.optimize import minimize_scalar +from scipy.stats import gaussian_kde +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) + + +def calculate_distance_cell( + embedding_dataset, + fov_name, + track_id, + metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", +): + """ + Calculate distances between a cell's first timepoint embedding and all its subsequent embeddings. + + This function extracts embeddings for a specific cell (identified by fov_name and track_id) + and calculates the distance between its first timepoint embedding and all subsequent timepoints + using the specified distance metric. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing the embeddings and metadata. Must have dimensions for 'features', + 'fov_name', 'track_id', and 't' (time). + fov_name : str + Field of view name to identify the specific imaging area. + track_id : int + Track ID of the cell to analyze. + metric : {'cosine', 'euclidean', 'normalized_euclidean'}, default='cosine' + Distance metric to use for calculations: + - 'cosine': Cosine similarity between embeddings + - 'euclidean': Standard Euclidean distance + - 'normalized_euclidean': Euclidean distance between L2-normalized embeddings + + Returns + ------- + time_points : numpy.ndarray + Array of time points corresponding to the calculated distances. + distances : list + List of distances between the first timepoint embedding and each subsequent + timepoint embedding, calculated using the specified metric. + + Notes + ----- + For 'normalized_euclidean', embeddings are L2-normalized before distance calculation. + Cosine similarity results in values between -1 and 1, where 1 indicates identical + direction, 0 indicates orthogonality, and -1 indicates opposite directions. + Euclidean distances are always non-negative. + + Examples + -------- + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="cosine") + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="euclidean") + """ + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + + if metric == "normalized_euclidean": + features = features / np.linalg.norm(features, axis=1, keepdims=True) + + first_time_point_embedding = features[0].reshape(1, -1) + + if metric == "cosine": + distances = cosine_similarity(first_time_point_embedding, features).flatten() + else: # both euclidean and normalized_euclidean use norm + distances = np.linalg.norm(first_time_point_embedding - features, axis=1) + + return time_points, distances.tolist() + + +def compute_displacement( + embedding_dataset, + distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", + max_delta_t: int = None, +) -> Dict[int, List[float]]: + """Compute displacements between embeddings at different time differences. + + For each time difference τ, computes distances between embeddings of the same cell + separated by τ timepoints. Supports multiple distance metrics. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing embeddings and metadata with the following variables: + - features: (N, D) array of embeddings + - fov_name: (N,) array of field of view names + - track_id: (N,) array of cell track IDs + - t: (N,) array of timepoints + distance_metric : str, optional + The metric to use for computing distances between embeddings. + Valid options are: + - "euclidean_squared": Squared Euclidean distance (default) + - "cosine": Cosine similarity + max_delta_t : int, optional + Maximum time difference τ to compute displacements for. + If None, uses the maximum possible time difference in the dataset. + + Returns + ------- + Dict[int, List[float]] + Dictionary mapping time difference τ to list of displacements. + Each displacement value represents the distance between a pair of + embeddings from the same cell separated by τ timepoints. + """ + + # Get data from dataset + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # Check if max_delta_t is provided, otherwise use the maximum timepoint + if max_delta_t is None: + max_delta_t = timepoints.max() + + displacement_per_delta_t = defaultdict(list) + # Process each sample + for i in tqdm(range(len(fov_names)), desc="Processing FOVs"): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i].reshape(1, -1) + + # Compute displacements for each delta t + for delta_t in range(1, max_delta_t + 1): + future_time = current_time + delta_t + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + if distance_metric == "euclidean_squared": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = np.sum((current_embedding - future_embedding) ** 2) + elif distance_metric == "cosine": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = cosine_similarity( + current_embedding, future_embedding + ) + displacement_per_delta_t[delta_t].append(displacement) + return dict(displacement_per_delta_t) + + +def compute_displacement_statistics( + displacement_per_delta_t: Dict[int, List[float]] +) -> Tuple[Dict[int, float], Dict[int, float]]: + """Compute mean and standard deviation of displacements for each delta_t. + + Parameters + ---------- + displacement_per_delta_t : Dict[int, List[float]] + Dictionary mapping τ to list of displacements + + Returns + ------- + Tuple[Dict[int, float], Dict[int, float]] + Tuple of (mean_displacements, std_displacements) where each is a + dictionary mapping τ to the statistic + """ + mean_displacement_per_delta_t = { + delta_t: np.mean(displacements) + for delta_t, displacements in displacement_per_delta_t.items() + } + std_displacement_per_delta_t = { + delta_t: np.std(displacements) + for delta_t, displacements in displacement_per_delta_t.items() + } + return mean_displacement_per_delta_t, std_displacement_per_delta_t + + +def compute_dynamic_range(mean_displacement_per_delta_t): + """ + Compute the dynamic range as the difference between the maximum + and minimum mean displacement per τ. + + Parameters: + mean_displacement_per_delta_t: dict with τ as key and mean displacement as value + + Returns: + float: dynamic range (max displacement - min displacement) + """ + displacements = list(mean_displacement_per_delta_t.values()) + return max(displacements) - min(displacements) + + +def compute_rms_per_track(embedding_dataset): + """ + Compute RMS of the time derivative of embeddings per track. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + + Returns: + list: A list of RMS values, one for each track. + """ + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + cell_identifiers = np.array( + list(zip(fov_names, track_ids)), + dtype=[("fov_name", "O"), ("track_id", "int64")], + ) + unique_cells = np.unique(cell_identifiers) + + rms_values = [] + + for cell in unique_cells: + fov_name = cell["fov_name"] + track_id = cell["track_id"] + indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] + cell_timepoints = timepoints[indices] + cell_embeddings = embeddings[indices] + + if len(cell_embeddings) < 2: + continue + + sorted_indices = np.argsort(cell_timepoints) + cell_embeddings = cell_embeddings[sorted_indices] + differences = np.diff(cell_embeddings, axis=0) + + if differences.shape[0] == 0: + continue + + norms = np.linalg.norm(differences, axis=1) + rms = np.sqrt(np.mean(norms**2)) + rms_values.append(rms) + + return rms_values + + +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. + + Args: + data: Array of values to find the peak for + + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" + ) + return result.x + + +def compute_piece_wise_dissimilarity( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def compute_embedding_distances( + prediction_path: Path, + output_path: Path, + distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", + verbose: bool = False, +) -> pd.DataFrame: + """ + Compute and save pairwise distances between embeddings. + + Parameters + ---------- + prediction_path : Path + Path to the embedding dataset + output_path : Path + name of saved CSV file + distance_metric : str, optional + Distance metric to use for computing distances between embeddings + verbose : bool, optional + If True, plots the distance matrix visualization + + Returns + ------- + pd.DataFrame + DataFrame containing the adjacent frame and random sampling distances + """ + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + if distance_metric != "euclidean": + features = StandardScaler().fit_transform(features.values) + + # Compute the distance matrix + cross_dist = pairwise_distance_matrix(features, metric=distance_metric) + + # Normalize by sqrt of embedding dimension if using euclidean distance + if distance_metric == "euclidean": + cross_dist /= np.sqrt(features.shape[1]) + + if verbose: + # Plot the distance matrix + plt.figure(figsize=(10, 10)) + plt.imshow(cross_dist, cmap="viridis") + plt.colorbar(label=f"{distance_metric.capitalize()} Distance") + plt.title(f"{distance_metric.capitalize()} Distance Matrix") + plt.tight_layout() + plt.show() + + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise dissimilarity and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) + ) + + all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) + + # Random sampling values in the dissimilarity matrix + n_samples = len(all_dissimilarity) + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + # Create and save DataFrame + distributions_df = pd.DataFrame( + { + "adjacent_frame": pd.Series(all_dissimilarity), + "random_sampling": pd.Series(sampled_values), + } + ) + + csv_path = output_path + distributions_df.to_csv(csv_path, index=False) + + return distributions_df + + +def analyze_and_plot_distances( + distributions_df: pd.DataFrame, + output_file_path: Optional[str], + overwrite: bool = False, +) -> dict: + """ + Analyze distance distributions and create visualization plots. + + Parameters + ---------- + distributions_df : pd.DataFrame + DataFrame containing 'adjacent_frame' and 'random_sampling' columns + output_file_path : str, optional + Path to save the plot ideally with a .pdf extension. Uses `plt.savefig()` + overwrite : bool, default=False + If True, overwrites existing files + + Returns + ------- + dict + Dictionary containing computed metrics including means, standard deviations, + medians, peaks, and dynamic range of the distributions + """ + # Compute statistics + adjacent_dist = distributions_df["adjacent_frame"].values + random_dist = distributions_df["random_sampling"].values + + # Compute peaks + adjacent_peak = float(find_distribution_peak(adjacent_dist)) + random_peak = float(find_distribution_peak(random_dist)) + dynamic_range = float(random_peak - adjacent_peak) + + metrics = { + "dissimilarity_mean": float(np.mean(adjacent_dist)), + "dissimilarity_std": float(np.std(adjacent_dist)), + "dissimilarity_median": float(np.median(adjacent_dist)), + "dissimilarity_peak": adjacent_peak, + "dissimilarity_p99": float(np.percentile(adjacent_dist, 99)), + "dissimilarity_p1": float(np.percentile(adjacent_dist, 1)), + "random_mean": float(np.mean(random_dist)), + "random_std": float(np.std(random_dist)), + "random_median": float(np.median(random_dist)), + "random_peak": random_peak, + "dynamic_range": dynamic_range, + } + + # Create plot + fig = plt.figure() + sns.histplot( + data=distributions_df, + x="adjacent_frame", + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + data=distributions_df, + x="random_sampling", + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", + ) + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + plt.axvline(x=adjacent_peak, color="cyan", linestyle="--", alpha=0.8) + plt.axvline(x=random_peak, color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + if output_file_path.exists() and not overwrite: + raise FileExistsError( + f"File {output_file_path} already exists and overwrite=False" + ) + fig.savefig(output_file_path, dpi=600) + plt.show() + + return metrics From 1212a5a80bb7174dabe1ea3bd6cf18769ed3bfd1 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 29 Jan 2025 18:12:32 -0800 Subject: [PATCH 32/38] modify the shuffling of indices and adding concatdatasets instead of randomsampler --- viscy/data/hcs.py | 9 +- viscy/data/tarrow.py | 120 +++++++++------- viscy/representation/timearrow.py | 220 +++++++++++++----------------- 3 files changed, 170 insertions(+), 179 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 88111cc37..2170ebf13 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -25,6 +25,7 @@ from torch.utils.data import DataLoader, Dataset from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample +from viscy.utils.engine_state import set_fit_global_state _logger = logging.getLogger("lightning.pytorch") @@ -426,12 +427,6 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # shuffle positions, randomness is handled globally - return torch.randperm(num_positions) - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" train_transform, val_transform = self._fit_transform() @@ -441,7 +436,7 @@ def _setup_fit(self, dataset_settings: dict): # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] - shuffled_indices = self._set_fit_global_state(len(positions)) + shuffled_indices = set_fit_global_state(len(positions)) positions = list(positions[i] for i in shuffled_indices) num_train_fovs = int(len(positions) * self.split_ratio) # training set needs to sample more Z range for augmentation diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index df443581b..d6cf423c6 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -1,11 +1,13 @@ from pathlib import Path +from typing import Callable import numpy as np -import torch from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from tarrow.data.tarrow_dataset import TarrowDataset -from torch.utils.data import DataLoader +from torch.utils.data import ConcatDataset, DataLoader + +from viscy.utils.engine_state import set_fit_global_state class TarrowDataModule(LightningDataModule): @@ -19,6 +21,8 @@ class TarrowDataModule(LightningDataModule): Name of the channel to load train_split : float, default=0.8 Fraction of data to use for training (0.0 to 1.0) + patch_size : tuple[int, int], default=(128, 128) + Patch size for TarrowDataset batch_size : int, default=16 Batch size for dataloaders num_workers : int, default=8 @@ -33,6 +37,8 @@ class TarrowDataModule(LightningDataModule): Number of validation samples per epoch resolution : int, default=0 Resolution level to load from OME-Zarr + normalization : function, optional (default=None) + Normalization function to apply to images z_slice : int, default=0 Z-slice to load pin_memory : bool, default=True @@ -50,12 +56,14 @@ def __init__( train_split: float = 0.8, batch_size: int = 16, num_workers: int = 8, + patch_size: tuple[int, int] = (128, 128), prefetch_factor: int | None = None, include_fov_names: list[str] = [], train_samples_per_epoch: int = 100000, val_samples_per_epoch: int = 10000, resolution: int = 0, z_slice: int = 0, + normalization: Callable[[np.ndarray], np.ndarray] | None = None, pin_memory: bool = True, persistent_workers: bool = True, **kwargs, @@ -67,12 +75,17 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers self.prefetch_factor = prefetch_factor + self.path_size = patch_size self.include_fov_names = include_fov_names self.train_samples_per_epoch = train_samples_per_epoch self.val_samples_per_epoch = val_samples_per_epoch self.resolution = resolution self.z_slice = z_slice self.kwargs = kwargs + self.normalization = normalization + + self._filter_positions() + self._channel_idx = self._get_channel_index() def _get_channel_index(self, plate) -> int: """Get the index of the specified channel from the plate metadata. @@ -102,15 +115,13 @@ def _get_channel_index(self, plate) -> int: f"Channel '{self.channel_name}' not found. Available channels: {available_channels}" ) - def _load_images( - self, positions: list[Position], channel_idx: int - ) -> list[np.ndarray]: + def _load_images(self, position: Position, channel_idx: int) -> list[np.ndarray]: """Load all images from positions into memory. Parameters ---------- - positions : list[Position] - List of positions to load + position : Position + Position to load channel_idx : int Index of channel to load @@ -120,11 +131,10 @@ def _load_images( List of 2D numpy arrays """ imgs = [] - for pos in positions: - img_arr = pos[str(self.resolution)] - # Load all timepoints for this position - for t in range(len(img_arr)): - imgs.append(img_arr[t, channel_idx, self.z_slice]) + img_arr = position[str(self.resolution)] + # Load all timepoints for this position + for t in range(len(img_arr)): + imgs.append(img_arr[t, channel_idx, self.z_slice]) return imgs def setup(self, stage: str): @@ -140,41 +150,34 @@ def setup(self, stage: str): NotImplementedError If stage is not "fit" """ - plate = open_ome_zarr(self.ome_zarr_path, mode="r") # Get channel index once - channel_idx = self._get_channel_index(plate) - # Get the positions to load - if self.include_fov_names: - positions = [] - for fov_str, pos in plate.positions(): - normalized_include_fovs = [ - f.lstrip("/") for f in self.include_fov_names - ] - if fov_str in normalized_include_fovs: - positions.append(pos) - else: - positions = [pos for _, pos in plate.positions()] + if stage == "fit": + list_dataset = [] + for pos in self.positions: + pos_imgs = self._load_images(pos, self._channel_idx) + list_dataset.append( + TarrowDataset( + imgs=pos_imgs, + normalize=self.normalization, + size=self.path_size, + **self.kwargs, + ) + ) - # Load all images into memory using the pre-determined channel index - imgs = self._load_images(positions, channel_idx) + # Calculate split point + split_idx = int(len(self.positions) * self.train_split) - # Calculate split point - split_idx = int(len(imgs) * self.train_split) + # Shuffle the list of datasets + shuffled_indices = set_fit_global_state(len(list_dataset)) + list_dataset = [list_dataset[i] for i in shuffled_indices] - if stage == "fit": # Create training dataset with first train_split% of images - self.train_dataset = TarrowDataset( - imgs=imgs[:split_idx], - **self.kwargs, - ) + self.train_dataset = ConcatDataset(list_dataset[:split_idx]) # Create validation dataset with remaining images - self.val_dataset = TarrowDataset( - imgs=imgs[split_idx:], - **{k: v for k, v in self.kwargs.items() if k != "augmenter"}, - ) + self.val_dataset = ConcatDataset(list_dataset[split_idx:]) elif stage == "test": raise NotImplementedError(f"Invalid stage: {stage}") @@ -183,25 +186,45 @@ def setup(self, stage: str): else: raise NotImplementedError(f"Invalid stage: {stage}") + def _filter_positions(self): + """Filter positions based on include_fov_names.""" + # Get the positions to load + plate = open_ome_zarr(self.ome_zarr_path, mode="r") + if self.include_fov_names: + positions = [] + for fov_str, pos in plate.positions(): + normalized_include_fovs = [ + f.lstrip("/") for f in self.include_fov_names + ] + if fov_str in normalized_include_fovs: + positions.append(pos) + else: + positions = [pos for _, pos in plate.positions()] + + self.positions = positions + + def _get_channel_index(self): + """Get the index of the specified channel from the plate metadata.""" + with open_ome_zarr(self.ome_zarr_path, mode="r") as plate: + _, first_pos = next(plate.positions()) + return first_pos.channel_names.index(self.channel_name) + def train_dataloader(self): """Create the training dataloader. Returns ------- torch.utils.data.DataLoader - DataLoader for training data with random sampling + DataLoader for training data """ return DataLoader( self.train_dataset, - sampler=torch.utils.data.RandomSampler( - self.train_dataset, - replacement=True, - num_samples=self.train_samples_per_epoch, - ), batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=True, + shuffle=True, ) def val_dataloader(self): @@ -210,19 +233,16 @@ def val_dataloader(self): Returns ------- torch.utils.data.DataLoader - DataLoader for validation data with random sampling + DataLoader for validation data """ return DataLoader( self.val_dataset, - sampler=torch.utils.data.RandomSampler( - self.val_dataset, - replacement=True, - num_samples=self.val_samples_per_epoch, - ), batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=True, + shuffle=False, ) def test_dataloader(self): diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index 6333a3075..077efe143 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -1,13 +1,18 @@ +import logging +from typing import Literal, Sequence + +import numpy as np import torch import torch.nn as nn -import torchvision from lightning.pytorch import LightningModule -from lightning.pytorch.callbacks import Callback from tarrow.models import TimeArrowNet from tarrow.models.losses import DecorrelationLoss from torch.optim import Adam from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau -from torch.utils.data import DataLoader + +from viscy.utils.log_images import render_images + +logger = logging.getLogger(__name__) class TarrowModule(LightningModule): @@ -41,6 +46,10 @@ class TarrowModule(LightningModule): Patience for learning rate scheduler cam_size : tuple or int, optional Size of the class activation map (H, W). If None, use input size. + log_batches_per_epoch : int, default=8 + Number of batches to log samples from during training + log_samples_per_batch : int, default=1 + Number of samples to log from each batch """ def __init__( @@ -58,11 +67,18 @@ def __init__( lr_scheduler="cyclic", lr_patience=50, cam_size=None, + log_batches_per_epoch=8, + log_samples_per_batch=1, **kwargs, ): super().__init__() self.save_hyperparameters() + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + self.training_step_outputs = [] + self.validation_step_outputs = [] + self.model = TimeArrowNet( backbone=backbone, projection_head=projection_head, @@ -76,6 +92,48 @@ def __init__( self.criterion = nn.CrossEntropyLoss(reduction="none") self.criterion_decorr = DecorrelationLoss() + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + """Log sample images to TensorBoard. + + Parameters + ---------- + key : str + Key for logging + imgs : Sequence[Sequence[np.ndarray]] + List of image pairs to log + """ + grid = render_images(imgs, cmaps=["gray"] * 2) # Only 2 timepoints + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def _log_step_samples(self, batch_idx, images, stage: Literal["train", "val"]): + """Log samples from a batch. + + Parameters + ---------- + batch_idx : int + Index of current batch + images : torch.Tensor + Batch of images with shape (B, T, C, H, W) + stage : str + Either "train" or "val" + """ + if batch_idx < self.log_batches_per_epoch: + # Get first n samples from batch + n = min(self.log_samples_per_batch, images.shape[0]) + samples = images[:n].detach().cpu().numpy() + + # Split into pairs of timepoints + pairs = [(sample[0], sample[1]) for sample in samples] + + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + output_list.extend(pairs) + def forward(self, x): """Forward pass through the model. @@ -111,6 +169,10 @@ def _shared_step(self, batch, batch_idx, step="train"): Combined loss (classification + decorrelation) """ x, y = batch + + # Log sample images + self._log_step_samples(batch_idx, x, step) + out, pro = self(x) if out.ndim > 2: @@ -170,134 +232,48 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, - "monitor": "val_loss", + "monitor": "loss/val_loss", + "interval": "epoch", }, } elif self.hparams.lr_scheduler == "cyclic": + # Get dataloader length accounting for DDP + dataloader = self.trainer.datamodule.train_dataloader() + steps_per_epoch = len(dataloader) + + # Account for gradient accumulation and multiple GPUs + if self.trainer.accumulate_grad_batches: + steps_per_epoch = ( + steps_per_epoch // self.trainer.accumulate_grad_batches + ) + + total_steps = steps_per_epoch * self.trainer.max_epochs + scheduler = CyclicLR( optimizer, base_lr=self.hparams.learning_rate, max_lr=self.hparams.learning_rate * 10, cycle_momentum=False, - step_size_up=self.trainer.estimated_stepping_batches, + step_size_up=total_steps // 2, # Half the total steps for one cycle scale_mode="cycle", scale_fn=lambda x: 0.9**x, ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} - - def embedding(self, x): - """Get dense embeddings from the model. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape (batch_size, n_frames, channels, height, width) - - Returns - ------- - torch.Tensor - Dense embeddings from the backbone network - """ - return self.model.embedding(x) - - -class TarrowVisualizationCallback(Callback): - """Callback for visualizing cells and embeddings in TensorBoard. - - Parameters - ---------- - dataset : Dataset - Dataset to visualize - max_samples : int, default=100 - Maximum number of samples to visualize - log_every_n_epochs : int, default=3 - How often to log visualizations - cam_size : tuple or int, optional - Size for class activation maps. If None, use original size - """ - - def __init__(self, dataset, max_samples=100, log_every_n_epochs=3, cam_size=None): - """ - Parameters - ---------- - dataset : Dataset - Dataset to visualize - max_samples : int, default=100 - Maximum number of samples to visualize - log_every_n_epochs : int, default=3 - How often to log visualizations - cam_size : tuple or int, optional - Size for class activation maps. If None, use original size - """ - super().__init__() - self.dataset = dataset - self.max_samples = max_samples - self.log_every_n_epochs = log_every_n_epochs - self.cam_size = cam_size - - def on_train_epoch_end(self, trainer, pl_module): - if (trainer.current_epoch + 1) % self.log_every_n_epochs == 0: - # Get samples from dataset - loader = DataLoader( - self.dataset, - batch_size=min(32, self.max_samples), - shuffle=True, - ) - batch = next(iter(loader)) - images, labels = batch - images = images.to(pl_module.device) - - # Get embeddings - with torch.no_grad(): - embeddings = pl_module.embedding(images) - out, _ = pl_module(images) - preds = torch.argmax(out, dim=1) - - # Log images - grid = torchvision.utils.make_grid( - images[:, 0], # First timepoint - nrow=8, - normalize=True, - value_range=(images.min(), images.max()), - ) - trainer.logger.experiment.add_image( - "cells/timepoint1", - grid, - trainer.current_epoch, - ) - - grid = torchvision.utils.make_grid( - images[:, 1], # Second timepoint - nrow=8, - normalize=True, - value_range=(images.min(), images.max()), - ) - trainer.logger.experiment.add_image( - "cells/timepoint2", - grid, - trainer.current_epoch, - ) - - # Log embeddings - trainer.logger.experiment.add_embedding( - embeddings.reshape(len(embeddings), -1), - metadata=[ - f"label={l.item()}, pred={p.item()}" for l, p in zip(labels, preds) - ], - label_img=images[:, 0], # Use first timepoint as label image - global_step=trainer.current_epoch, - ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + }, + } - # Log CAMs if cam_size is provided - if self.cam_size is not None and hasattr(pl_module.model, "get_cam"): - cam = pl_module.model.get_cam(images, size=self.cam_size) - grid = torchvision.utils.make_grid( - cam.unsqueeze(1), # Add channel dimension - nrow=8, - normalize=True, - ) - trainer.logger.experiment.add_image( - "cells/cam", - grid, - trainer.current_epoch, - ) + def on_train_epoch_end(self): + """Log collected training samples at end of epoch.""" + if self.training_step_outputs: + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self): + """Log collected validation samples at end of epoch.""" + if self.validation_step_outputs: + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] From 4d3e46d23965eff304bb4e3fef7c3ca1358461ba Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 30 Jan 2025 12:29:41 -0800 Subject: [PATCH 33/38] adding gradcam prototype --- tests/representation/test_gradcam.py | 181 +++++++++++++++++++++++++++ viscy/callbacks/gradcam.py | 103 +++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 tests/representation/test_gradcam.py create mode 100644 viscy/callbacks/gradcam.py diff --git a/tests/representation/test_gradcam.py b/tests/representation/test_gradcam.py new file mode 100644 index 000000000..bcad19c16 --- /dev/null +++ b/tests/representation/test_gradcam.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from lightning.pytorch import LightningModule, Trainer +from torch.utils.data import DataLoader +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.callbacks.gradcam import GradCAMCallback + + +class ResNetClassifier(LightningModule): + def __init__(self, num_classes=10): + super().__init__() + # Load pretrained ResNet18 + self.model = torchvision.models.resnet18(pretrained=True) + + # Replace final layer for CIFAR-10 + self.model.fc = nn.Linear(512, num_classes) + + # Save the target layer for GradCAM + self.target_layer = self.model.layer4[-1] + + # Ensure gradients are enabled for the target layer + for param in self.target_layer.parameters(): + param.requires_grad = True + + self.gradients = None + self.activations = None + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("val_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + # GradCAM methods + def activations_hook(self, grad): + self.gradients = grad + + def get_activations(self, x): + return self.activations + + def gradcam(self, x): + # Store original training mode and switch to eval mode + was_training = self.training + self.eval() # Use eval mode for inference + + try: + # Register hooks + h = self.target_layer.register_forward_hook( + lambda module, input, output: setattr(self, "activations", output) + ) + h_bp = self.target_layer.register_backward_hook( + lambda module, grad_in, grad_out: self.activations_hook(grad_out[0]) + ) + + # Forward pass + x = x.unsqueeze(0).to(self.device) # Add batch dimension + + # Enable gradients for the entire computation + with torch.enable_grad(): + x = x.requires_grad_(True) + output = self(x) + + # Get predicted class + pred = output.argmax(dim=1) + + # Create one hot vector for backward pass + one_hot = torch.zeros_like(output, device=self.device) + one_hot[0][pred] = 1 + + # Clear gradients + self.zero_grad(set_to_none=False) + + # Backward pass + output.backward(gradient=one_hot) + + # Generate GradCAM + gradients = self.gradients + activations = self.activations + + # Ensure we have valid gradients + if gradients is None: + raise RuntimeError("No gradients available for GradCAM computation") + + weights = torch.mean(gradients, dim=(2, 3)) + cam = torch.sum(weights[:, :, None, None] * activations, dim=1) + cam = F.relu(cam) + cam = ( + F.interpolate( + cam.unsqueeze(0), + size=x.shape[2:], + mode="bilinear", + align_corners=False, + )[0, 0] + .cpu() + .detach() + .numpy() + ) + + return cam + + finally: + # Clean up + h.remove() + h_bp.remove() + # Restore original training mode + self.train(mode=was_training) + + +def main(): + # Data transforms + transform = transforms.Compose( + [ + transforms.Resize(224), # ResNet expects 224x224 images + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Load CIFAR-10 dataset + train_dataset = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + + val_dataset = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + + # Create small visualization dataset + vis_dataset = torch.utils.data.Subset(val_dataset, indices=range(10)) + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=32) + + # Initialize model + model = ResNetClassifier() + + # Initialize callbacks + gradcam_callback = GradCAMCallback( + visual_datasets=[vis_dataset], + every_n_epochs=1, # Generate visualizations every epoch + max_samples=5, # Visualize 5 samples + max_height=224, # Match ResNet input size + ) + + # Initialize trainer with specific logger + trainer = Trainer( + max_epochs=5, + callbacks=[gradcam_callback], + accelerator="auto", + devices=1, + logger=TensorBoardLogger( + save_dir="your/log/path", # specify your desired log directory + name="gradcam_experiment", # experiment name + ), + ) + + # Train model + trainer.fit(model, train_loader, val_loader) + + +if __name__ == "__main__": + main() diff --git a/viscy/callbacks/gradcam.py b/viscy/callbacks/gradcam.py new file mode 100644 index 000000000..589662474 --- /dev/null +++ b/viscy/callbacks/gradcam.py @@ -0,0 +1,103 @@ +from typing import List + +import torch +import torchvision +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback + + +class GradCAMCallback(Callback): + """Callback for computing and logging GradCAM visualizations. + + Parameters + ---------- + visual_datasets : list + List of datasets to generate visualizations from + every_n_epochs : int, default=10 + Generate visualizations every n epochs + max_samples : int, default=5 + Maximum number of samples to visualize per dataset + max_height : int, default=720 + Maximum height of output visualization + """ + + def __init__( + self, + visual_datasets: List, + every_n_epochs: int = 10, + max_samples: int = 5, + max_height: int = 720, + ): + super().__init__() + self.visual_datasets = visual_datasets + self.every_n_epochs = every_n_epochs + self.max_samples = max_samples + self.max_height = max_height + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Generate and log GradCAM visualizations""" + if (trainer.current_epoch + 1) % self.every_n_epochs != 0: + return + + pl_module.eval() + + for dataset_idx, dataset in enumerate(self.visual_datasets): + # Get a few samples + samples = [] + cams = [] + + for i, (x, _) in enumerate(dataset): + if i >= self.max_samples: + break + + # Move tensor to same device as model + x = x.to(pl_module.device) + + # Generate GradCAM - no need to add batch dimension here since gradcam() does it + cam = pl_module.gradcam(x) + + # Convert to RGB images for visualization + x_img = self._tensor_to_img( + x.cpu() + ) # removed [0] since x is already unbatched + cam_img = self._tensor_to_img(torch.from_numpy(cam)) + overlay = self._create_overlay(x_img, cam_img) + + samples.append(x_img) + cams.append(overlay) + + # Stack images for grid visualization + samples_grid = torchvision.utils.make_grid( + [torch.from_numpy(img) for img in samples], nrow=len(samples) + ) + cams_grid = torchvision.utils.make_grid( + [torch.from_numpy(img) for img in cams], nrow=len(cams) + ) + + # Log to tensorboard + trainer.logger.experiment.add_image( + f"gradcam/dataset_{dataset_idx}/samples", + samples_grid, + trainer.current_epoch, + ) + trainer.logger.experiment.add_image( + f"gradcam/dataset_{dataset_idx}/cams", + cams_grid, + trainer.current_epoch, + ) + + @staticmethod + def _tensor_to_img(tensor: torch.Tensor) -> torch.Tensor: + """Convert tensor to normalized image tensor""" + img = tensor.cpu().numpy() + img = (img - img.min()) / (img.max() - img.min() + 1e-7) + return img + + @staticmethod + def _create_overlay( + img: torch.Tensor, cam: torch.Tensor, alpha: float = 0.5 + ) -> torch.Tensor: + """Create overlay of image and CAM""" + return (1 - alpha) * img + alpha * cam From b56c6fd6258e06296be6a18864aa237c0eddff50 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 30 Jan 2025 17:30:57 -0800 Subject: [PATCH 34/38] adding custom gradcam to tarrow --- tests/representation/test_gradcam.py | 70 ++++++++++++++- viscy/callbacks/gradcam.py | 122 +++++++++++++++++++-------- viscy/representation/timearrow.py | 110 ++++++++++++++++++++++++ 3 files changed, 264 insertions(+), 38 deletions(-) diff --git a/tests/representation/test_gradcam.py b/tests/representation/test_gradcam.py index bcad19c16..ec919a187 100644 --- a/tests/representation/test_gradcam.py +++ b/tests/representation/test_gradcam.py @@ -6,6 +6,8 @@ from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader from lightning.pytorch.loggers import TensorBoardLogger +import matplotlib.pyplot as plt +import numpy as np from viscy.callbacks.gradcam import GradCAMCallback @@ -149,6 +151,7 @@ def main(): # Create data loaders train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32) + vis_loader = DataLoader(vis_dataset, batch_size=32) # Added visualization loader # Initialize model model = ResNetClassifier() @@ -168,14 +171,77 @@ def main(): accelerator="auto", devices=1, logger=TensorBoardLogger( - save_dir="your/log/path", # specify your desired log directory - name="gradcam_experiment", # experiment name + save_dir="/home/eduardo.hirata/repos/viscy/tests/representation/lightning_logs", # specify your desired log directory + name="gradcam_cifar", # experiment name ), ) # Train model trainer.fit(model, train_loader, val_loader) + # Test GradCAM visualization + test_gradcam_visualization(model, vis_loader) + + +def test_gradcam_visualization(model, dataloader): + """Test GradCAM visualization. + + Parameters + ---------- + model : LightningModule + The trained model + dataloader : DataLoader + DataLoader containing samples to visualize + """ + model.eval() + # Get a sample from validation set + batch = next(iter(dataloader)) + images, labels = batch + + # Generate GradCAM for first sample + sample_img = images[0] # Shape: (C, H, W) + cam = model.gradcam(sample_img) + + # Plot the results + fig, axes = plt.subplots(1, 3, figsize=(30, 10)) + + # Original image + img = images[0].squeeze().cpu().numpy() + if img.ndim == 3: # Handle RGB images + axes[0].imshow(np.transpose(img, (1, 2, 0))) + else: # Handle grayscale images + axes[0].imshow(img, cmap="gray") + axes[0].set_title("Original Image") + plt.colorbar(axes[0].images[0], ax=axes[0]) + + # GradCAM visualization + im = axes[1].imshow(cam, cmap="magma") + axes[1].set_title("GradCAM") + plt.colorbar(im, ax=axes[1]) + + # Overlay GradCAM on original image + img = images[0].squeeze().cpu().numpy() + if img.ndim == 3: # Handle RGB images + img = np.transpose(img, (1, 2, 0)) + img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0,1] + cam_norm = (cam - cam.min()) / (cam.max() - cam.min()) # Normalize to [0,1] + + # Create RGB overlay + if img.ndim == 2: # Convert grayscale to RGB + img_rgb = np.stack([img] * 3, axis=-1) + else: # Already RGB + img_rgb = img + cam_rgb = plt.cm.magma(cam_norm)[..., :3] # Convert to RGB using magma colormap + overlay = 0.7 * img_rgb + 0.3 * cam_rgb + + axes[2].imshow(overlay) + axes[2].set_title("GradCAM Overlay") + + plt.suptitle(f"GradCAM Visualization (Predicted: {labels[0].item()})", y=1.05) + plt.savefig("./gradcam_cifar.png") + plt.close() + # plt.show() + if __name__ == "__main__": main() diff --git a/viscy/callbacks/gradcam.py b/viscy/callbacks/gradcam.py index 589662474..da23f19f5 100644 --- a/viscy/callbacks/gradcam.py +++ b/viscy/callbacks/gradcam.py @@ -4,6 +4,9 @@ import torchvision from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +import numpy as np +import matplotlib.pyplot as plt +from skimage.exposure import rescale_intensity class GradCAMCallback(Callback): @@ -19,6 +22,9 @@ class GradCAMCallback(Callback): Maximum number of samples to visualize per dataset max_height : int, default=720 Maximum height of output visualization + mode : str, default="separate" + Visualization mode: "separate" for individual images and activations, + or "overlay" for activation map overlaid on input image """ def __init__( @@ -27,12 +33,15 @@ def __init__( every_n_epochs: int = 10, max_samples: int = 5, max_height: int = 720, + mode: str = "overlay", ): super().__init__() self.visual_datasets = visual_datasets self.every_n_epochs = every_n_epochs self.max_samples = max_samples self.max_height = max_height + assert mode in ["separate", "overlay"], "Mode must be 'separate' or 'overlay'" + self.mode = mode def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule @@ -42,51 +51,87 @@ def on_validation_epoch_end( return pl_module.eval() + device = pl_module.device for dataset_idx, dataset in enumerate(self.visual_datasets): # Get a few samples samples = [] - cams = [] + activations = [] for i, (x, _) in enumerate(dataset): if i >= self.max_samples: break - # Move tensor to same device as model - x = x.to(pl_module.device) - - # Generate GradCAM - no need to add batch dimension here since gradcam() does it - cam = pl_module.gradcam(x) - - # Convert to RGB images for visualization - x_img = self._tensor_to_img( - x.cpu() - ) # removed [0] since x is already unbatched - cam_img = self._tensor_to_img(torch.from_numpy(cam)) - overlay = self._create_overlay(x_img, cam_img) - - samples.append(x_img) - cams.append(overlay) - - # Stack images for grid visualization - samples_grid = torchvision.utils.make_grid( - [torch.from_numpy(img) for img in samples], nrow=len(samples) - ) - cams_grid = torchvision.utils.make_grid( - [torch.from_numpy(img) for img in cams], nrow=len(cams) - ) - - # Log to tensorboard - trainer.logger.experiment.add_image( - f"gradcam/dataset_{dataset_idx}/samples", - samples_grid, - trainer.current_epoch, - ) - trainer.logger.experiment.add_image( - f"gradcam/dataset_{dataset_idx}/cams", - cams_grid, - trainer.current_epoch, - ) + try: + # Move tensor to same device as model + x = x.to(device) + + # Generate class activation map + activation_map = pl_module.gradcam(x) + + # Convert to RGB images for visualization + x_img = x[0].cpu().numpy() # Take first timepoint + if x_img.ndim == 3: # Handle (C,H,W) case + x_img = x_img[0] # Take first channel + x_img = rescale_intensity(x_img, in_range="image", out_range=(0, 1)) + + # Create activation map visualization + activation_norm = self._normalize_cam( + torch.from_numpy(activation_map) + ) + activation_rgb = plt.cm.magma(activation_norm.numpy())[..., :3] + + if self.mode == "separate": + # Keep sample as grayscale + x_vis = torch.from_numpy(x_img).unsqueeze(0).float() + activation_vis = ( + torch.from_numpy(activation_rgb).permute(2, 0, 1).float() + ) + else: # overlay mode + # Convert input to RGB + x_rgb = np.stack([x_img] * 3, axis=-1) + # Create overlay + overlay = self._create_overlay(x_rgb, activation_rgb) + x_vis = torch.from_numpy(x_rgb).permute(2, 0, 1).float() + activation_vis = ( + torch.from_numpy(overlay).permute(2, 0, 1).float() + ) + + samples.append(x_vis.cpu()) # Ensure on CPU for visualization + activations.append( + activation_vis.cpu() + ) # Ensure on CPU for visualization + + except Exception as e: + print(f"Error processing sample {i}: {str(e)}") + continue + + if samples: # Only proceed if we have samples + try: + # Stack images for grid visualization + samples_grid = torchvision.utils.make_grid( + samples, nrow=len(samples), normalize=True, value_range=(0, 1) + ) + activations_grid = torchvision.utils.make_grid( + activations, + nrow=len(activations), + normalize=True, + value_range=(0, 1), + ) + + # Log to tensorboard + trainer.logger.experiment.add_image( + f"gradcam/dataset_{dataset_idx}/samples", + samples_grid, + trainer.current_epoch, + ) + trainer.logger.experiment.add_image( + f"gradcam/dataset_{dataset_idx}/{'overlays' if self.mode == 'overlay' else 'activations'}", + activations_grid, + trainer.current_epoch, + ) + except Exception as e: + print(f"Error creating visualization grid: {str(e)}") @staticmethod def _tensor_to_img(tensor: torch.Tensor) -> torch.Tensor: @@ -101,3 +146,8 @@ def _create_overlay( ) -> torch.Tensor: """Create overlay of image and CAM""" return (1 - alpha) * img + alpha * cam + + @staticmethod + def _normalize_cam(cam: torch.Tensor) -> torch.Tensor: + """Normalize CAM to [0,1]""" + return (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index 077efe143..e61af2651 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -277,3 +277,113 @@ def on_validation_epoch_end(self): if self.validation_step_outputs: self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + + def gradcam(self, x, **kwargs): + """Generate GradCAM visualization for the projection layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (T, C, H, W) + **kwargs : dict + Additional arguments passed to model's gradcam method + + Returns + ------- + numpy.ndarray + GradCAM visualization + """ + # Store training mode and switch to eval + was_training = self.training + self.eval() + + # Store gradients and activations + self.gradients = None + self.activations = None + + # Register hooks + def save_gradients(grad): + self.gradients = grad + + def save_activations(module, input, output): + self.activations = output + + # Get the target layer (last conv layer of backbone) + target_layer = None + for module in self.model.backbone.modules(): + if isinstance(module, nn.Conv2d): + target_layer = module + + if target_layer is None: + raise RuntimeError("Could not find suitable layer for GradCAM") + + # Register hooks + h = target_layer.register_forward_hook(save_activations) + h_bp = target_layer.register_backward_hook( + lambda m, grad_in, grad_out: save_gradients(grad_out[0]) + ) + + try: + # Add batch dimension if needed + if x.ndim == 4: + x = x.unsqueeze(0) + + x = x.to(self.device) + + # Enable gradients for computation + with torch.enable_grad(): + x = x.requires_grad_(True) + + # Forward pass through model + output = self.model(x, mode="both") + if isinstance(output, tuple): + output = output[0] # Get classification output + + # Get predicted class (or use class 0 for binary case) + if output.ndim > 2: + # Handle spatial outputs by averaging + output = torch.mean(output, tuple(range(2, output.ndim))) + pred = output.argmax(dim=1) + + # Create one hot vector for backward pass + one_hot = torch.zeros_like(output, device=self.device) + one_hot[0][pred] = 1 + + # Clear gradients + self.zero_grad(set_to_none=False) + + # Backward pass + output.backward(gradient=one_hot) + + # Ensure we have valid gradients and activations + if self.gradients is None or self.activations is None: + raise RuntimeError( + "No gradients or activations available for GradCAM computation" + ) + + # Calculate weights and generate CAM + weights = torch.mean(self.gradients, dim=(2, 3)) + cam = torch.sum(weights[:, :, None, None] * self.activations, dim=1) + cam = torch.relu(cam) + + # Interpolate CAM to input size + cam = ( + torch.nn.functional.interpolate( + cam.unsqueeze(0), + size=x.shape[-2:], + mode="bilinear", + align_corners=False, + )[0, 0] + .cpu() + .detach() + .numpy() + ) + + return cam + + finally: + # Clean up + h.remove() + h_bp.remove() + # Restore original training mode + self.train(mode=was_training) From 0fef4e192a3ce2994e10379210b0425a8a274408 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 30 Jan 2025 18:51:26 -0800 Subject: [PATCH 35/38] modifying gradcam so it's compatible with the config file --- viscy/callbacks/gradcam.py | 178 ++++++++++++++++-------------- viscy/data/tarrow.py | 58 ++++++++-- viscy/representation/timearrow.py | 1 - 3 files changed, 145 insertions(+), 92 deletions(-) diff --git a/viscy/callbacks/gradcam.py b/viscy/callbacks/gradcam.py index da23f19f5..f70ac9c33 100644 --- a/viscy/callbacks/gradcam.py +++ b/viscy/callbacks/gradcam.py @@ -1,3 +1,4 @@ +import logging from typing import List import torch @@ -8,35 +9,33 @@ import matplotlib.pyplot as plt from skimage.exposure import rescale_intensity +logger = logging.getLogger(__name__) + class GradCAMCallback(Callback): """Callback for computing and logging GradCAM visualizations. Parameters ---------- - visual_datasets : list - List of datasets to generate visualizations from every_n_epochs : int, default=10 Generate visualizations every n epochs max_samples : int, default=5 Maximum number of samples to visualize per dataset max_height : int, default=720 Maximum height of output visualization - mode : str, default="separate" + mode : str, default="overlay" Visualization mode: "separate" for individual images and activations, or "overlay" for activation map overlaid on input image """ def __init__( self, - visual_datasets: List, every_n_epochs: int = 10, max_samples: int = 5, max_height: int = 720, mode: str = "overlay", ): super().__init__() - self.visual_datasets = visual_datasets self.every_n_epochs = every_n_epochs self.max_samples = max_samples self.max_height = max_height @@ -50,88 +49,99 @@ def on_validation_epoch_end( if (trainer.current_epoch + 1) % self.every_n_epochs != 0: return + if not hasattr(trainer.datamodule, "visual_dataloader"): + logger.warning( + "DataModule does not have visual_dataloader method. Skipping GradCAM visualization." + ) + return + pl_module.eval() device = pl_module.device - for dataset_idx, dataset in enumerate(self.visual_datasets): - # Get a few samples - samples = [] - activations = [] - - for i, (x, _) in enumerate(dataset): - if i >= self.max_samples: - break - - try: - # Move tensor to same device as model - x = x.to(device) - - # Generate class activation map - activation_map = pl_module.gradcam(x) - - # Convert to RGB images for visualization - x_img = x[0].cpu().numpy() # Take first timepoint - if x_img.ndim == 3: # Handle (C,H,W) case - x_img = x_img[0] # Take first channel - x_img = rescale_intensity(x_img, in_range="image", out_range=(0, 1)) - - # Create activation map visualization - activation_norm = self._normalize_cam( - torch.from_numpy(activation_map) - ) - activation_rgb = plt.cm.magma(activation_norm.numpy())[..., :3] - - if self.mode == "separate": - # Keep sample as grayscale - x_vis = torch.from_numpy(x_img).unsqueeze(0).float() - activation_vis = ( - torch.from_numpy(activation_rgb).permute(2, 0, 1).float() - ) - else: # overlay mode - # Convert input to RGB - x_rgb = np.stack([x_img] * 3, axis=-1) - # Create overlay - overlay = self._create_overlay(x_rgb, activation_rgb) - x_vis = torch.from_numpy(x_rgb).permute(2, 0, 1).float() - activation_vis = ( - torch.from_numpy(overlay).permute(2, 0, 1).float() - ) - - samples.append(x_vis.cpu()) # Ensure on CPU for visualization - activations.append( - activation_vis.cpu() - ) # Ensure on CPU for visualization - - except Exception as e: - print(f"Error processing sample {i}: {str(e)}") - continue - - if samples: # Only proceed if we have samples - try: - # Stack images for grid visualization - samples_grid = torchvision.utils.make_grid( - samples, nrow=len(samples), normalize=True, value_range=(0, 1) - ) - activations_grid = torchvision.utils.make_grid( - activations, - nrow=len(activations), - normalize=True, - value_range=(0, 1), - ) - - # Log to tensorboard - trainer.logger.experiment.add_image( - f"gradcam/dataset_{dataset_idx}/samples", - samples_grid, - trainer.current_epoch, - ) - trainer.logger.experiment.add_image( - f"gradcam/dataset_{dataset_idx}/{'overlays' if self.mode == 'overlay' else 'activations'}", - activations_grid, - trainer.current_epoch, - ) - except Exception as e: - print(f"Error creating visualization grid: {str(e)}") + # Get visual dataloader from the datamodule + visual_loader = trainer.datamodule.visual_dataloader() + + # Get a few samples + samples = [] + activations = [] + + for batch_idx, (x, _) in enumerate(visual_loader): + if batch_idx >= self.max_samples: + break + + try: + # Move tensor to same device as model + x = x.to(device) + + # Generate class activation map + activation_map = pl_module.gradcam(x) + + # Convert to RGB images for visualization + # Handle 5D tensor [B, T, C, H, W] -> take first batch and timepoint + x_img = x[0, 0].cpu().numpy() # Take first batch and timepoint + if x_img.ndim == 3: # Handle [C, H, W] case + x_img = x_img[0] # Take first channel to get [H, W] + x_img = rescale_intensity(x_img, in_range="image", out_range=(0, 1)) + + # Create activation map visualization + activation_norm = self._normalize_cam(torch.from_numpy(activation_map)) + activation_rgb = plt.cm.magma(activation_norm.numpy())[..., :3] + + if self.mode == "separate": + # Keep sample as grayscale + x_vis = ( + torch.from_numpy(x_img).unsqueeze(0).float() + ) # Add channel dim [1, H, W] + activation_vis = ( + torch.from_numpy(activation_rgb).permute(2, 0, 1).float() + ) # [3, H, W] + else: # overlay mode + # Convert input to RGB + x_rgb = np.stack([x_img] * 3, axis=-1) # [H, W, 3] + # Create overlay + overlay = self._create_overlay(x_rgb, activation_rgb) + x_vis = ( + torch.from_numpy(x_rgb).permute(2, 0, 1).float() + ) # [3, H, W] + activation_vis = ( + torch.from_numpy(overlay).permute(2, 0, 1).float() + ) # [3, H, W] + + samples.append(x_vis.cpu()) # Ensure on CPU for visualization + activations.append( + activation_vis.cpu() + ) # Ensure on CPU for visualization + + except Exception as e: + logger.error(f"Error processing sample {batch_idx}: {str(e)}") + continue + + if samples: # Only proceed if we have samples + try: + # Stack images for grid visualization + samples_grid = torchvision.utils.make_grid( + samples, nrow=len(samples), normalize=True, value_range=(0, 1) + ) + activations_grid = torchvision.utils.make_grid( + activations, + nrow=len(activations), + normalize=True, + value_range=(0, 1), + ) + + # Log to tensorboard + trainer.logger.experiment.add_image( + f"gradcam/samples", + samples_grid, + trainer.current_epoch, + ) + trainer.logger.experiment.add_image( + f"gradcam/{'overlays' if self.mode == 'overlay' else 'activations'}", + activations_grid, + trainer.current_epoch, + ) + except Exception as e: + logger.error(f"Error creating visualization grid: {str(e)}") @staticmethod def _tensor_to_img(tensor: torch.Tensor) -> torch.Tensor: diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index d6cf423c6..4a58e2470 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -23,6 +23,10 @@ class TarrowDataModule(LightningDataModule): Fraction of data to use for training (0.0 to 1.0) patch_size : tuple[int, int], default=(128, 128) Patch size for TarrowDataset + visual_patch_size : tuple[int, int] | None, default=None + Patch size for visualization dataset + visual_batch_size : int | None, default=None + Batch size for visualization dataloader batch_size : int, default=16 Batch size for dataloaders num_workers : int, default=8 @@ -57,6 +61,8 @@ def __init__( batch_size: int = 16, num_workers: int = 8, patch_size: tuple[int, int] = (128, 128), + visual_patch_size: tuple[int, int] | None = None, + visual_batch_size: int | None = None, prefetch_factor: int | None = None, include_fov_names: list[str] = [], train_samples_per_epoch: int = 100000, @@ -75,7 +81,9 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers self.prefetch_factor = prefetch_factor - self.path_size = patch_size + self.patch_size = patch_size + self.visual_patch_size = visual_patch_size or tuple(4 * x for x in patch_size) + self.visual_batch_size = visual_batch_size or min(4, batch_size) self.include_fov_names = include_fov_names self.train_samples_per_epoch = train_samples_per_epoch self.val_samples_per_epoch = val_samples_per_epoch @@ -150,18 +158,26 @@ def setup(self, stage: str): NotImplementedError If stage is not "fit" """ - - # Get channel index once - if stage == "fit": list_dataset = [] + list_visual_dataset = [] + for pos in self.positions: pos_imgs = self._load_images(pos, self._channel_idx) list_dataset.append( TarrowDataset( imgs=pos_imgs, normalize=self.normalization, - size=self.path_size, + size=self.patch_size, + **self.kwargs, + ) + ) + # Create visualization dataset with larger patches + list_visual_dataset.append( + TarrowDataset( + imgs=pos_imgs, + normalize=self.normalization, + size=self.visual_patch_size, **self.kwargs, ) ) @@ -172,13 +188,23 @@ def setup(self, stage: str): # Shuffle the list of datasets shuffled_indices = set_fit_global_state(len(list_dataset)) list_dataset = [list_dataset[i] for i in shuffled_indices] + list_visual_dataset = [ + list_visual_dataset[i] for i in shuffled_indices + ] # Use same shuffling # Create training dataset with first train_split% of images self.train_dataset = ConcatDataset(list_dataset[:split_idx]) - - # Create validation dataset with remaining images self.val_dataset = ConcatDataset(list_dataset[split_idx:]) + # Take up to n_visual samples from validation set + # NOTE fixed to take the first n_visual samples from validation set + self.visual_batch_size = max( + len(list_visual_dataset[split_idx:]), self.visual_batch_size + ) + self.visual_dataset = ConcatDataset( + list_visual_dataset[split_idx : split_idx + self.visual_batch_size] + ) + elif stage == "test": raise NotImplementedError(f"Invalid stage: {stage}") elif stage == "predict": @@ -245,6 +271,24 @@ def val_dataloader(self): shuffle=False, ) + def visual_dataloader(self): + """Create the visualization dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for visualization data + """ + return DataLoader( + self.visual_dataset, + batch_size=self.visual_batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=True, + shuffle=False, + ) + def test_dataloader(self): """Create the test dataloader. diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index e61af2651..1b4fbac2c 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -66,7 +66,6 @@ def __init__( lambda_decorrelation=0.01, lr_scheduler="cyclic", lr_patience=50, - cam_size=None, log_batches_per_epoch=8, log_samples_per_batch=1, **kwargs, From b8eaf2ed7bdcf099aeceb9bcc2fa29ee695f4bc9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 30 Jan 2025 20:44:02 -0800 Subject: [PATCH 36/38] swapping to use callables so we can configure the augmentations --- viscy/callbacks/gradcam.py | 9 ++--- viscy/data/tarrow.py | 60 +++++++++++++++---------------- viscy/representation/timearrow.py | 2 -- 3 files changed, 30 insertions(+), 41 deletions(-) diff --git a/viscy/callbacks/gradcam.py b/viscy/callbacks/gradcam.py index f70ac9c33..1a2c04dbb 100644 --- a/viscy/callbacks/gradcam.py +++ b/viscy/callbacks/gradcam.py @@ -1,12 +1,11 @@ import logging -from typing import List +import matplotlib.pyplot as plt +import numpy as np import torch import torchvision from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback -import numpy as np -import matplotlib.pyplot as plt from skimage.exposure import rescale_intensity logger = logging.getLogger(__name__) @@ -21,8 +20,6 @@ class GradCAMCallback(Callback): Generate visualizations every n epochs max_samples : int, default=5 Maximum number of samples to visualize per dataset - max_height : int, default=720 - Maximum height of output visualization mode : str, default="overlay" Visualization mode: "separate" for individual images and activations, or "overlay" for activation map overlaid on input image @@ -32,13 +29,11 @@ def __init__( self, every_n_epochs: int = 10, max_samples: int = 5, - max_height: int = 720, mode: str = "overlay", ): super().__init__() self.every_n_epochs = every_n_epochs self.max_samples = max_samples - self.max_height = max_height assert mode in ["separate", "overlay"], "Mode must be 'separate' or 'overlay'" self.mode = mode diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index 4a58e2470..a20997cf8 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -1,7 +1,8 @@ from pathlib import Path -from typing import Callable +from typing import Callable, Sequence import numpy as np +import torch.nn as nn from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from tarrow.data.tarrow_dataset import TarrowDataset @@ -49,6 +50,8 @@ class TarrowDataModule(LightningDataModule): Whether to pin memory persistent_workers : bool, default=True Whether to keep the workers alive between epochs + augmentations : list[nn.Module], default=[] + List of Kornia augmentation transforms to apply during training **kwargs : dict Additional arguments passed to TarrowDataset """ @@ -72,6 +75,7 @@ def __init__( normalization: Callable[[np.ndarray], np.ndarray] | None = None, pin_memory: bool = True, persistent_workers: bool = True, + augmentations: Sequence[nn.Module] = [], **kwargs, ): super().__init__() @@ -91,37 +95,31 @@ def __init__( self.z_slice = z_slice self.kwargs = kwargs self.normalization = normalization + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.augmentations = augmentations self._filter_positions() self._channel_idx = self._get_channel_index() - def _get_channel_index(self, plate) -> int: - """Get the index of the specified channel from the plate metadata. + def _get_channel_index(self) -> int: + """Get the index of the specified channel from the plate metadata.""" + with open_ome_zarr(self.ome_zarr_path, mode="r") as plate: + _, first_pos = next(plate.positions()) + return first_pos.channel_names.index(self.channel_name) - Parameters - ---------- - plate : iohub.ngff.Plate - OME-Zarr plate object + def _create_augmentation_pipeline(self) -> nn.Sequential | None: + """Create the augmentation pipeline for training. Returns ------- - int - Index of the specified channel - - Raises - ------ - ValueError - If channel_name is not found in available channels + nn.Sequential | None + Sequential container of Kornia augmentations or None if no augmentations """ - # Get channel names from first position - _, first_pos = next(plate.positions()) - try: - return first_pos.channel_names.index(self.channel_name) - except ValueError: - available_channels = ", ".join(first_pos.channel_names) - raise ValueError( - f"Channel '{self.channel_name}' not found. Available channels: {available_channels}" - ) + if not self.augmentations: + return None + + return nn.Sequential(*self.augmentations) def _load_images(self, position: Position, channel_idx: int) -> list[np.ndarray]: """Load all images from positions into memory. @@ -162,6 +160,9 @@ def setup(self, stage: str): list_dataset = [] list_visual_dataset = [] + # Create augmentation pipeline + augmenter = self._create_augmentation_pipeline() + for pos in self.positions: pos_imgs = self._load_images(pos, self._channel_idx) list_dataset.append( @@ -169,6 +170,7 @@ def setup(self, stage: str): imgs=pos_imgs, normalize=self.normalization, size=self.patch_size, + augmenter=augmenter, # Pass augmenter to dataset **self.kwargs, ) ) @@ -229,12 +231,6 @@ def _filter_positions(self): self.positions = positions - def _get_channel_index(self): - """Get the index of the specified channel from the plate metadata.""" - with open_ome_zarr(self.ome_zarr_path, mode="r") as plate: - _, first_pos = next(plate.positions()) - return first_pos.channel_names.index(self.channel_name) - def train_dataloader(self): """Create the training dataloader. @@ -249,7 +245,7 @@ def train_dataloader(self): num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=self.prefetch_factor if self.num_workers else None, - pin_memory=True, + pin_memory=self.pin_memory, shuffle=True, ) @@ -267,7 +263,7 @@ def val_dataloader(self): num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=self.prefetch_factor if self.num_workers else None, - pin_memory=True, + pin_memory=self.pin_memory, shuffle=False, ) @@ -285,7 +281,7 @@ def visual_dataloader(self): num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=self.prefetch_factor if self.num_workers else None, - pin_memory=True, + pin_memory=self.pin_memory, shuffle=False, ) diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py index 1b4fbac2c..6df8a1ccf 100644 --- a/viscy/representation/timearrow.py +++ b/viscy/representation/timearrow.py @@ -44,8 +44,6 @@ class TarrowModule(LightningModule): Learning rate scheduler ('plateau' or 'cyclic') lr_patience : int, default=50 Patience for learning rate scheduler - cam_size : tuple or int, optional - Size of the class activation map (H, W). If None, use input size. log_batches_per_epoch : int, default=8 Number of batches to log samples from during training log_samples_per_batch : int, default=1 From 473d3ab3b111c34dea3356c051b7c0d2b6cb6b82 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Sun, 9 Feb 2025 10:59:29 -0800 Subject: [PATCH 37/38] temporary fix for missing utility --- viscy/data/tarrow.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py index a20997cf8..d89771301 100644 --- a/viscy/data/tarrow.py +++ b/viscy/data/tarrow.py @@ -8,7 +8,9 @@ from tarrow.data.tarrow_dataset import TarrowDataset from torch.utils.data import ConcatDataset, DataLoader -from viscy.utils.engine_state import set_fit_global_state +# FIXME: This module is not available in the viscy package,so shuffle the list of datasets manually. +# from viscy.utils.engine_state import set_fit_global_state +import random class TarrowDataModule(LightningDataModule): @@ -188,7 +190,11 @@ def setup(self, stage: str): split_idx = int(len(self.positions) * self.train_split) # Shuffle the list of datasets - shuffled_indices = set_fit_global_state(len(list_dataset)) + + #FIXME: This module is not available in the viscy package,so shuffle the list of datasets manually. + # shuffled_indices = set_fit_global_state(len(list_dataset)) + shuffled_indices = list(range(len(list_dataset))) + random.shuffle(shuffled_indices) list_dataset = [list_dataset[i] for i in shuffled_indices] list_visual_dataset = [ list_visual_dataset[i] for i in shuffled_indices From 22bd10e586ba35d3fa18df57def8544aeb016557 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Sun, 9 Feb 2025 10:59:51 -0800 Subject: [PATCH 38/38] test the model and dataloader --- .../timearrow_phenotyping/test_tarrow.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 applications/timearrow_phenotyping/test_tarrow.py diff --git a/applications/timearrow_phenotyping/test_tarrow.py b/applications/timearrow_phenotyping/test_tarrow.py new file mode 100644 index 000000000..f4d2530a7 --- /dev/null +++ b/applications/timearrow_phenotyping/test_tarrow.py @@ -0,0 +1,57 @@ +# %% Imports +import torch +import torchview +from viscy.data.tarrow import TarrowDataModule +from viscy.representation.timearrow import TarrowModule + + +# %% Load minimal config +config = { + 'data': { + 'init_args': { + 'ome_zarr_path': '/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets/float_phase_ome_zarr_output_valtrain.zarr', # Replace with actual path + 'channel_name': 'DIC', + 'patch_size': [256, 256], + 'batch_size': 32, + 'num_workers': 4, + 'train_split': 0.8 + } + }, + 'model': { + 'init_args': { + 'backbone': 'unet', + 'projection_head': 'minimal_batchnorm', + 'classification_head': 'minimal', + } + } +} + +# # Optionally load config from file +# config_path = "/hpc/projects/organelle_phenotyping/models/ALFI/tarrow_test/tarrow.yml" +# with open(config_path) as f: +# config = yaml.safe_load(f) + +# %% Initialize data and model +data_module = TarrowDataModule(**config['data']['init_args']) +model = TarrowModule(**config['model']['init_args']) +# %% Construct a batch of data from the data module +data_module.setup('fit') +batch = next(iter(data_module.train_dataloader())) +images, labels = batch +print(model) +# %% Print model graph. +try: + # Try constructing the graph + model_graph = torchview.draw_graph( + model, + input_data=images, + save_graph=False, # Don't save, just display + expand_nested=True, + device='cpu' # specify CPU device + ) +except Exception as e: + print(f"Error generating model graph: {e}") + +model_graph.visual_graph # Display the graph + +# %%