diff --git a/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py index cce99c9e4..80f2fd056 100644 --- a/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py +++ b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py @@ -1,11 +1,12 @@ -"""Interactive visualization of phenotype data.""" - import logging from pathlib import Path +import click from numpy.random import seed -from viscy.representation.evaluation.visualization import EmbeddingVisualizationApp +from viscy.representation.visualization.app import EmbeddingVisualizationApp +from viscy.representation.visualization.settings import VizConfig +from viscy.utils.cli_utils import yaml_to_model logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -13,33 +14,34 @@ seed(42) -def main(): +@click.command() +@click.option( + "--config-filepath", + "-c", + required=True, + help="Path to YAML configuration file.", +) +def main(config_filepath): """Main function to run the visualization app.""" - # Config for the visualization app - # TODO: Update the paths to the downloaded data. By default the data is downloaded to ~/data/dynaclr/demo - download_root = Path.home() / "data/dynaclr/demo" - output_path = Path.home() / "data/dynaclr/demo/embedding-web-visualization" - viz_config = { - "data_path": download_root / "registered_test.zarr", # TODO add path to data - "tracks_path": download_root / "track_test.zarr", # TODO add path to tracks - "features_path": download_root - / "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr", # TODO add path to features - "channels_to_display": ["Phase3D", "RFP"], - "fov_tracks": { - "/A/3/9": list(range(50)), - "/B/4/9": list(range(50)), - }, - "yx_patch_size": (160, 160), - "z_range": (24, 29), - "num_PC_components": 8, - "output_dir": output_path, - } + # Load and validate configuration from YAML file + viz_config = yaml_to_model(yaml_path=config_filepath, model=VizConfig) + + # Use configured paths, with fallbacks to current defaults if not specified + output_dir = viz_config.output_dir or str(Path(__file__).parent / "output") + cache_path = viz_config.cache_path + + logger.info(f"Using output directory: {output_dir}") + logger.info(f"Using cache path: {cache_path}") # Create and run the visualization app try: - # Create and run the visualization app - app = EmbeddingVisualizationApp(**viz_config) + app = EmbeddingVisualizationApp( + viz_config=viz_config, + cache_path=cache_path, + num_loading_workers=16, + output_dir=output_dir, + ) app.preload_images() app.run(debug=True) diff --git a/examples/DynaCLR/embedding-web-visualization/viz_config.yaml b/examples/DynaCLR/embedding-web-visualization/viz_config.yaml new file mode 100644 index 000000000..ac35b0e98 --- /dev/null +++ b/examples/DynaCLR/embedding-web-visualization/viz_config.yaml @@ -0,0 +1,52 @@ +# Multi-dataset visualization configuration +datasets: + g3bp1_sensor: + data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr" + tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr" + features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/sensor_160patch_99ckpt_max.zarr" + channels_to_display: ["raw mCherry EX561 EM600-37"] + z_range: [0, 1] + yx_patch_size: [192, 192] + fov_tracks: + "/B/3/000000": [42, 47, 56, 57, 58, 59, 61] + "/B/4/000001": [57, 83, 89, 90, 91, 101] + + + g3bp1_organelle: + data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr" + tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr" + features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/organelle_160patch_99ckpt_max.zarr" + channels_to_display: ["raw GFP EX488 EM525-45"] + z_range: [0, 1] + yx_patch_size: [192, 192] + fov_tracks: + # "/B/3/000000": "all" + # "/B/4/000001": "all" + "/B/2/000000": [15, 47, 61, 62, 63, 64, 65] + "/C/2/000001": "all" + # "/C/2/000000": [] + + # sec61_organelle: + # data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr" + # tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr" + # features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/2-predictions/organelle_160patch_99ckpt_max.zarr" + # channels_to_display: ["raw GFP EX488 EM525-45"] + # z_range: [0, 1] + # yx_patch_size: [192, 192] + +# Global settings +num_PC_components: 8 +# phate_kwargs: +# n_components: 4 # Number of PHATE components +# knn: 5 # Number of nearest neighbors for KNN graph +# decay: 40 # Decay parameter for Markov operator +# n_jobs: -1 # Number of parallel jobs (-1 uses all cores) +# random_state: 42 # Random seed for reproducibility +# # gamma: 1.0 # Informational scale parameter (optional) +# # t: "auto" # Number of diffusion steps (optional, auto-selects optimal) +# # a: 1 # Alpha parameter for alpha-decay kernel (optional) +# # verbose: 0 # Verbosity level (optional) + +# File system paths +output_dir: "./output" # Directory to save CSV files and other outputs +cache_path: "/home/eduardo.hirata/mydata/tmp/pcviewer/cache/image_cache1.pkl" # Path to save/load image cache \ No newline at end of file diff --git a/viscy/representation/evaluation/combined_analysis.py b/viscy/representation/evaluation/combined_analysis.py new file mode 100644 index 000000000..cc8966174 --- /dev/null +++ b/viscy/representation/evaluation/combined_analysis.py @@ -0,0 +1,172 @@ +import logging +from pathlib import Path +from typing import Literal, Optional + +import numpy as np +import pandas as pd +from xarray import Dataset + +from viscy.representation.embedding_writer import ( + read_embedding_dataset, + write_embedding_dataset, +) +from viscy.representation.evaluation.data_loading import ( + EmbeddingDataLoader, + TripletEmbeddingLoader, +) +from viscy.representation.evaluation.dimensionality_reduction import compute_phate + +__all__ = ["load_and_combine_features", "compute_phate_for_combined_datasets"] + +_logger = logging.getLogger("lightning.pytorch") + + +def load_and_combine_features( + feature_paths: list[Path], + dataset_names: Optional[list[str]] = None, + loader: Literal[ + EmbeddingDataLoader, TripletEmbeddingLoader + ] = TripletEmbeddingLoader(), +) -> tuple[np.ndarray, pd.DataFrame]: + """ + Load features from multiple datasets and combine them using a pluggable loader. + + Parameters + ---------- + feature_paths : list[Path] + Paths to embedding datasets + dataset_names : list[str], optional + Names for datasets. If None, uses file stems + loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional + Custom data loader. If None, uses TripletEmbeddingLoader + + Returns + ------- + tuple[np.ndarray, pd.DataFrame] + Combined features array and index DataFrame with dataset_pair column + """ + if dataset_names is None: + dataset_names = [path.stem for path in feature_paths] + + if len(dataset_names) != len(feature_paths): + raise ValueError("Number of dataset names must match number of feature paths") + + all_features = [] + all_indices = [] + + for path, dataset_name in zip(feature_paths, dataset_names): + _logger.info(f"Loading features from {path}") + + dataset = loader.load_dataset(path) + features = loader.extract_features(dataset) + index_df = loader.extract_metadata(dataset) + + index_df["dataset_pair"] = dataset_name + index_df["dataset_path"] = str(path) + + all_features.append(features) + all_indices.append(index_df) + + _logger.info(f"Loaded {len(features)} samples from {dataset_name}") + + combined_features = np.vstack(all_features) + combined_indices = pd.concat(all_indices, ignore_index=True) + + _logger.info( + f"Combined {len(combined_features)} total samples from {len(feature_paths)} datasets" + ) + + return combined_features, combined_indices + + +def compute_phate_for_combined_datasets( + feature_paths: list[Path], + output_path: Path, + dataset_names: Optional[list[str]] = None, + phate_kwargs: Optional[dict] = None, + overwrite: bool = False, + loader: Literal[ + EmbeddingDataLoader, TripletEmbeddingLoader + ] = TripletEmbeddingLoader(), +) -> Dataset: + """ + Compute PHATE embeddings on combined features from multiple datasets. + + Parameters + ---------- + feature_paths : list[Path] + List of paths to zarr stores containing embedding datasets + output_path : Path + Path to save the combined dataset with PHATE embeddings + dataset_names : list[str], optional + Names for each dataset. If None, uses file stems + phate_kwargs : dict, optional + Parameters for PHATE computation. Default: {"knn": 5, "decay": 40, "n_components": 2} + overwrite : bool, optional + Whether to overwrite existing output file + loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional + Custom data loader. If None, uses TripletEmbeddingLoader + + Returns + ------- + Dataset + Combined xarray dataset with original features and PHATE coordinates + + Raises + ------ + FileExistsError + If output_path exists and overwrite is False + ImportError + If PHATE is not installed + """ + output_path = Path(output_path) + + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Output path {output_path} already exists. Use overwrite=True to overwrite." + ) + + if phate_kwargs is None: + phate_kwargs = {"knn": 5, "decay": 40, "n_components": 2} + + _logger.info( + f"Computing PHATE for combined datasets with parameters: {phate_kwargs}" + ) + + combined_features, combined_indices = load_and_combine_features( + feature_paths, dataset_names, loader + ) + + n_samples = len(combined_features) + if phate_kwargs.get("knn", 5) >= n_samples: + original_knn = phate_kwargs["knn"] + phate_kwargs["knn"] = max(2, n_samples // 2) + _logger.warning( + f"Reducing knn from {original_knn} to {phate_kwargs['knn']} due to dataset size ({n_samples} samples)" + ) + + _logger.info("Computing PHATE embeddings on combined features") + try: + _, phate_embedding = compute_phate(combined_features, **phate_kwargs) + _logger.info( + f"PHATE computation successful, embedding shape: {phate_embedding.shape}" + ) + except Exception as e: + _logger.error(f"PHATE computation failed: {str(e)}") + raise + + _logger.info(f"Writing combined dataset with PHATE embeddings to {output_path}") + write_embedding_dataset( + output_path=output_path, + features=combined_features, + index_df=combined_indices, + phate_kwargs=phate_kwargs, + overwrite=overwrite, + ) + + result_dataset = read_embedding_dataset(output_path) + _logger.info( + f"Successfully created combined dataset with {len(result_dataset.sample)} samples" + ) + + return result_dataset diff --git a/viscy/representation/evaluation/data_loading.py b/viscy/representation/evaluation/data_loading.py new file mode 100644 index 000000000..f1625c2c2 --- /dev/null +++ b/viscy/representation/evaluation/data_loading.py @@ -0,0 +1,163 @@ +import logging +from pathlib import Path +from typing import Protocol, runtime_checkable + +import numpy as np +import pandas as pd +from xarray import Dataset + +__all__ = ["EmbeddingDataLoader", "TripletEmbeddingLoader"] + +_logger = logging.getLogger("lightning.pytorch") + + +@runtime_checkable +class EmbeddingDataLoader(Protocol): + """Protocol for embedding dataloaders that can be used with combined analysis.""" + + def load_dataset(self, path: Path) -> Dataset: + """ + Load dataset from path and return xarray Dataset with 'features' data variable. + + Parameters + ---------- + path : Path + Path to the dataset file + + Returns + ------- + Dataset + Xarray dataset with 'features' data variable containing embeddings + """ + ... + + def extract_features(self, dataset: Dataset) -> np.ndarray: + """ + Extract feature embeddings from dataset. + + Parameters + ---------- + dataset : Dataset + Xarray dataset containing features + + Returns + ------- + np.ndarray + Array of shape (n_samples, n_features) containing embeddings + """ + ... + + def extract_metadata(self, dataset: Dataset) -> pd.DataFrame: + """ + Extract metadata/index information from dataset. + + Parameters + ---------- + dataset : Dataset + Xarray dataset containing metadata + + Returns + ------- + pd.DataFrame + DataFrame containing metadata for each sample, including any existing + dimensionality reduction coordinates (PHATE, UMAP, etc.) + """ + ... + + +class TripletEmbeddingLoader: + """Default loader for triplet-based embedding datasets.""" + + def load_dataset(self, path: Path) -> Dataset: + """Load embedding dataset using the standard embedding writer format.""" + from viscy.representation.embedding_writer import read_embedding_dataset + + _logger.debug(f"Loading dataset from {path} using TripletEmbeddingLoader") + return read_embedding_dataset(path) + + def extract_features(self, dataset: Dataset) -> np.ndarray: + """Extract features from the 'features' data variable.""" + return dataset["features"].values + + def extract_metadata(self, dataset: Dataset) -> pd.DataFrame: + """ + Extract metadata from dataset coordinates and data variables. + + This includes sample coordinates and any existing dimensionality reduction + coordinates like PHATE, UMAP, PCA that were previously computed. + """ + features_data_array = dataset["features"] + + try: + coord_df = features_data_array["sample"].to_dataframe() + + if coord_df.index.names != [None]: + index_df = coord_df.reset_index() + if "features" in index_df.columns: + index_df = index_df.drop(columns=["features"]) + else: + index_df = coord_df.reset_index(drop=True) + + dim_reduction_cols = [ + col + for col in index_df.columns + if any(col.startswith(prefix) for prefix in ["PHATE", "UMAP", "PCA"]) + ] + if dim_reduction_cols: + _logger.debug( + f"Found dimensionality reduction coordinates: {dim_reduction_cols}" + ) + + _logger.debug( + f"Extracted metadata with {len(index_df.columns)} columns: {list(index_df.columns)}" + ) + return index_df + + except Exception as e: + _logger.error(f"Error extracting metadata: {e}") + index_df = ( + features_data_array["sample"].to_dataframe().reset_index(drop=True) + ) + _logger.warning( + f"Using fallback metadata extraction with {len(index_df.columns)} columns" + ) + return index_df + + +# Example of how to implement a custom loader +# TODO: replace with the other dataloaders +class CustomEmbeddingLoader: + """ + Example implementation of a custom embedding loader. + + This serves as a template for implementing loaders for different data formats. + Replace the method implementations with your specific loading logic. + """ + + def load_dataset(self, path: Path) -> Dataset: + """ + Load your custom dataset format. + + This should return an xarray Dataset with at least a 'features' data variable + containing the embeddings with a 'sample' dimension. + """ + raise NotImplementedError("Implement your custom dataset loading logic here") + + def extract_features(self, dataset: Dataset) -> np.ndarray: + """ + Extract features from your dataset format. + + Should return a numpy array of shape (n_samples, n_features). + """ + return dataset["features"].values # Modify if your format is different + + def extract_metadata(self, dataset: Dataset) -> pd.DataFrame: + """ + Extract metadata from your dataset format. + + Should return a DataFrame with one row per sample containing metadata + like sample IDs, FOV names, track IDs, etc. + """ + raise NotImplementedError( + "Implement your custom metadata extraction logic here" + ) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index eb5d43f91..80634c7a6 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -61,7 +61,11 @@ def compute_phate( # Compute PHATE embeddings phate_model = phate.PHATE( - n_components=n_components, knn=knn, decay=decay, **phate_kwargs + n_components=n_components, + knn=knn, + decay=decay, + random_state=42, + **phate_kwargs, ) phate_embedding = phate_model.fit_transform(embeddings) diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py deleted file mode 100644 index 9d787fe05..000000000 --- a/viscy/representation/evaluation/visualization.py +++ /dev/null @@ -1,2244 +0,0 @@ -import atexit -import base64 -import json -import logging -from io import BytesIO -from pathlib import Path - -import dash -import dash.dependencies as dd -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import plotly.graph_objects as go -from dash import dcc, html -from PIL import Image -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler - -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import read_embedding_dataset - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class EmbeddingVisualizationApp: - def __init__( - self, - data_path: str, - tracks_path: str, - features_path: str, - channels_to_display: list[str] | str, - fov_tracks: dict[str, list[int] | str], - z_range: tuple[int, int] = (0, 1), - yx_patch_size: tuple[int, int] = (128, 128), - num_PC_components: int = 3, - cache_path: str | None = None, - num_loading_workers: int = 16, - output_dir: str | None = None, - ) -> None: - """ - Initialize a Dash application for visualizing the DynaCLR embeddings. - - This class provides a visualization tool for visualizing the DynaCLR embeddings into a 2D space (e.g. PCA, UMAP, PHATE). - It allows users to interactively explore and analyze trajectories, visualize clusters, and explore the embedding space. - - Parameters - ---------- - data_path: str - Path to the data directory. - tracks_path: str - Path to the tracks directory. - features_path: str - Path to the features directory. - channels_to_display: list[str] | str - List of channels to display. - fov_tracks: dict[str, list[int] | str] - Dictionary of FOV names and track IDs. - z_range: tuple[int, int] | list[int,int] - Range of z-slices to display. - yx_patch_size: tuple[int, int] | list[int,int] - Size of the yx-patch to display. - num_PC_components: int - Number of PCA components to use. - cache_path: str | None - Path to the cache directory. - num_loading_workers: int - Number of workers to use for loading data. - output_dir: str | None, optional - Directory to save CSV files and other outputs. If None, uses current working directory. - Returns - ------- - None - Initializes the visualization app. - """ - self.data_path = Path(data_path) - self.tracks_path = Path(tracks_path) - self.features_path = Path(features_path) - self.fov_tracks = fov_tracks - self.image_cache = {} - self.cache_path = Path(cache_path) if cache_path else None - self.output_dir = Path(output_dir) if output_dir else Path.cwd() - self.app = None - self.features_df = None - self.fig = None - self.channels_to_display = channels_to_display - self.z_range = z_range - self.yx_patch_size = yx_patch_size - self.filtered_tracks_by_fov = {} - self._z_idx = (self.z_range[1] - self.z_range[0]) // 2 - self.num_PC_components = num_PC_components - self.num_loading_workers = num_loading_workers - # Initialize cluster storage before preparing data and creating figure - self.clusters = [] # List to store all clusters - self.cluster_points = set() # Set to track all points in clusters - self.cluster_names = {} # Dictionary to store cluster names - self.next_cluster_id = 1 # Counter for cluster IDs - # Initialize data - self._prepare_data() - self._create_figure() - self._init_app() - atexit.register(self._cleanup_cache) - - def _prepare_data(self): - """Prepare the feature data and PCA transformation""" - embedding_dataset = read_embedding_dataset(self.features_path) - features = embedding_dataset["features"] - self.features_df = features["sample"].to_dataframe().reset_index(drop=True) - - # Check if UMAP or PHATE columns already exist - existing_dims = [] - dim_options = [] - - # Check for PCA and compute if needed - if not any(col.startswith("PCA") for col in self.features_df.columns): - # PCA transformation - scaled_features = StandardScaler().fit_transform(features.values) - pca = PCA(n_components=self.num_PC_components) - pca_coords = pca.fit_transform(scaled_features) - - # Add PCA coordinates to the features dataframe - for i in range(self.num_PC_components): - self.features_df[f"PCA{i + 1}"] = pca_coords[:, i] - - # Store explained variance for PCA - self.pca_explained_variance = [ - f"PC{i + 1} ({var:.1f}%)" - for i, var in enumerate(pca.explained_variance_ratio_ * 100) - ] - - # Add PCA options - for i, pc_label in enumerate(self.pca_explained_variance): - dim_options.append({"label": pc_label, "value": f"PCA{i + 1}"}) - existing_dims.append(f"PCA{i + 1}") - - # Check for UMAP coordinates - umap_dims = [col for col in self.features_df.columns if col.startswith("UMAP")] - if umap_dims: - for dim in umap_dims: - dim_options.append({"label": dim, "value": dim}) - existing_dims.append(dim) - - # Check for PHATE coordinates - phate_dims = [ - col for col in self.features_df.columns if col.startswith("PHATE") - ] - if phate_dims: - for dim in phate_dims: - dim_options.append({"label": dim, "value": dim}) - existing_dims.append(dim) - - # Store dimension options for dropdowns - self.dim_options = dim_options - - # Set default x and y axes based on available dimensions - self.default_x = existing_dims[0] if existing_dims else "PCA1" - self.default_y = existing_dims[1] if len(existing_dims) > 1 else "PCA2" - - # Process each FOV and its track IDs - all_filtered_features = [] - for fov_name, track_ids in self.fov_tracks.items(): - if track_ids == "all": - fov_tracks = ( - self.features_df[self.features_df["fov_name"] == fov_name][ - "track_id" - ] - .unique() - .tolist() - ) - else: - fov_tracks = track_ids - - self.filtered_tracks_by_fov[fov_name] = fov_tracks - - # Filter features for this FOV and its track IDs - fov_features = self.features_df[ - (self.features_df["fov_name"] == fov_name) - & (self.features_df["track_id"].isin(fov_tracks)) - ] - all_filtered_features.append(fov_features) - - # Combine all filtered features - self.filtered_features_df = pd.concat(all_filtered_features, axis=0) - - def _create_figure(self): - """Create the initial scatter plot figure""" - self.fig = self._create_track_colored_figure() - - def _init_app(self): - """Initialize the Dash application""" - self.app = dash.Dash(__name__) - - # Add cluster assignment button next to clear selection - cluster_controls = html.Div( - [ - html.Button( - "Assign to New Cluster", - id="assign-cluster", - style={ - "backgroundColor": "#28a745", - "color": "white", - "border": "none", - "padding": "5px 10px", - "borderRadius": "4px", - "cursor": "pointer", - "marginRight": "10px", - }, - ), - html.Button( - "Clear All Clusters", - id="clear-clusters", - style={ - "backgroundColor": "#dc3545", - "color": "white", - "border": "none", - "padding": "5px 10px", - "borderRadius": "4px", - "cursor": "pointer", - "marginRight": "10px", - }, - ), - html.Button( - "Save Clusters to CSV", - id="save-clusters-csv", - style={ - "backgroundColor": "#17a2b8", - "color": "white", - "border": "none", - "padding": "5px 10px", - "borderRadius": "4px", - "cursor": "pointer", - "marginRight": "10px", - }, - ), - html.Button( - "Clear Selection", - id="clear-selection", - style={ - "backgroundColor": "#6c757d", - "color": "white", - "border": "none", - "padding": "5px 10px", - "borderRadius": "4px", - "cursor": "pointer", - }, - ), - ], - style={"marginLeft": "10px", "display": "inline-block"}, - ) - # Create tabs for different views - tabs = dcc.Tabs( - id="view-tabs", - value="timeline-tab", - children=[ - dcc.Tab( - label="Track Timeline", - value="timeline-tab", - children=[ - html.Div( - id="track-timeline", - style={ - "height": "auto", - "overflowY": "auto", - "maxHeight": "80vh", - "padding": "10px", - "marginTop": "10px", - }, - ), - ], - ), - dcc.Tab( - label="Clusters", - value="clusters-tab", - id="clusters-tab", - children=[ - html.Div( - id="cluster-container", - style={ - "padding": "10px", - "marginTop": "10px", - }, - ), - ], - style={"display": "none"}, # Initially hidden - ), - ], - style={"marginTop": "20px"}, - ) - - # Add modal for cluster naming - cluster_name_modal = html.Div( - id="cluster-name-modal", - children=[ - html.Div( - [ - html.H3("Name Your Cluster", style={"marginBottom": "20px"}), - html.Label("Cluster Name:"), - dcc.Input( - id="cluster-name-input", - type="text", - placeholder="Enter cluster name...", - style={"width": "100%", "marginBottom": "20px"}, - ), - html.Div( - [ - html.Button( - "Save", - id="save-cluster-name", - style={ - "backgroundColor": "#28a745", - "color": "white", - "border": "none", - "padding": "8px 16px", - "borderRadius": "4px", - "cursor": "pointer", - "marginRight": "10px", - }, - ), - html.Button( - "Cancel", - id="cancel-cluster-name", - style={ - "backgroundColor": "#6c757d", - "color": "white", - "border": "none", - "padding": "8px 16px", - "borderRadius": "4px", - "cursor": "pointer", - }, - ), - ], - style={"textAlign": "right"}, - ), - ], - style={ - "backgroundColor": "white", - "padding": "30px", - "borderRadius": "8px", - "maxWidth": "400px", - "margin": "auto", - "boxShadow": "0 4px 6px rgba(0, 0, 0, 0.1)", - "border": "1px solid #ddd", - }, - ) - ], - style={ - "display": "none", - "position": "fixed", - "top": "0", - "left": "0", - "width": "100%", - "height": "100%", - "backgroundColor": "rgba(0, 0, 0, 0.5)", - "zIndex": "1000", - "justifyContent": "center", - "alignItems": "center", - }, - ) - - # Update layout to use tabs - self.app.layout = html.Div( - style={ - "maxWidth": "95vw", - "margin": "auto", - "padding": "20px", - }, - children=[ - html.H1( - "Track Visualization", - style={"textAlign": "center", "marginBottom": "20px"}, - ), - html.Div( - [ - html.Div( - style={ - "width": "100%", - "display": "inline-block", - "verticalAlign": "top", - }, - children=[ - html.Div( - style={ - "marginBottom": "20px", - "display": "flex", - "alignItems": "center", - "gap": "20px", - "flexWrap": "wrap", - }, - children=[ - html.Div( - [ - html.Label( - "Color by:", - style={"marginRight": "10px"}, - ), - dcc.Dropdown( - id="color-mode", - options=[ - { - "label": "Track ID", - "value": "track", - }, - { - "label": "Time", - "value": "time", - }, - ], - value="track", - style={"width": "200px"}, - ), - ] - ), - html.Div( - [ - dcc.Checklist( - id="show-arrows", - options=[ - { - "label": "Show arrows", - "value": "show", - } - ], - value=[], - style={"marginLeft": "20px"}, - ), - ] - ), - html.Div( - [ - html.Label( - "X-axis:", - style={"marginRight": "10px"}, - ), - dcc.Dropdown( - id="x-axis", - options=self.dim_options, - value=self.default_x, - style={"width": "200px"}, - ), - ] - ), - html.Div( - [ - html.Label( - "Y-axis:", - style={"marginRight": "10px"}, - ), - dcc.Dropdown( - id="y-axis", - options=self.dim_options, - value=self.default_y, - style={"width": "200px"}, - ), - ] - ), - cluster_controls, - ], - ), - ], - ), - ] - ), - dcc.Loading( - id="loading", - children=[ - dcc.Graph( - id="scatter-plot", - figure=self.fig, - config={ - "displayModeBar": True, - "editable": False, - "showEditInChartStudio": False, - "modeBarButtonsToRemove": [ - "select2d", - "resetScale2d", - ], - "edits": { - "annotationPosition": False, - "annotationTail": False, - "annotationText": False, - "shapePosition": True, - }, - "scrollZoom": True, - }, - style={"height": "50vh"}, - ), - ], - type="default", - ), - tabs, - cluster_name_modal, - ], - ) - - @self.app.callback( - [ - dd.Output("scatter-plot", "figure", allow_duplicate=True), - dd.Output("scatter-plot", "selectedData", allow_duplicate=True), - ], - [ - dd.Input("color-mode", "value"), - dd.Input("show-arrows", "value"), - dd.Input("x-axis", "value"), - dd.Input("y-axis", "value"), - dd.Input("scatter-plot", "relayoutData"), - dd.Input("scatter-plot", "selectedData"), - ], - [dd.State("scatter-plot", "figure")], - prevent_initial_call=True, - ) - def update_figure( - color_mode, - show_arrows, - x_axis, - y_axis, - relayout_data, - selected_data, - current_figure, - ): - show_arrows = len(show_arrows or []) > 0 - - ctx = dash.callback_context - if not ctx.triggered: - triggered_id = "No clicks yet" - else: - triggered_id = ctx.triggered[0]["prop_id"].split(".")[0] - - # Create new figure when necessary - if triggered_id in [ - "color-mode", - "show-arrows", - "x-axis", - "y-axis", - ]: - if color_mode == "track": - fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis) - else: - fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis) - - # Update dragmode and selection settings - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - uirevision="true", - selectdirection="any", - ) - else: - fig = dash.no_update - - return fig, selected_data - - @self.app.callback( - dd.Output("track-timeline", "children"), - [dd.Input("scatter-plot", "clickData")], - prevent_initial_call=True, - ) - def update_track_timeline(clickData): - """Update the track timeline based on the clicked point""" - if clickData is None: - return html.Div("Click on a point to see the track timeline") - - # Parse the hover text to get track_id, time and fov_name - hover_text = clickData["points"][0]["text"] - track_id = int(hover_text.split("
")[0].split(": ")[1]) - clicked_time = int(hover_text.split("
")[1].split(": ")[1]) - fov_name = hover_text.split("
")[2].split(": ")[1] - - # Get all timepoints for this track - track_data = self.features_df[ - (self.features_df["fov_name"] == fov_name) - & (self.features_df["track_id"] == track_id) - ].sort_values("t") - - if track_data.empty: - return html.Div(f"No data found for track {track_id}") - - # Get unique timepoints - timepoints = track_data["t"].unique() - - # Create a list to store all timepoint columns - timepoint_columns = [] - - # First create the time labels row - time_labels = [] - for t in timepoints: - is_clicked = t == clicked_time - time_style = { - "width": "150px", - "textAlign": "center", - "padding": "5px", - "fontWeight": "bold" if is_clicked else "normal", - "color": "#007bff" if is_clicked else "black", - } - time_labels.append(html.Div(f"t={t}", style=time_style)) - - timepoint_columns.append( - html.Div( - time_labels, - style={ - "display": "flex", - "flexDirection": "row", - "minWidth": "fit-content", - "borderBottom": "2px solid #ddd", - "marginBottom": "10px", - "paddingBottom": "5px", - }, - ) - ) - - # Then create image rows for each channel - for channel in self.channels_to_display: - channel_images = [] - for t in timepoints: - cache_key = (fov_name, track_id, t) - if ( - cache_key in self.image_cache - and channel in self.image_cache[cache_key] - ): - is_clicked = t == clicked_time - image_style = { - "width": "150px", - "height": "150px", - "border": ( - "3px solid #007bff" if is_clicked else "1px solid #ddd" - ), - "borderRadius": "4px", - } - channel_images.append( - html.Div( - html.Img( - src=self.image_cache[cache_key][channel], - style=image_style, - ), - style={ - "width": "150px", - "padding": "5px", - }, - ) - ) - - if channel_images: - # Add channel label - timepoint_columns.append( - html.Div( - [ - html.Div( - channel, - style={ - "width": "100px", - "fontWeight": "bold", - "fontSize": "14px", - "padding": "5px", - "backgroundColor": "#f8f9fa", - "borderRadius": "4px", - "marginBottom": "5px", - "textAlign": "center", - }, - ), - html.Div( - channel_images, - style={ - "display": "flex", - "flexDirection": "row", - "minWidth": "fit-content", - "marginBottom": "15px", - }, - ), - ] - ) - ) - - # Create the main container with synchronized scrolling - return html.Div( - [ - html.H4( - f"Track {track_id} (FOV: {fov_name})", - style={ - "marginBottom": "20px", - "fontSize": "20px", - "fontWeight": "bold", - "color": "#2c3e50", - }, - ), - html.Div( - timepoint_columns, - style={ - "overflowX": "auto", - "overflowY": "hidden", - "whiteSpace": "nowrap", - "backgroundColor": "white", - "padding": "20px", - "borderRadius": "8px", - "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", - "marginBottom": "20px", - }, - ), - ] - ) - - # Add callback to show/hide clusters tab and handle modal - @self.app.callback( - [ - dd.Output("clusters-tab", "style"), - dd.Output("cluster-container", "children"), - dd.Output("view-tabs", "value"), - dd.Output("scatter-plot", "figure", allow_duplicate=True), - dd.Output("cluster-name-modal", "style"), - dd.Output("cluster-name-input", "value"), - dd.Output("scatter-plot", "selectedData", allow_duplicate=True), - ], - [ - dd.Input("assign-cluster", "n_clicks"), - dd.Input("clear-clusters", "n_clicks"), - dd.Input("save-cluster-name", "n_clicks"), - dd.Input("cancel-cluster-name", "n_clicks"), - dd.Input({"type": "edit-cluster-name", "index": dash.ALL}, "n_clicks"), - ], - [ - dd.State("scatter-plot", "selectedData"), - dd.State("scatter-plot", "figure"), - dd.State("color-mode", "value"), - dd.State("show-arrows", "value"), - dd.State("x-axis", "value"), - dd.State("y-axis", "value"), - dd.State("cluster-name-input", "value"), - ], - prevent_initial_call=True, - ) - def update_clusters_tab( - assign_clicks, - clear_clicks, - save_name_clicks, - cancel_name_clicks, - edit_name_clicks, - selected_data, - current_figure, - color_mode, - show_arrows, - x_axis, - y_axis, - cluster_name, - ): - ctx = dash.callback_context - if not ctx.triggered: - return ( - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - ) - - button_id = ctx.triggered[0]["prop_id"].split(".")[0] - - # Handle edit cluster name button clicks - if button_id.startswith('{"type":"edit-cluster-name"'): - try: - id_dict = json.loads(button_id) - cluster_idx = id_dict["index"] - - # Get current cluster name - current_name = self.cluster_names.get( - cluster_idx, f"Cluster {cluster_idx + 1}" - ) - - # Show modal - modal_style = { - "display": "flex", - "position": "fixed", - "top": "0", - "left": "0", - "width": "100%", - "height": "100%", - "backgroundColor": "rgba(0, 0, 0, 0.5)", - "zIndex": "1000", - "justifyContent": "center", - "alignItems": "center", - } - - return ( - {"display": "block"}, - self._get_cluster_images(), - "clusters-tab", - dash.no_update, - modal_style, - current_name, - dash.no_update, # Don't change selection - ) - except Exception: - return ( - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - ) - - if ( - button_id == "assign-cluster" - and selected_data - and selected_data.get("points") - ): - # Create new cluster from selected points - new_cluster = [] - for point in selected_data["points"]: - text = point["text"] - lines = text.split("
") - track_id = int(lines[0].split(": ")[1]) - t = int(lines[1].split(": ")[1]) - fov = lines[2].split(": ")[1] - - cache_key = (fov, track_id, t) - if cache_key in self.image_cache: - new_cluster.append( - { - "track_id": track_id, - "t": t, - "fov_name": fov, - } - ) - self.cluster_points.add(cache_key) - - if new_cluster: - # Add cluster to list but don't assign name yet - self.clusters.append(new_cluster) - # Open modal for naming - modal_style = { - "display": "flex", - "position": "fixed", - "top": "0", - "left": "0", - "width": "100%", - "height": "100%", - "backgroundColor": "rgba(0, 0, 0, 0.5)", - "zIndex": "1000", - "justifyContent": "center", - "alignItems": "center", - } - return ( - {"display": "block"}, - self._get_cluster_images(), - "clusters-tab", - dash.no_update, # Don't update figure yet - modal_style, # Show modal - "", # Clear input - None, # Clear selection - ) - - elif button_id == "save-cluster-name" and cluster_name: - # Assign name to the most recently created cluster - if self.clusters: - cluster_id = len(self.clusters) - 1 - self.cluster_names[cluster_id] = cluster_name.strip() - - # Create new figure with updated colors - fig = self._create_track_colored_figure( - len(show_arrows or []) > 0, - x_axis, - y_axis, - ) - # Ensure the dragmode is set based on selection_mode - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - uirevision="true", # Keep the UI state - selectdirection="any", - ) - modal_style = {"display": "none"} - return ( - {"display": "block"}, - self._get_cluster_images(), - "clusters-tab", - fig, - modal_style, # Hide modal - "", # Clear input - None, # Clear selection - ) - - elif button_id == "cancel-cluster-name": - # Remove the cluster that was just created - if self.clusters: - # Remove points from cluster_points set - for point in self.clusters[-1]: - cache_key = (point["fov_name"], point["track_id"], point["t"]) - self.cluster_points.discard(cache_key) - # Remove the cluster - self.clusters.pop() - - # Create new figure with updated colors - fig = self._create_track_colored_figure( - len(show_arrows or []) > 0, - x_axis, - y_axis, - ) - # Ensure the dragmode is set based on selection_mode - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - uirevision="true", # Keep the UI state - selectdirection="any", - ) - modal_style = {"display": "none"} - return ( - ( - {"display": "none"} - if not self.clusters - else {"display": "block"} - ), - self._get_cluster_images() if self.clusters else None, - "timeline-tab" if not self.clusters else "clusters-tab", - fig, - modal_style, # Hide modal - "", # Clear input - None, # Clear selection - ) - - elif button_id == "clear-clusters": - self.clusters = [] - self.cluster_points.clear() - self.cluster_names.clear() - # Restore original coloring - fig = self._create_track_colored_figure( - len(show_arrows or []) > 0, - x_axis, - y_axis, - ) - # Reset UI state completely to ensure clean slate - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - uirevision=None, # Reset UI state completely - selectdirection="any", - ) - modal_style = {"display": "none"} - return ( - {"display": "none"}, - None, - "timeline-tab", - fig, - modal_style, - "", - None, - ) # Clear selection - - return ( - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - dash.no_update, - ) - - # Add callback for saving clusters to CSV - @self.app.callback( - dd.Output("cluster-container", "children", allow_duplicate=True), - [dd.Input("save-clusters-csv", "n_clicks")], - prevent_initial_call=True, - ) - def save_clusters_csv(n_clicks): - """Callback to save clusters to CSV file""" - if n_clicks and self.clusters: - try: - output_path = self.save_clusters_to_csv() - return html.Div( - [ - html.H3("Clusters", style={"marginBottom": "20px"}), - html.Div( - f"✅ Successfully saved {len(self.clusters)} clusters to: {output_path}", - style={ - "backgroundColor": "#d4edda", - "color": "#155724", - "padding": "10px", - "borderRadius": "4px", - "marginBottom": "20px", - "border": "1px solid #c3e6cb", - }, - ), - self._get_cluster_images(), - ] - ) - except Exception as e: - return html.Div( - [ - html.H3("Clusters", style={"marginBottom": "20px"}), - html.Div( - f"❌ Error saving clusters: {str(e)}", - style={ - "backgroundColor": "#f8d7da", - "color": "#721c24", - "padding": "10px", - "borderRadius": "4px", - "marginBottom": "20px", - "border": "1px solid #f5c6cb", - }, - ), - self._get_cluster_images(), - ] - ) - elif n_clicks and not self.clusters: - return html.Div( - [ - html.H3("Clusters", style={"marginBottom": "20px"}), - html.Div( - "⚠️ No clusters to save. Create clusters first by selecting points and clicking 'Assign to New Cluster'.", - style={ - "backgroundColor": "#fff3cd", - "color": "#856404", - "padding": "10px", - "borderRadius": "4px", - "marginBottom": "20px", - "border": "1px solid #ffeaa7", - }, - ), - ] - ) - return dash.no_update - - @self.app.callback( - [ - dd.Output("scatter-plot", "figure", allow_duplicate=True), - dd.Output("scatter-plot", "selectedData", allow_duplicate=True), - ], - [dd.Input("clear-selection", "n_clicks")], - [ - dd.State("color-mode", "value"), - dd.State("show-arrows", "value"), - dd.State("x-axis", "value"), - dd.State("y-axis", "value"), - ], - prevent_initial_call=True, - ) - def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): - """Callback to clear the selection and restore original opacity""" - if n_clicks: - # Create a new figure with no selections - if color_mode == "track": - fig = self._create_track_colored_figure( - len(show_arrows or []) > 0, - x_axis, - y_axis, - ) - else: - fig = self._create_time_colored_figure( - len(show_arrows or []) > 0, - x_axis, - y_axis, - ) - - # Update layout to maintain lasso mode but clear selections - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - uirevision=None, # Reset UI state - selectdirection="any", - ) - - return fig, None # Return new figure and clear selectedData - return dash.no_update, dash.no_update - - def _calculate_equal_aspect_ranges(self, x_data, y_data): - """Calculate ranges for x and y axes to ensure equal aspect ratio. - - Parameters - ---------- - x_data : array-like - Data for x-axis - y_data : array-like - Data for y-axis - - Returns - ------- - tuple - (x_range, y_range) as tuples of (min, max) with equal scaling - """ - # Get data ranges - x_min, x_max = np.min(x_data), np.max(x_data) - y_min, y_max = np.min(y_data), np.max(y_data) - - # Add padding (5% on each side) - x_padding = 0.05 * (x_max - x_min) - y_padding = 0.05 * (y_max - y_min) - - x_min -= x_padding - x_max += x_padding - y_min -= y_padding - y_max += y_padding - - # Ensure equal scaling by using the larger range - x_range = x_max - x_min - y_range = y_max - y_min - - if x_range > y_range: - # Expand y-range to match x-range aspect ratio - y_center = (y_max + y_min) / 2 - y_min = y_center - x_range / 2 - y_max = y_center + x_range / 2 - else: - # Expand x-range to match y-range aspect ratio - x_center = (x_max + x_min) / 2 - x_min = x_center - y_range / 2 - x_max = x_center + y_range / 2 - - return (x_min, x_max), (y_min, y_max) - - def _create_track_colored_figure( - self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): - """Create scatter plot with track-based coloring""" - x_axis = x_axis or self.default_x - y_axis = y_axis or self.default_y - - unique_tracks = self.filtered_features_df["track_id"].unique() - cmap = plt.cm.tab20 - track_colors = { - track_id: f"rgb{tuple(int(x * 255) for x in cmap(i % 20)[:3])}" - for i, track_id in enumerate(unique_tracks) - } - - fig = go.Figure() - - # Set initial layout with lasso mode - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - selectdirection="any", - plot_bgcolor="white", - title="PCA visualization of Selected Tracks", - xaxis_title=x_axis, - yaxis_title=y_axis, - uirevision=True, - hovermode="closest", - showlegend=True, - legend=dict( - yanchor="top", - y=1, - xanchor="left", - x=1.02, - title="Tracks", - bordercolor="Black", - borderwidth=1, - ), - margin=dict(l=50, r=150, t=50, b=50), - autosize=True, - ) - fig.update_xaxes(showgrid=False) - fig.update_yaxes(showgrid=False) - - # Add background points with hover info (excluding the colored tracks) - background_df = self.features_df[ - (self.features_df["fov_name"].isin(self.fov_tracks.keys())) - & (~self.features_df["track_id"].isin(unique_tracks)) - ] - - if not background_df.empty: - # Subsample background points if there are too many - if len(background_df) > 5000: # Adjust this threshold as needed - background_df = background_df.sample(n=5000, random_state=42) - - fig.add_trace( - go.Scattergl( - x=background_df[x_axis], - y=background_df[y_axis], - mode="markers", - marker=dict(size=12, color="lightgray", opacity=0.3), - name=f"Other tracks (showing {len(background_df)} of {len(self.features_df)} points)", - text=[ - f"Track: {track_id}
Time: {t}
FOV: {fov}" - for track_id, t, fov in zip( - background_df["track_id"], - background_df["t"], - background_df["fov_name"], - ) - ], - hoverinfo="text", - showlegend=True, - hoverlabel=dict(namelength=-1), - ) - ) - - # Add points for each selected track - for track_id in unique_tracks: - track_data = self.filtered_features_df[ - self.filtered_features_df["track_id"] == track_id - ].sort_values("t") - - # Get points for this track that are in clusters - track_points = list( - zip( - [fov for fov in track_data["fov_name"]], - [track_id] * len(track_data), - [t for t in track_data["t"]], - ) - ) - - # Determine colors based on cluster membership - colors = [] - opacities = [] - if self.clusters: - cluster_colors = [ - f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}" - for i in range(len(self.clusters)) - ] - point_to_cluster = {} - for cluster_idx, cluster in enumerate(self.clusters): - for point in cluster: - point_key = (point["fov_name"], point["track_id"], point["t"]) - point_to_cluster[point_key] = cluster_idx - - for point in track_points: - if point in point_to_cluster: - colors.append(cluster_colors[point_to_cluster[point]]) - opacities.append(1.0) - else: - colors.append("lightgray") - opacities.append(0.3) - else: - colors = [track_colors[track_id]] * len(track_data) - opacities = [1.0] * len(track_data) - - # Add points using Scattergl for better performance - scatter_kwargs = { - "x": track_data[x_axis], - "y": track_data[y_axis], - "mode": "markers", - "marker": dict( - size=10, # Reduced size - color=colors, - line=dict(width=1, color="black"), - opacity=opacities, - ), - "name": f"Track {track_id}", - "text": [ - f"Track: {track_id}
Time: {t}
FOV: {fov}" - for t, fov in zip(track_data["t"], track_data["fov_name"]) - ], - "hoverinfo": "text", - "hoverlabel": dict(namelength=-1), # Show full text in hover - } - - # Only apply selection properties if there are clusters - # This prevents opacity conflicts when no clusters exist - if self.clusters: - scatter_kwargs.update( - { - "unselected": dict(marker=dict(opacity=0.3, size=10)), - "selected": dict(marker=dict(size=12, opacity=1.0)), - } - ) - - fig.add_trace(go.Scattergl(**scatter_kwargs)) - - # Add trajectory lines and arrows if requested - if show_arrows and len(track_data) > 1: - x_coords = track_data[x_axis].values - y_coords = track_data[y_axis].values - - # Add dashed lines for the trajectory using Scattergl - fig.add_trace( - go.Scattergl( - x=x_coords, - y=y_coords, - mode="lines", - line=dict( - color=track_colors[track_id], - width=1, - dash="dot", - ), - showlegend=False, - hoverinfo="skip", - ) - ) - - # Add arrows at regular intervals (reduced frequency) - arrow_interval = max( - 1, len(track_data) // 3 - ) # Reduced number of arrows - for i in range(0, len(track_data) - 1, arrow_interval): - # Calculate arrow angle - dx = x_coords[i + 1] - x_coords[i] - dy = y_coords[i + 1] - y_coords[i] - - # Only add arrow if there's significant movement - if dx * dx + dy * dy > 1e-6: # Minimum distance threshold - # Add arrow annotation - fig.add_annotation( - x=x_coords[i + 1], - y=y_coords[i + 1], - ax=x_coords[i], - ay=y_coords[i], - xref="x", - yref="y", - axref="x", - ayref="y", - showarrow=True, - arrowhead=2, - arrowsize=1, # Reduced size - arrowwidth=1, # Reduced width - arrowcolor=track_colors[track_id], - opacity=0.8, - ) - - # Compute axis ranges to ensure equal aspect ratio - all_x_data = self.filtered_features_df[x_axis] - all_y_data = self.filtered_features_df[y_axis] - - if not all_x_data.empty and not all_y_data.empty: - x_range, y_range = self._calculate_equal_aspect_ranges( - all_x_data, all_y_data - ) - - # Set equal aspect ratio and range - fig.update_layout( - xaxis=dict( - range=x_range, scaleanchor="y", scaleratio=1, constrain="domain" - ), - yaxis=dict(range=y_range, constrain="domain"), - ) - - return fig - - def _create_time_colored_figure( - self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): - """Create scatter plot with time-based coloring""" - x_axis = x_axis or self.default_x - y_axis = y_axis or self.default_y - - fig = go.Figure() - - # Set initial layout with lasso mode - fig.update_layout( - dragmode="lasso", - clickmode="event+select", - selectdirection="any", - plot_bgcolor="white", - title="PCA visualization of Selected Tracks", - xaxis_title=x_axis, - yaxis_title=y_axis, - uirevision=True, - hovermode="closest", - showlegend=True, - legend=dict( - yanchor="top", - y=1, - xanchor="left", - x=1.02, - title="Tracks", - bordercolor="Black", - borderwidth=1, - ), - margin=dict(l=50, r=150, t=50, b=50), - autosize=True, - ) - fig.update_xaxes(showgrid=False) - fig.update_yaxes(showgrid=False) - - # Add background points with hover info - all_tracks_df = self.features_df[ - self.features_df["fov_name"].isin(self.fov_tracks.keys()) - ] - - # Subsample background points if there are too many - if len(all_tracks_df) > 5000: # Adjust this threshold as needed - all_tracks_df = all_tracks_df.sample(n=5000, random_state=42) - - fig.add_trace( - go.Scattergl( - x=all_tracks_df[x_axis], - y=all_tracks_df[y_axis], - mode="markers", - marker=dict(size=12, color="lightgray", opacity=0.3), - name=f"Other points (showing {len(all_tracks_df)} of {len(self.features_df)} points)", - text=[ - f"Track: {track_id}
Time: {t}
FOV: {fov}" - for track_id, t, fov in zip( - all_tracks_df["track_id"], - all_tracks_df["t"], - all_tracks_df["fov_name"], - ) - ], - hoverinfo="text", - hoverlabel=dict(namelength=-1), - ) - ) - - # Add time-colored points using Scattergl - fig.add_trace( - go.Scattergl( - x=self.filtered_features_df[x_axis], - y=self.filtered_features_df[y_axis], - mode="markers", - marker=dict( - size=10, # Reduced size - color=self.filtered_features_df["t"], - colorscale="Viridis", - colorbar=dict(title="Time"), - ), - text=[ - f"Track: {track_id}
Time: {t}
FOV: {fov}" - for track_id, t, fov in zip( - self.filtered_features_df["track_id"], - self.filtered_features_df["t"], - self.filtered_features_df["fov_name"], - ) - ], - hoverinfo="text", - showlegend=False, - hoverlabel=dict(namelength=-1), # Show full text in hover - ) - ) - - # Add arrows if requested, but more efficiently - if show_arrows: - for track_id in self.filtered_features_df["track_id"].unique(): - track_data = self.filtered_features_df[ - self.filtered_features_df["track_id"] == track_id - ].sort_values("t") - - if len(track_data) > 1: - # Calculate distances between consecutive points - x_coords = track_data[x_axis].values - y_coords = track_data[y_axis].values - distances = np.sqrt(np.diff(x_coords) ** 2 + np.diff(y_coords) ** 2) - - # Only show arrows for movements larger than the median distance - threshold = np.median(distances) * 0.5 - - # Add arrows as a single trace - arrow_x = [] - arrow_y = [] - - for i in range(len(track_data) - 1): - if distances[i] > threshold: - arrow_x.extend([x_coords[i], x_coords[i + 1], None]) - arrow_y.extend([y_coords[i], y_coords[i + 1], None]) - - if arrow_x: # Only add if there are arrows to show - fig.add_trace( - go.Scatter( - x=arrow_x, - y=arrow_y, - mode="lines", - line=dict( - color="rgba(128, 128, 128, 0.5)", - width=1, - dash="dot", - ), - showlegend=False, - hoverinfo="skip", - ) - ) - - # Compute axis ranges to ensure equal aspect ratio - all_x_data = self.filtered_features_df[x_axis] - all_y_data = self.filtered_features_df[y_axis] - if not all_x_data.empty and not all_y_data.empty: - x_range, y_range = self._calculate_equal_aspect_ranges( - all_x_data, all_y_data - ) - - # Set equal aspect ratio and range - fig.update_layout( - xaxis=dict( - range=x_range, scaleanchor="y", scaleratio=1, constrain="domain" - ), - yaxis=dict(range=y_range, constrain="domain"), - ) - - return fig - - @staticmethod - def _normalize_image(img_array): - """Normalize a single image array to [0, 255] more efficiently""" - min_val = img_array.min() - max_val = img_array.max() - if min_val == max_val: - return np.zeros_like(img_array, dtype=np.uint8) - # Normalize in one step - return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8) - - @staticmethod - def _numpy_to_base64(img_array): - """Convert numpy array to base64 string with compression""" - if not isinstance(img_array, np.uint8): - img_array = img_array.astype(np.uint8) - img = Image.fromarray(img_array) - buffered = BytesIO() - # Use JPEG format with quality=85 for better compression - img.save(buffered, format="JPEG", quality=85, optimize=True) - return "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode( - "utf-8" - ) - - def save_cache(self, cache_path: str | None = None): - """Save the image cache to disk using pickle. - - Parameters - ---------- - cache_path : str | None, optional - Path to save the cache. If None, uses self.cache_path, by default None - """ - import pickle - - if cache_path is None: - if self.cache_path is None: - logger.warning("No cache path specified, skipping cache save") - return - cache_path = self.cache_path - else: - cache_path = Path(cache_path) - - # Create parent directory if it doesn't exist - cache_path.parent.mkdir(parents=True, exist_ok=True) - - # Save cache metadata for validation - cache_metadata = { - "data_path": str(self.data_path), - "tracks_path": str(self.tracks_path), - "features_path": str(self.features_path), - "channels": self.channels_to_display, - "z_range": self.z_range, - "yx_patch_size": self.yx_patch_size, - "cache_size": len(self.image_cache), - } - - try: - logger.info(f"Saving image cache to {cache_path}") - with open(cache_path, "wb") as f: - pickle.dump((cache_metadata, self.image_cache), f) - logger.info(f"Successfully saved cache with {len(self.image_cache)} images") - except Exception as e: - logger.error(f"Error saving cache: {e}") - - def load_cache(self, cache_path: str | None = None) -> bool: - """Load the image cache from disk using pickle. - - Parameters - ---------- - cache_path : str | None, optional - Path to load the cache from. If None, uses self.cache_path, by default None - - Returns - ------- - bool - True if cache was successfully loaded, False otherwise - """ - import pickle - - if cache_path is None: - if self.cache_path is None: - logger.warning("No cache path specified, skipping cache load") - return False - cache_path = self.cache_path - else: - cache_path = Path(cache_path) - - if not cache_path.exists(): - logger.warning(f"Cache file {cache_path} does not exist") - return False - - try: - logger.info(f"Loading image cache from {cache_path}") - with open(cache_path, "rb") as f: - cache_metadata, loaded_cache = pickle.load(f) - - # Validate cache metadata - if ( - cache_metadata["data_path"] != str(self.data_path) - or cache_metadata["tracks_path"] != str(self.tracks_path) - or cache_metadata["features_path"] != str(self.features_path) - or cache_metadata["channels"] != self.channels_to_display - or cache_metadata["z_range"] != self.z_range - or cache_metadata["yx_patch_size"] != self.yx_patch_size - ): - logger.warning("Cache metadata mismatch, skipping cache load") - return False - - self.image_cache = loaded_cache - logger.info( - f"Successfully loaded cache with {len(self.image_cache)} images" - ) - return True - except Exception as e: - logger.error(f"Error loading cache: {e}") - return False - - def preload_images(self): - """Preload all images into memory""" - # Try to load from cache first - if self.cache_path and self.load_cache(): - return - - logger.info("Preloading images into cache...") - logger.info(f"FOVs to process: {list(self.filtered_tracks_by_fov.keys())}") - - # Process each FOV and its tracks - for fov_name, track_ids in self.filtered_tracks_by_fov.items(): - if not track_ids: # Skip FOVs with no tracks - logger.info(f"Skipping FOV {fov_name} as it has no tracks") - continue - - logger.info(f"Processing FOV {fov_name} with tracks {track_ids}") - - try: - data_module = TripletDataModule( - data_path=self.data_path, - tracks_path=self.tracks_path, - include_fov_names=[fov_name] * len(track_ids), - include_track_ids=track_ids, - source_channel=self.channels_to_display, - z_range=self.z_range, - initial_yx_patch_size=self.yx_patch_size, - final_yx_patch_size=self.yx_patch_size, - batch_size=1, - num_workers=self.num_loading_workers, - normalizations=None, - predict_cells=True, - ) - data_module.setup("predict") - - for batch in data_module.predict_dataloader(): - try: - images = batch["anchor"].numpy() - indices = batch["index"] - track_id = indices["track_id"].tolist() - t = indices["t"].tolist() - - img = np.stack(images) - cache_key = (fov_name, track_id[0], t[0]) - - logger.debug(f"Processing cache key: {cache_key}") - - # Process each channel based on its type - processed_channels = {} - for idx, channel in enumerate(self.channels_to_display): - try: - if channel in ["Phase3D", "DIC", "BF"]: - # For phase contrast, use the middle z-slice - z_idx = (self.z_range[1] - self.z_range[0]) // 2 - processed = self._normalize_image( - img[0, idx, z_idx] - ) - else: - # For fluorescence, use max projection - processed = self._normalize_image( - np.max(img[0, idx], axis=0) - ) - - processed_channels[channel] = self._numpy_to_base64( - processed - ) - logger.debug( - f"Successfully processed channel {channel} for {cache_key}" - ) - except Exception as e: - logger.error( - f"Error processing channel {channel} for {cache_key}: {e}" - ) - continue - - if ( - processed_channels - ): # Only store if at least one channel was processed - self.image_cache[cache_key] = processed_channels - - except Exception as e: - logger.error( - f"Error processing batch for {fov_name}, track {track_id}: {e}" - ) - continue - - except Exception as e: - logger.error(f"Error setting up data module for FOV {fov_name}: {e}") - continue - - logger.info(f"Successfully cached {len(self.image_cache)} images") - # Log some statistics about the cache - cached_fovs = set(key[0] for key in self.image_cache.keys()) - cached_tracks = set((key[0], key[1]) for key in self.image_cache.keys()) - logger.info(f"Cached FOVs: {cached_fovs}") - logger.info(f"Number of unique track-FOV combinations: {len(cached_tracks)}") - - # Save cache if path is specified - if self.cache_path: - self.save_cache() - - def _cleanup_cache(self): - """Clear the image cache when the program exits""" - logging.info("Cleaning up image cache...") - self.image_cache.clear() - - def _get_trajectory_images_lasso(self, x_axis, y_axis, selected_data): - """Get images of points selected by lasso""" - if not selected_data or not selected_data.get("points"): - return html.Div("Use the lasso tool to select points") - - # Dictionary to store points for each lasso selection - lasso_clusters = {} - - # Track which points we've seen to avoid duplicates within clusters - seen_points = set() - - # Process each selected point - for point in selected_data["points"]: - text = point["text"] - lines = text.split("
") - track_id = int(lines[0].split(": ")[1]) - t = int(lines[1].split(": ")[1]) - fov = lines[2].split(": ")[1] - - point_id = (track_id, t, fov) - cache_key = (fov, track_id, t) - - # Skip if we don't have the image in cache - if cache_key not in self.image_cache: - logger.debug(f"Skipping point {point_id} as it's not in the cache") - continue - - # Determine which curve (lasso selection) this point belongs to - curve_number = point.get("curveNumber", 0) - if curve_number not in lasso_clusters: - lasso_clusters[curve_number] = [] - - # Only add if we haven't seen this point in this cluster - cluster_point_id = (curve_number, point_id) - if cluster_point_id not in seen_points: - seen_points.add(cluster_point_id) - lasso_clusters[curve_number].append( - { - "track_id": track_id, - "t": t, - "fov_name": fov, - x_axis: point["x"], - y_axis: point["y"], - } - ) - - if not lasso_clusters: - return html.Div("No cached images found for the selected points") - - # Create sections for each lasso selection - cluster_sections = [] - for cluster_idx, points in lasso_clusters.items(): - cluster_df = pd.DataFrame(points) - - # Create channel rows for this cluster - channel_rows = [] - for channel in self.channels_to_display: - images = [] - for _, row in cluster_df.iterrows(): - cache_key = (row["fov_name"], row["track_id"], row["t"]) - images.append( - html.Div( - [ - html.Img( - src=self.image_cache[cache_key][channel], - style={ - "width": "150px", - "height": "150px", - "margin": "5px", - "border": "1px solid #ddd", - }, - ), - html.Div( - f"Track {row['track_id']}, t={row['t']}", - style={ - "textAlign": "center", - "fontSize": "12px", - }, - ), - ], - style={ - "display": "inline-block", - "margin": "5px", - "verticalAlign": "top", - }, - ) - ) - - if images: # Only add row if there are images - channel_rows.extend( - [ - html.H5( - f"{channel}", - style={ - "margin": "10px 5px", - "fontSize": "16px", - "fontWeight": "bold", - }, - ), - html.Div( - images, - style={ - "overflowX": "auto", - "whiteSpace": "nowrap", - "padding": "10px", - "border": "1px solid #ddd", - "borderRadius": "5px", - "marginBottom": "20px", - "backgroundColor": "#f8f9fa", - }, - ), - ] - ) - - if channel_rows: # Only add cluster section if it has images - cluster_sections.append( - html.Div( - [ - html.H3( - f"Lasso Selection {cluster_idx + 1}", - style={ - "marginTop": "30px", - "marginBottom": "15px", - "fontSize": "24px", - "fontWeight": "bold", - "borderBottom": "2px solid #007bff", - "paddingBottom": "5px", - }, - ), - html.Div( - channel_rows, - style={ - "backgroundColor": "#ffffff", - "padding": "15px", - "borderRadius": "8px", - "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", - }, - ), - ] - ) - ) - - return html.Div( - [ - html.H2( - f"Selected Points ({len(cluster_sections)} selections)", - style={ - "marginBottom": "20px", - "fontSize": "28px", - "fontWeight": "bold", - "color": "#2c3e50", - }, - ), - html.Div(cluster_sections), - ] - ) - - def _get_output_info_display(self) -> html.Div: - """ - Create a display component showing the output directory information. - - Returns - ------- - html.Div - HTML component displaying output directory info - """ - return html.Div( - [ - html.H4( - "Output Directory", - style={"marginBottom": "10px", "fontSize": "16px"}, - ), - html.Div( - [ - html.Span("📁 ", style={"fontSize": "14px"}), - html.Span( - str(self.output_dir), - style={ - "fontFamily": "monospace", - "backgroundColor": "#f8f9fa", - "padding": "4px 8px", - "borderRadius": "4px", - "border": "1px solid #dee2e6", - "fontSize": "12px", - }, - ), - ], - style={"marginBottom": "10px"}, - ), - html.Div( - "CSV files will be saved to this directory with timestamped names.", - style={ - "fontSize": "12px", - "color": "#6c757d", - "fontStyle": "italic", - }, - ), - ], - style={ - "backgroundColor": "#e9ecef", - "padding": "10px", - "borderRadius": "6px", - "marginBottom": "15px", - "border": "1px solid #ced4da", - }, - ) - - def _get_cluster_images(self): - """Display images for all clusters in a grid layout""" - if not self.clusters: - return html.Div( - [self._get_output_info_display(), html.Div("No clusters created yet")] - ) - - # Create cluster colors once - cluster_colors = [ - f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}" - for i in range(len(self.clusters)) - ] - - # Create individual cluster panels - cluster_panels = [] - for cluster_idx, cluster_points in enumerate(self.clusters): - # Get cluster name or use default - cluster_name = self.cluster_names.get( - cluster_idx, f"Cluster {cluster_idx + 1}" - ) - - # Create a single scrollable container for all channels - all_channel_images = [] - for channel in self.channels_to_display: - images = [] - for point in cluster_points: - cache_key = (point["fov_name"], point["track_id"], point["t"]) - - images.append( - html.Div( - [ - html.Img( - src=self.image_cache[cache_key][channel], - style={ - "width": "100px", - "height": "100px", - "margin": "2px", - "border": f"2px solid {cluster_colors[cluster_idx]}", - "borderRadius": "4px", - }, - ), - html.Div( - f"Track {point['track_id']}, t={point['t']}", - style={ - "textAlign": "center", - "fontSize": "10px", - }, - ), - ], - style={ - "display": "inline-block", - "margin": "2px", - "verticalAlign": "top", - }, - ) - ) - - if images: - all_channel_images.extend( - [ - html.H6( - f"{channel}", - style={ - "margin": "5px", - "fontSize": "12px", - "fontWeight": "bold", - "position": "sticky", - "left": "0", - "backgroundColor": "#f8f9fa", - "zIndex": "1", - "paddingLeft": "5px", - }, - ), - html.Div( - images, - style={ - "whiteSpace": "nowrap", - "marginBottom": "10px", - }, - ), - ] - ) - - if all_channel_images: - # Create a panel for this cluster with synchronized scrolling - cluster_panels.append( - html.Div( - [ - html.Div( - [ - html.Span( - cluster_name, - style={ - "color": cluster_colors[cluster_idx], - "fontWeight": "bold", - "fontSize": "16px", - }, - ), - html.Span( - f" ({len(cluster_points)} points)", - style={ - "color": "#2c3e50", - "fontSize": "14px", - }, - ), - html.Button( - "✏️", - id={ - "type": "edit-cluster-name", - "index": cluster_idx, - }, - style={ - "backgroundColor": "transparent", - "border": "none", - "cursor": "pointer", - "fontSize": "12px", - "marginLeft": "5px", - "color": "#6c757d", - }, - title="Edit cluster name", - ), - ], - style={ - "marginBottom": "10px", - "borderBottom": f"2px solid {cluster_colors[cluster_idx]}", - "paddingBottom": "5px", - "position": "sticky", - "top": "0", - "backgroundColor": "white", - "zIndex": "1", - }, - ), - html.Div( - all_channel_images, - style={ - "overflowX": "auto", - "overflowY": "auto", - "height": "400px", - "backgroundColor": "#ffffff", - "padding": "10px", - "borderRadius": "8px", - "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", - }, - ), - ], - style={ - "width": "24%", - "display": "inline-block", - "verticalAlign": "top", - "padding": "5px", - "boxSizing": "border-box", - }, - ) - ) - - # Create rows of 4 panels each - rows = [] - for i in range(0, len(cluster_panels), 4): - row = html.Div( - cluster_panels[i : i + 4], - style={ - "display": "flex", - "justifyContent": "flex-start", - "gap": "10px", - "marginBottom": "10px", - }, - ) - rows.append(row) - - return html.Div( - [ - html.H2( - [ - "Clusters ", - html.Span( - f"({len(self.clusters)} total)", - style={"color": "#666"}, - ), - ], - style={ - "marginBottom": "20px", - "fontSize": "28px", - "fontWeight": "bold", - "color": "#2c3e50", - }, - ), - self._get_output_info_display(), - html.Div( - rows, - style={ - "maxHeight": "calc(100vh - 200px)", - "overflowY": "auto", - "padding": "10px", - }, - ), - ] - ) - - def get_output_dir(self) -> Path: - """ - Get the output directory for saving files. - - Returns - ------- - Path - The output directory path - """ - return self.output_dir - - def save_clusters_to_csv(self, output_path: str | None = None) -> str: - """ - Save cluster information to CSV file. - - This method exports all cluster data including track_id, time, FOV, - cluster assignment, and cluster names to a CSV file for further analysis. - - Parameters - ---------- - output_path : str | None, optional - Path to save the CSV file. If None, generates a timestamped filename - in the output directory, by default None - - Returns - ------- - str - Path to the saved CSV file - - Notes - ----- - The CSV will contain columns: - - cluster_id: The cluster number (1-indexed) - - cluster_name: The custom name assigned to the cluster - - track_id: The track identifier - - time: The timepoint - - fov_name: The field of view name - - cluster_size: Number of points in the cluster - """ - if not self.clusters: - logger.warning("No clusters to save") - return "" - - # Prepare data for CSV export - csv_data = [] - for cluster_idx, cluster in enumerate(self.clusters): - cluster_id = cluster_idx + 1 # 1-indexed for user-friendly output - cluster_size = len(cluster) - cluster_name = self.cluster_names.get(cluster_idx, f"Cluster {cluster_id}") - - for point in cluster: - csv_data.append( - { - "cluster_id": cluster_id, - "cluster_name": cluster_name, - "track_id": point["track_id"], - "time": point["t"], - "fov_name": point["fov_name"], - "cluster_size": cluster_size, - } - ) - - # Create DataFrame and save to CSV - df = pd.DataFrame(csv_data) - - if output_path is None: - # Generate timestamped filename in output directory - from datetime import datetime - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = self.output_dir / f"clusters_{timestamp}.csv" - else: - output_path = Path(output_path) - # If only filename is provided, use output directory - if not output_path.parent.name: - output_path = self.output_dir / output_path.name - - try: - # Create parent directory if it doesn't exist - output_path.parent.mkdir(parents=True, exist_ok=True) - - df.to_csv(output_path, index=False) - logger.info(f"Successfully saved {len(df)} cluster points to {output_path}") - return str(output_path) - - except Exception as e: - logger.error(f"Error saving clusters to CSV: {e}") - raise - - def run(self, debug=False, port=None): - """Run the Dash server - - Parameters - ---------- - debug : bool, optional - Whether to run in debug mode, by default False - port : int, optional - Port to run on. If None, will try ports from 8050-8070, by default None - """ - import socket - - def is_port_in_use(port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(("127.0.0.1", port)) - return False - except socket.error: - return True - - if port is None: - # Try ports from 8050 to 8070 - # FIXME: set a range for the ports - port_range = list(range(8050, 8071)) - for p in port_range: - if not is_port_in_use(p): - port = p - break - if port is None: - raise RuntimeError( - f"Could not find an available port in range {port_range[0]}-{port_range[-1]}" - ) - - try: - logger.info(f"Starting server on port {port}") - self.app.run( - debug=debug, - port=port, - use_reloader=False, # Disable reloader to prevent multiple instances - ) - except KeyboardInterrupt: - logger.info("Server shutdown requested...") - except Exception as e: - logger.error(f"Error running server: {e}") - finally: - self._cleanup_cache() - logger.info("Server shutdown complete") diff --git a/viscy/representation/visualization/__init__.py b/viscy/representation/visualization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/viscy/representation/visualization/app.py b/viscy/representation/visualization/app.py new file mode 100644 index 000000000..9403317ed --- /dev/null +++ b/viscy/representation/visualization/app.py @@ -0,0 +1,2706 @@ +import atexit +import json +import logging +import time +from pathlib import Path +from typing import Union + +import dash +import dash.dependencies as dd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from dash import dcc, html +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.combined_analysis import ( + compute_phate_for_combined_datasets, + load_and_combine_features, +) +from viscy.representation.evaluation.data_loading import EmbeddingDataLoader +from viscy.representation.evaluation.dimensionality_reduction import ( + compute_pca, +) +from viscy.representation.visualization.cluster import ( + ClusterManager, +) +from viscy.representation.visualization.settings import VizConfig + +logger = logging.getLogger("lightning.pytorch") + + +class EmbeddingVisualizationApp: + def __init__( + self, + viz_config: VizConfig, + cache_path: str | None = None, + num_loading_workers: int = 16, + output_dir: str | None = None, + loader: EmbeddingDataLoader | None = None, + ) -> None: + """ + Initialize a Dash application for visualizing the DynaCLR embeddings. + + This class provides a visualization tool for visualizing the DynaCLR embeddings into a 2D space (e.g. PCA, UMAP, PHATE). + It allows users to interactively explore and analyze trajectories, visualize clusters, and explore the embedding space. + + Parameters + ---------- + viz_config: VizConfig + Configuration object for visualization. + cache_path: str | None + Path to the cache directory. + num_loading_workers: int + Number of workers to use for loading data. + output_dir: str | None, optional + Directory to save CSV files and other outputs. If None, uses current working directory. + loader: EmbeddingDataLoader | None + Custom data loader for embeddings. If None, uses TripletEmbeddingLoader. + Returns + ------- + None + Initializes the visualization app. + """ + self.viz_config = viz_config + self.image_cache = {} + self.cache_path = Path(cache_path) if cache_path else None + self.output_dir = Path(output_dir) if output_dir else Path.cwd() + self.app = None + self.features_df: pd.DataFrame | None = None + self.fig = None + self.num_loading_workers = num_loading_workers + self.loader = loader + + # Initialize cluster storage before preparing data and creating figure + self.cluster_manager = ClusterManager() + + # Store datasets for per-dataset access + self.datasets = viz_config.get_datasets() + self._DEFAULT_MARKER_SIZE = 15 + + # Debouncing for dropdown updates + # TODO: check if this hack works for plotting to be successful + self._last_update_time = 0 + self._debounce_delay = 0.3 # 300ms debounce delay + + # Initialize data + self._prepare_data() + self._create_figure() + self._init_app() + atexit.register(self._cleanup_cache) + + def _prepare_data(self): + """Load and prepare the data for visualization""" + # Extract feature paths and dataset names + feature_paths = [ + Path(config.features_path) for config in self.datasets.values() + ] + dataset_names = list(self.datasets.keys()) + + if self.viz_config.phate_kwargs is not None and len(self.datasets) > 1: + self._prepare_data_with_combined_phate(feature_paths, dataset_names) + else: + self._loading_and_prepare_data(feature_paths, dataset_names) + + def _prepare_data_with_combined_phate(self, feature_paths, dataset_names): + """Prepare data using the new combined PHATE approach""" + # Set up combined PHATE cache path + if self.viz_config.combined_phate_cache_path: + combined_cache_path = Path(self.viz_config.combined_phate_cache_path) + elif self.cache_path: + combined_cache_path = self.cache_path / "combined_phate.zarr" + else: + combined_cache_path = Path.cwd() / "combined_phate.zarr" + + # Check if we should use cached results + use_cache = ( + self.viz_config.use_cached_combined_phate and combined_cache_path.exists() + ) + + if use_cache: + logger.info( + f"Loading cached combined PHATE results from {combined_cache_path}" + ) + try: + combined_dataset = read_embedding_dataset(combined_cache_path) + self.features_df = self.loader.extract_metadata(combined_dataset) + combined_embeddings = combined_dataset["features"].values + logger.info( + f"Loaded cached combined dataset with {len(self.features_df)} samples" + ) + except Exception as e: + logger.warning( + f"Failed to load cached results: {e}. Computing fresh PHATE." + ) + use_cache = False + + if not use_cache: + logger.info("Computing fresh combined PHATE embeddings") + try: + if self.loader is not None: + combined_dataset = compute_phate_for_combined_datasets( + feature_paths=feature_paths, + output_path=combined_cache_path, + dataset_names=dataset_names, + phate_kwargs=self.viz_config.phate_kwargs, + overwrite=True, + loader=self.loader, + ) + else: + combined_dataset = compute_phate_for_combined_datasets( + feature_paths=feature_paths, + output_path=combined_cache_path, + dataset_names=dataset_names, + phate_kwargs=self.viz_config.phate_kwargs, + overwrite=True, + ) + self.features_df = combined_dataset.to_dataframe().reset_index() + combined_embeddings = combined_dataset["features"].values + logger.info( + f"Successfully computed combined PHATE with {len(self.features_df)} samples" + ) + except Exception as e: + logger.error(f"Combined PHATE computation failed: {e}") + # Fall back to traditional approach + self._prepare_data_traditional(feature_paths, dataset_names) + return + + if "dataset_pair" in self.features_df.columns: + self.features_df["dataset"] = self.features_df["dataset_pair"] + + self._compute_pca_on_combined_embeddings(combined_embeddings) + + def _loading_and_prepare_data(self, feature_paths, dataset_names): + """Load and prepare data using modular functions""" + if self.loader is not None: + combined_embeddings, combined_indices = load_and_combine_features( + feature_paths, dataset_names, self.loader + ) + else: + # Let the function use its default TripletEmbeddingLoader + combined_embeddings, combined_indices = load_and_combine_features( + feature_paths, dataset_names + ) + + # Convert to DataFrame and rename dataset_pair to dataset for compatibility + self.features_df = combined_indices.copy() + if "dataset_pair" in self.features_df.columns: + self.features_df["dataset"] = self.features_df["dataset_pair"] + + logger.info(f"Combined embeddings shape: {combined_embeddings.shape}") + + # Compute PCA and PHATE on combined embeddings + self._compute_pca_on_combined_embeddings(combined_embeddings) + self._compute_phate_on_combined_embeddings(combined_embeddings) + + def _compute_pca_on_combined_embeddings(self, combined_embeddings): + """Compute PCA on combined embeddings and set up dimension options""" + # Check if dimensionality reduction columns already exist + existing_dims = [] + dim_options = [] + + # Always recompute PCA on combined embeddings for multi-dataset scenario + if len(self.datasets) > 1 or not any( + col.startswith("PCA") for col in self.features_df.columns + ): + logger.info( + f"Computing PCA with {self.viz_config.num_PC_components} components on combined embeddings" + ) + + # Use the compute_pca function + pca_coords, _ = compute_pca( + combined_embeddings, + n_components=self.viz_config.num_PC_components, + normalize_features=True, + ) + + # We need to get the explained variance separately since compute_pca doesn't return the model + # FIXME: ideally the compute_pca function should return the model + scaler = StandardScaler() + scaled_features = scaler.fit_transform(combined_embeddings) + pca_model = PCA( + n_components=self.viz_config.num_PC_components, random_state=42 + ) + pca_model.fit(scaled_features) + + # Store explained variance for PCA labels + self.pca_explained_variance = [ + f"PC{i + 1} ({var:.1f}%)" + for i, var in enumerate(pca_model.explained_variance_ratio_ * 100) + ] + + # Add PCA coordinates to the features dataframe + for i in range(self.viz_config.num_PC_components): + self.features_df[f"PCA{i + 1}"] = pca_coords[:, i] + + # Add PCA options to dropdown + for i, pc_label in enumerate(self.pca_explained_variance): + dim_options.append({"label": pc_label, "value": f"PCA{i + 1}"}) + existing_dims.append(f"PCA{i + 1}") + + # Check for existing PHATE coordinates (if they exist in the data already) + phate_dims = [ + col for col in self.features_df.columns if col.startswith("PHATE") + ] + if phate_dims: + for dim in phate_dims: + dim_options.append({"label": dim, "value": dim}) + existing_dims.append(dim) + + # Check for existing UMAP coordinates (if they exist in the data) + umap_dims = [col for col in self.features_df.columns if col.startswith("UMAP")] + if umap_dims: + for dim in umap_dims: + dim_options.append({"label": dim, "value": dim}) + existing_dims.append(dim) + + # Store dimension options for dropdowns + self.dim_options = dim_options + + # Set default x and y axes based on available dimensions + # TODO: hardcoding to default to PCA1 and PCA2 + self.default_x = existing_dims[0] if existing_dims else "PCA1" + self.default_y = existing_dims[1] if len(existing_dims) > 1 else "PCA2" + + def _compute_phate_on_combined_embeddings(self, combined_embeddings): + """Compute PHATE on combined embeddings (traditional approach)""" + # Compute PHATE if specified in config + if self.viz_config.phate_kwargs is not None: + logger.info( + f"Computing PHATE with {self.viz_config.phate_kwargs['n_components']} components on combined embeddings" + ) + + try: + from viscy.representation.evaluation.dimensionality_reduction import ( + compute_phate, + ) + + # Use the compute_phate function with configurable parameters + logger.info(f"Using PHATE parameters: {self.viz_config.phate_kwargs}") + + phate_model, phate_coords = compute_phate( + combined_embeddings, **self.viz_config.phate_kwargs + ) + + # Add PHATE coordinates to the features dataframe + for i in range(self.viz_config.phate_kwargs["n_components"]): + self.features_df[f"PHATE{i + 1}"] = phate_coords[:, i] + + logger.info( + f"Successfully computed PHATE with {self.viz_config.phate_kwargs['n_components']} components" + ) + + except ImportError: + logger.warning( + "PHATE is not available. Install with: pip install viscy[phate]" + ) + except Exception as e: + logger.warning(f"PHATE computation failed: {str(e)}") + + # Collect all valid (dataset, fov, track) combinations + self.valid_combinations = [] + + for dataset_name, dataset_config in self.datasets.items(): + logger.info(f"Processing dataset: {dataset_name}") + logger.info(f" fov_tracks: {dataset_config.fov_tracks}") + + for fov_name, track_ids in dataset_config.fov_tracks.items(): + if track_ids == "all": + # Get all tracks for this dataset/FOV combination + fov_tracks = self.features_df[ + (self.features_df["dataset"] == dataset_name) + & (self.features_df["fov_name"] == fov_name) + ]["track_id"].unique() + + logger.info( + f" FOV {fov_name}: found {len(fov_tracks)} tracks for 'all'" + ) + + for track_id in fov_tracks: + self.valid_combinations.append( + (dataset_name, fov_name, track_id) + ) + else: + logger.info(f" FOV {fov_name}: using specific tracks {track_ids}") + for track_id in track_ids: + self.valid_combinations.append( + (dataset_name, fov_name, track_id) + ) + + logger.debug(f"Total valid filtered tracks: {len(self.valid_combinations)}") + + # Create a MultiIndex for efficient filtering + if self.valid_combinations: + # Create temporary column for filtering + self.features_df["_temp_combo"] = list( + zip( + self.features_df["dataset"], + self.features_df["fov_name"], + self.features_df["track_id"], + ) + ) + + # Create mask for selected data + selected_mask = self.features_df["_temp_combo"].isin( + self.valid_combinations + ) + + # Apply mask FIRST, then drop the temporary column + filtered_df_with_temp = self.features_df[selected_mask].copy() + background_df_with_temp = self.features_df[~selected_mask].copy() + + # Now drop the temporary column from the filtered dataframes + self.filtered_features_df = filtered_df_with_temp.drop( + "_temp_combo", axis=1 + ) + self.background_features_df = background_df_with_temp.drop( + "_temp_combo", axis=1 + ) + + # Subsample background points if there are too many + if len(self.background_features_df) > 5000: + self.background_features_df = self.background_features_df.sample( + n=5000, random_state=42 + ) + + # Pre-compute track colors + cmap = plt.cm.get_cmap("tab20") + self.track_colors = { + track_key: f"rgb{tuple(int(x * 255) for x in cmap(i % 20)[:3])}" + for i, track_key in enumerate(self.valid_combinations) + } + + # Drop the temporary column from the original dataframe + self.features_df = self.features_df.drop("_temp_combo", axis=1) + else: + self.filtered_features_df = pd.DataFrame() + self.background_features_df = pd.DataFrame() + self.track_colors = {} + + logger.info( + f"Prepared data with {len(self.features_df)} total samples, " + f"{len(self.filtered_features_df)} filtered samples, " + f"and {len(self.background_features_df)} background samples, " + f"with {len(self.valid_combinations)} unique tracks" + ) + + def _create_figure(self): + """Create the initial scatter plot figure""" + self.fig = self._create_track_colored_figure() + + def _init_app(self): + """Initialize the Dash application""" + self.app = dash.Dash(__name__) + + # Add cluster assignment button next to clear selection + cluster_controls = html.Div( + [ + html.Button( + "Assign to New Cluster", + id="assign-cluster", + style={ + "backgroundColor": "#28a745", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Clear All Clusters", + id="clear-clusters", + style={ + "backgroundColor": "#dc3545", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Save Clusters to CSV", + id="save-clusters-csv", + style={ + "backgroundColor": "#17a2b8", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Clear Selection", + id="clear-selection", + style={ + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + ], + style={"marginLeft": "10px", "display": "inline-block"}, + ) + # Create tabs for different views + tabs = dcc.Tabs( + id="view-tabs", + value="timeline-tab", + children=[ + dcc.Tab( + label="Track Timeline", + value="timeline-tab", + children=[ + html.Div( + id="track-timeline", + style={ + "height": "auto", + "overflowY": "auto", + "maxHeight": "80vh", + "padding": "10px", + "marginTop": "10px", + }, + ), + ], + ), + dcc.Tab( + label="Clusters", + value="clusters-tab", + id="clusters-tab", + children=[ + html.Div( + id="cluster-container", + style={ + "padding": "10px", + "marginTop": "10px", + }, + ), + ], + style={"display": "none"}, # Initially hidden + ), + ], + style={"marginTop": "20px"}, + ) + + # Add modal for cluster naming + cluster_name_modal = html.Div( + id="cluster-name-modal", + children=[ + html.Div( + [ + html.H3("Name Your Cluster", style={"marginBottom": "20px"}), + html.Label("Cluster Name:"), + dcc.Input( + id="cluster-name-input", + type="text", + placeholder="Enter cluster name...", + style={"width": "100%", "marginBottom": "20px"}, + ), + html.Div( + [ + html.Button( + "Save", + id="save-cluster-name", + style={ + "backgroundColor": "#28a745", + "color": "white", + "border": "none", + "padding": "8px 16px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Cancel", + id="cancel-cluster-name", + style={ + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "padding": "8px 16px", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + ], + style={"textAlign": "right"}, + ), + ], + style={ + "backgroundColor": "white", + "padding": "30px", + "borderRadius": "8px", + "maxWidth": "400px", + "margin": "auto", + "boxShadow": "0 4px 6px rgba(0, 0, 0, 0.1)", + "border": "1px solid #ddd", + }, + ) + ], + style={ + "display": "none", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + }, + ) + + # Update layout to use tabs + self.app.layout = html.Div( + style={ + "maxWidth": "95vw", + "margin": "auto", + "padding": "20px", + }, + children=[ + html.H1( + "Track Visualization", + style={"textAlign": "center", "marginBottom": "20px"}, + ), + html.Div( + [ + html.Div( + style={ + "width": "100%", + "display": "inline-block", + "verticalAlign": "top", + }, + children=[ + html.Div( + style={ + "marginBottom": "20px", + "display": "flex", + "alignItems": "center", + "gap": "20px", + "flexWrap": "wrap", + }, + children=[ + html.Div( + [ + html.Label( + "Color by:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="color-mode", + options=[ + { + "label": "Track ID", + "value": "track", + }, + { + "label": "Time", + "value": "time", + }, + ], + value="track", + style={"width": "200px"}, + ), + ] + ), + html.Div( + [ + dcc.Checklist( + id="show-arrows", + options=[ + { + "label": "Show arrows", + "value": "show", + } + ], + value=[], + style={"marginLeft": "20px"}, + ), + ] + ), + html.Div( + [ + html.Label( + "X-axis:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="x-axis", + options=self.dim_options, + value=self.default_x, + style={"width": "200px"}, + ), + ] + ), + html.Div( + [ + html.Label( + "Y-axis:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="y-axis", + options=self.dim_options, + value=self.default_y, + style={"width": "200px"}, + ), + ] + ), + cluster_controls, + ], + ), + ], + ), + ] + ), + dcc.Loading( + id="loading", + children=[ + dcc.Graph( + id="scatter-plot", + figure=self.fig, + config={ + "displayModeBar": True, + "editable": False, + "showEditInChartStudio": False, + "modeBarButtonsToRemove": [ + "select2d", + "resetScale2d", + ], + "edits": { + "annotationPosition": False, + "annotationTail": False, + "annotationText": False, + "shapePosition": True, + }, + "scrollZoom": True, + }, + style={"height": "50vh"}, + ), + ], + type="default", + ), + tabs, + cluster_name_modal, + html.Div( + id="dummy-output", style={"display": "none"} + ), # Hidden dummy output + html.Div( + id="notification-area", + style={ + "position": "fixed", + "top": "20px", + "right": "20px", + "zIndex": "2000", + "maxWidth": "300px", + }, + ), + ], + ) + + @self.app.callback( + [ + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + ], + [ + dd.Input("color-mode", "value"), + dd.Input("show-arrows", "value"), + dd.Input("x-axis", "value"), + dd.Input("y-axis", "value"), + dd.Input("scatter-plot", "relayoutData"), + dd.Input("scatter-plot", "selectedData"), + ], + [dd.State("scatter-plot", "figure")], + prevent_initial_call=True, + # Add debouncing to prevent rapid successive updates + config={"suppress_callback_exceptions": True}, + ) + def update_figure( + color_mode, + show_arrows, + x_axis, + y_axis, + relayout_data, + selected_data, + current_figure, + ): + # Input validation + if not color_mode: + color_mode = "track" + show_arrows = len(show_arrows or []) > 0 + + # Validate axis values exist in available options + valid_axis_values = [opt["value"] for opt in self.dim_options] + if not x_axis or x_axis not in valid_axis_values: + x_axis = self.default_x + if not y_axis or y_axis not in valid_axis_values: + y_axis = self.default_y + + ctx = dash.callback_context + if not ctx.triggered: + triggered_id = "No clicks yet" + else: + triggered_id = ctx.triggered[0]["prop_id"].split(".")[0] + + # Debouncing for axis changes to prevent rapid successive updates + current_time = time.time() + if triggered_id in ["x-axis", "y-axis"]: + if current_time - self._last_update_time < self._debounce_delay: + return dash.no_update, selected_data + self._last_update_time = current_time + + # Always create new figure when control inputs change (remove dependency on callback context) + if triggered_id in [ + "color-mode", + "show-arrows", + "x-axis", + "y-axis", + "No clicks yet" + ]: + if color_mode == "track": + fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis) + else: + fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis) + + # Update dragmode and selection settings + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision="true", + selectdirection="any", + ) + else: + fig = dash.no_update + + return fig, selected_data + + @self.app.callback( + [ + dd.Output("track-timeline", "children"), + dd.Output("scatter-plot", "figure", allow_duplicate=True), + ], + [dd.Input("scatter-plot", "clickData")], + [ + dd.State("color-mode", "value"), + dd.State("show-arrows", "value"), + dd.State("x-axis", "value"), + dd.State("y-axis", "value"), + ], + prevent_initial_call=True, + ) + def update_track_timeline(clickData, color_mode, show_arrows, x_axis, y_axis): + """Update the track timeline based on the clicked point""" + if clickData is None or self.features_df is None: + return ( + html.Div("Click on a point to see the track timeline"), + dash.no_update, + ) + + # Parse the hover text to get dataset, track_id, time and fov_name + hover_text = clickData["points"][0]["text"] + lines = hover_text.split("
") + dataset_name = lines[0].split(": ")[1] + track_id = int(lines[1].split(": ")[1]) + clicked_time = int(lines[2].split(": ")[1]) + fov_name = lines[3].split(": ")[1] + # Get channels specific to this dataset + channels_to_display = self.datasets[dataset_name].channels_to_display + + # Get all timepoints for this track + track_data = self.features_df[ + (self.features_df["dataset"] == dataset_name) + & (self.features_df["fov_name"] == fov_name) + & (self.features_df["track_id"] == int(track_id)) + ] + + if track_data.empty: + return ( + html.Div( + f"No data found for track {track_id} in dataset {dataset_name}" + ), + dash.no_update, + ) + + # Sort by time + track_data = track_data.sort_values("t") + timepoints = track_data["t"].unique() + + # Create a list to store all timepoint columns + timepoint_columns = [] + + # First create the time labels row + time_labels = [] + for t in timepoints: + is_clicked = t == clicked_time + time_style = { + "width": "150px", + "textAlign": "center", + "padding": "5px", + "fontSize": "20px" if is_clicked else "14px", + "fontWeight": "bold" if is_clicked else "normal", + "color": "#007bff" if is_clicked else "black", + } + time_labels.append(html.Div(f"t={t}", style=time_style)) + + timepoint_columns.append( + html.Div( + time_labels, + style={ + "display": "flex", + "flexDirection": "row", + "minWidth": "fit-content", + "borderBottom": "2px solid #ddd", + "marginBottom": "10px", + "paddingBottom": "5px", + }, + ) + ) + + # Then create image rows for each channel + for channel in channels_to_display: + channel_images = [] + for t in timepoints: + # Use correct 4-tuple cache key format + cache_key = (dataset_name, fov_name, int(track_id), int(t)) + + if ( + cache_key in self.image_cache + and channel in self.image_cache[cache_key] + ): + is_clicked = t == clicked_time + image_style = { + "width": "150px", + "height": "150px", + "border": "1px solid #ddd", + "borderRadius": "4px", + } + channel_images.append( + html.Div( + html.Img( + src=self.image_cache[cache_key][channel], + style=image_style, + ), + style={ + "width": "150px", + "padding": "5px", + }, + ) + ) + + if channel_images: + # Add channel label + timepoint_columns.append( + html.Div( + [ + html.Div( + channel, + style={ + "width": "100px", + "fontWeight": "bold", + "fontSize": "14px", + "padding": "5px", + "backgroundColor": "#f8f9fa", + "borderRadius": "4px", + "marginBottom": "5px", + "textAlign": "center", + }, + ), + html.Div( + channel_images, + style={ + "display": "flex", + "flexDirection": "row", + "minWidth": "fit-content", + "marginBottom": "15px", + }, + ), + ] + ) + ) + + # Create the main container + timeline_content = html.Div( + [ + html.H4( + f"Track {track_id} (FOV: {fov_name})", + style={ + "marginBottom": "20px", + "fontSize": "20px", + "fontWeight": "bold", + "color": "#2c3e50", + }, + ), + html.Div( + timepoint_columns, + style={ + "overflowX": "auto", + "overflowY": "hidden", + "whiteSpace": "nowrap", + "backgroundColor": "white", + "padding": "20px", + "borderRadius": "8px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + "marginBottom": "20px", + }, + ), + ] + ) + + return timeline_content, dash.no_update + + # Add callback to show/hide clusters tab and handle modal + @self.app.callback( + [ + dd.Output("clusters-tab", "style"), + dd.Output("cluster-container", "children"), + dd.Output("view-tabs", "value"), + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("cluster-name-modal", "style"), + dd.Output("cluster-name-input", "value"), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + ], + [ + dd.Input("assign-cluster", "n_clicks"), + dd.Input("clear-clusters", "n_clicks"), + dd.Input("save-cluster-name", "n_clicks"), + dd.Input("cancel-cluster-name", "n_clicks"), + dd.Input({"type": "edit-cluster-name", "index": dash.ALL}, "n_clicks"), + ], + [ + dd.State("scatter-plot", "selectedData"), + dd.State("scatter-plot", "figure"), + dd.State("color-mode", "value"), + dd.State("show-arrows", "value"), + dd.State("x-axis", "value"), + dd.State("y-axis", "value"), + dd.State("cluster-name-input", "value"), + ], + prevent_initial_call=True, + ) + def update_clusters_tab( + assign_clicks, + clear_clicks, + save_name_clicks, + cancel_name_clicks, + edit_name_clicks, + selected_data, + current_figure, + color_mode, + show_arrows, + x_axis, + y_axis, + cluster_name, + ): + ctx = dash.callback_context + if not ctx.triggered: + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + button_id = ctx.triggered[0]["prop_id"].split(".")[0] + + # Handle edit cluster name button clicks + if button_id.startswith('{"type":"edit-cluster-name"'): + try: + id_dict = json.loads(button_id) + cluster_idx = id_dict["index"] + + # Get current cluster name using the manager + cluster = self.cluster_manager.get_cluster_by_index(cluster_idx) + current_name = ( + cluster.name if cluster else f"Cluster {cluster_idx + 1}" + ) + + # Show modal + modal_style = { + "display": "flex", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + } + + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + dash.no_update, + modal_style, + current_name, + dash.no_update, + ) + except Exception: + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + # Handle clear clusters button + elif button_id == "clear-clusters" and clear_clicks: + self.cluster_manager.clear_all_clusters() + logger.info("Cleared all clusters") + + # Update figure to remove cluster coloring + if color_mode == "track": + fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis) + else: + fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis) + + return ( + {"display": "none"}, # Hide clusters tab + html.Div("No clusters created yet"), + "timeline-tab", # Switch back to timeline tab + fig, + {"display": "none"}, # Hide modal + "", + None, # Clear selection + ) + + # Handle save cluster name button + elif button_id == "save-cluster-name" and save_name_clicks and cluster_name: + # Get the most recent cluster and update its name + if self.cluster_manager.clusters: + latest_cluster = self.cluster_manager.clusters[-1] + latest_cluster.name = cluster_name + logger.info( + f"Named cluster {latest_cluster.id} as '{cluster_name}'" + ) + + # Close modal and update clusters display + modal_style = {"display": "none"} + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + dash.no_update, + modal_style, + "", + dash.no_update, + ) + + # Handle cancel cluster name button + elif button_id == "cancel-cluster-name" and cancel_name_clicks: + # Just close the modal without saving + modal_style = {"display": "none"} + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + dash.no_update, + modal_style, + "", + dash.no_update, + ) + + # Handle assign cluster button + elif button_id == "assign-cluster" and assign_clicks and selected_data: + if not selected_data or not selected_data.get("points"): + logger.warning("No points selected for clustering") + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + selected_points_list = [] + + # Determine if we're in cluster mode by checking if clusters exist + is_cluster_mode = len(self.cluster_manager.clusters) > 0 + + # Get information about each selected point + for point in selected_data["points"]: + curve_number = point.get("curveNumber", 0) + point_index = point.get("pointIndex") + + if point_index is None: + continue + + if is_cluster_mode: + # Handle cluster-colored figure + # Structure: [unclustered_trace, cluster1_trace, cluster2_trace, ...] + + if curve_number == 0: + # This is an unclustered point + # Get the DataFrame row from the unclustered points + df_cache_keys = [ + ( + row["dataset"], + row["fov_name"], + row["track_id"], + row["t"], + ) + for _, row in self.filtered_features_df.iterrows() + ] + + clustered_cache_keys = set() + for cluster in self.cluster_manager.clusters: + clustered_cache_keys.update(cluster.cache_keys) + + unclustered_mask = [ + cache_key not in clustered_cache_keys + for cache_key in df_cache_keys + ] + + if any(unclustered_mask): + unclustered_df = self.filtered_features_df[ + unclustered_mask + ] + if point_index < len(unclustered_df): + selected_point = unclustered_df.iloc[point_index] + selected_points_list.append( + { + "track_id": selected_point["track_id"], + "t": selected_point["t"], + "fov_name": selected_point["fov_name"], + "dataset": selected_point["dataset"], + } + ) + else: + # This is a point from an existing cluster + cluster_index = curve_number - 1 + if cluster_index < len(self.cluster_manager.clusters): + cluster = self.cluster_manager.clusters[cluster_index] + + # Get the DataFrame rows for this cluster + df_cache_keys = [ + ( + row["dataset"], + row["fov_name"], + row["track_id"], + row["t"], + ) + for _, row in self.filtered_features_df.iterrows() + ] + + cluster_mask = [ + cache_key in cluster.cache_keys + for cache_key in df_cache_keys + ] + + if any(cluster_mask): + cluster_df = self.filtered_features_df[cluster_mask] + if point_index < len(cluster_df): + selected_point = cluster_df.iloc[point_index] + selected_points_list.append( + { + "track_id": selected_point["track_id"], + "t": selected_point["t"], + "fov_name": selected_point["fov_name"], + "dataset": selected_point["dataset"], + } + ) + else: + # Handle track-colored figure + # Skip background points (curve 0 if background exists) + background_offset = ( + 1 if not self.background_features_df.empty else 0 + ) + + if curve_number < background_offset: + # This is a background point, skip it + continue + + # Find which track this point belongs to + track_curve_index = curve_number - background_offset + + # Map to the actual track + if track_curve_index < len(self.valid_combinations): + dataset_name, fov_name, track_id = self.valid_combinations[ + track_curve_index + ] + + # Get the track data + track_data = self.filtered_features_df[ + (self.filtered_features_df["dataset"] == dataset_name) + & (self.filtered_features_df["fov_name"] == fov_name) + & ( + self.filtered_features_df["track_id"] + == int(track_id) + ) + ].sort_values("t") + + # Get the specific point within this track + if point_index < len(track_data): + selected_point = track_data.iloc[point_index] + selected_points_list.append( + { + "track_id": selected_point["track_id"], + "t": selected_point["t"], + "fov_name": selected_point["fov_name"], + "dataset": selected_point["dataset"], + } + ) + + if not selected_points_list: + logger.warning("No valid points found in selection") + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + # Create a new cluster using the cluster manager + cluster_id = self.cluster_manager.create_cluster_from_points( + selected_points_list + ) + logger.info( + f"Created cluster {cluster_id} with {len(selected_points_list)} points" + ) + + # Show modal for naming the cluster + modal_style = { + "display": "flex", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + } + + # Update figure to show cluster coloring + fig = self._create_cluster_colored_figure(show_arrows, x_axis, y_axis) + + return ( + {"display": "block"}, # Show clusters tab + self._get_cluster_images(), + "clusters-tab", # Switch to clusters tab + fig, # Updated figure with cluster colors + modal_style, # Show naming modal + f"Cluster {len(self.cluster_manager.clusters)}", # Default name + None, # Clear selection + ) + + # Default return (no updates) + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + @self.app.callback( + [ + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + dd.Output("track-timeline", "children", allow_duplicate=True), + ], + [dd.Input("clear-selection", "n_clicks")], + [ + dd.State("color-mode", "value"), + dd.State("show-arrows", "value"), + dd.State("x-axis", "value"), + dd.State("y-axis", "value"), + ], + prevent_initial_call=True, + ) + def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): + """Callback to clear the selection and restore original opacity""" + if n_clicks: + # Create a new figure with no selections + if color_mode == "track": + fig = self._create_track_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + else: + fig = self._create_time_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + + # Update layout to maintain lasso mode but clear selections + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision=None, # Reset UI state to clear selections + selectdirection="any", + ) + + # Clear the track timeline as well + empty_timeline = html.Div( + "Click on a point to see the track timeline", + style={ + "textAlign": "center", + "color": "#666", + "fontSize": "16px", + "padding": "40px", + "fontStyle": "italic", + }, + ) + + return ( + fig, + None, + empty_timeline, + ) # Clear figure selection, selectedData, and timeline + return dash.no_update, dash.no_update, dash.no_update + + @self.app.callback( + [ + dd.Output("dummy-output", "children"), + dd.Output("notification-area", "children"), + ], + [dd.Input("save-clusters-csv", "n_clicks")], + prevent_initial_call=True, + ) + def save_clusters_to_csv(n_clicks): + """Save all clusters to CSV files""" + notification = "" + + if n_clicks and self.cluster_manager.clusters: + try: + # Create output directory if it doesn't exist + self.output_dir.mkdir(parents=True, exist_ok=True) + + saved_files = [] + + # Save each cluster to a separate CSV file + for i, cluster in enumerate(self.cluster_manager.clusters): + cluster_name = cluster.name or f"Cluster_{i + 1}" + # Clean the cluster name for filename + safe_name = "".join( + c + for c in cluster_name + if c.isalnum() or c in (" ", "-", "_") + ).rstrip() + safe_name = safe_name.replace(" ", "_") + + # Create DataFrame from cluster points + cluster_data = [] + for point in cluster.points: + # Get the full row data for this point + point_row = self.filtered_features_df[ + (self.filtered_features_df["dataset"] == point.dataset) + & ( + self.filtered_features_df["fov_name"] + == point.fov_name + ) + & ( + self.filtered_features_df["track_id"] + == point.track_id + ) + & (self.filtered_features_df["t"] == point.t) + ] + + if not point_row.empty: + cluster_data.append(point_row.iloc[0]) + + if cluster_data: + cluster_df = pd.DataFrame(cluster_data) + + # Add cluster information (dataset is already in the dataframe) + cluster_df["cluster_name"] = ( + cluster.name or f"Cluster {i + 1}" + ) + cluster_df["cluster_size"] = len(cluster.points) + + # Reorder columns to put dataset and cluster info at the front + cols = cluster_df.columns.tolist() + priority_cols = [ + "dataset", + "cluster_name", + "cluster_size", + ] + other_cols = [ + col for col in cols if col not in priority_cols + ] + cluster_df = cluster_df[priority_cols + other_cols] + + # Save to CSV + csv_path = self.output_dir / f"{safe_name}.csv" + cluster_df.to_csv(csv_path, index=False) + saved_files.append(csv_path.name) + logger.info(f"Saved cluster '{cluster.name}' to {csv_path}") + + # Also save a summary CSV with all clusters + if self.cluster_manager.clusters: + all_cluster_data = [] + for i, cluster in enumerate(self.cluster_manager.clusters): + for point in cluster.points: + point_row = self.filtered_features_df[ + ( + self.filtered_features_df["dataset"] + == point.dataset + ) + & ( + self.filtered_features_df["fov_name"] + == point.fov_name + ) + & ( + self.filtered_features_df["track_id"] + == point.track_id + ) + & (self.filtered_features_df["t"] == point.t) + ] + + if not point_row.empty: + row_data = point_row.iloc[0].to_dict() + row_data["cluster_name"] = ( + cluster.name or f"Cluster {i + 1}" + ) + row_data["cluster_index"] = i + all_cluster_data.append(row_data) + + if all_cluster_data: + summary_df = pd.DataFrame(all_cluster_data) + + # Reorder columns to put dataset and cluster info at the front + cols = summary_df.columns.tolist() + priority_cols = [ + "dataset", + "cluster_name", + "cluster_index", + ] + other_cols = [ + col for col in cols if col not in priority_cols + ] + summary_df = summary_df[priority_cols + other_cols] + + summary_path = self.output_dir / "all_clusters_summary.csv" + summary_df.to_csv(summary_path, index=False) + saved_files.append(summary_path.name) + logger.info( + f"Saved summary of all clusters to {summary_path}" + ) + + # Log summary statistics + dataset_counts = summary_df["dataset"].value_counts() + logger.info( + f"Exported {len(self.cluster_manager.clusters)} clusters with {len(all_cluster_data)} total points" + ) + logger.info( + f"Points per dataset: {dataset_counts.to_dict()}" + ) + + # Create success notification + notification = html.Div( + [ + html.Div( + [ + html.Strong("✅ Clusters Saved Successfully!"), + html.Br(), + html.Small( + f"Saved {len(saved_files)} files to {self.output_dir}" + ), + html.Br(), + html.Small( + f"Files: {', '.join(saved_files[:3])}" + + ("..." if len(saved_files) > 3 else "") + ), + ], + style={ + "backgroundColor": "#d4edda", + "color": "#155724", + "border": "1px solid #c3e6cb", + "borderRadius": "4px", + "padding": "10px", + "marginBottom": "10px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ) + ], + id="success-notification", + ) + + except Exception as e: + logger.error(f"Error saving clusters to CSV: {e}") + # Create error notification + notification = html.Div( + [ + html.Div( + [ + html.Strong("❌ Error Saving Clusters"), + html.Br(), + html.Small(f"Error: {str(e)}"), + ], + style={ + "backgroundColor": "#f8d7da", + "color": "#721c24", + "border": "1px solid #f5c6cb", + "borderRadius": "4px", + "padding": "10px", + "marginBottom": "10px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ) + ], + id="error-notification", + ) + + elif n_clicks and not self.cluster_manager.clusters: + # No clusters to save + notification = html.Div( + [ + html.Div( + [ + html.Strong("ℹ️ No Clusters to Save"), + html.Br(), + html.Small("Create some clusters first before saving."), + ], + style={ + "backgroundColor": "#d1ecf1", + "color": "#0c5460", + "border": "1px solid #bee5eb", + "borderRadius": "4px", + "padding": "10px", + "marginBottom": "10px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ) + ], + id="info-notification", + ) + + return "", notification + + def _create_track_colored_figure( + self, + show_arrows=False, + x_axis=None, + y_axis=None, + ): + """Create scatter plot with track-based coloring""" + if self.filtered_features_df is None or self.filtered_features_df.empty: + return go.Figure() + + x_axis = x_axis or self.default_x + y_axis = y_axis or self.default_y + + fig = go.Figure() + + # Set initial layout with lasso mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + selectdirection="any", + plot_bgcolor="white", + title="Embedding Visualization", + xaxis_title=x_axis, + yaxis_title=y_axis, + uirevision=True, + hovermode="closest", + showlegend=True, + legend=dict( + yanchor="top", + y=1, + xanchor="left", + x=1.02, + title="Tracks", + bordercolor="Black", + borderwidth=1, + ), + margin=dict(l=50, r=150, t=50, b=50), + autosize=True, + ) + fig.update_xaxes(showgrid=False) + fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1) + + # Use pre-computed filtered and background data + filtered_features_df = self.filtered_features_df + + # Add background points using pre-computed background data + # Make them non-interactive and render behind main points + if not self.background_features_df.empty: + fig.add_trace( + go.Scattergl( + x=self.background_features_df[x_axis], + y=self.background_features_df[y_axis], + mode="markers", + marker=dict( + size=12, color="lightgray", opacity=0.3 + ), # Smaller and more transparent + name=f"Other tracks ({len(self.background_features_df)} points)", + text=[ + f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}" + for dataset, track_id, t, fov in zip( + self.background_features_df["dataset"], + self.background_features_df["track_id"], + self.background_features_df["t"], + self.background_features_df["fov_name"], + ) + ], + hoverinfo="text", + showlegend=True, + hoverlabel=dict(namelength=-1), + selectedpoints=False, + unselected=dict(marker=dict(opacity=0.3)), + selected=dict(marker=dict(opacity=0.3)), + ) + ) + + # Use pre-computed unique track keys and colors + # Add points for each selected track with cluster coloring + for dataset_name, fov_name, track_id in self.valid_combinations: + track_data = filtered_features_df[ + (filtered_features_df["dataset"] == dataset_name) + & (filtered_features_df["fov_name"] == fov_name) + & (filtered_features_df["track_id"] == int(track_id)) + ] + + if track_data.empty: + logger.warning( + f"No data found for track {track_id} in dataset {dataset_name} and fov {fov_name}" + ) + continue + + # Sort by time + if hasattr(track_data, "sort_values"): + track_data = track_data.sort_values("t") + + # Add track points + fig.add_trace( + go.Scattergl( + x=track_data[x_axis], + y=track_data[y_axis], + mode="markers", + marker=dict( + size=self._DEFAULT_MARKER_SIZE, + color=self.track_colors[(dataset_name, fov_name, track_id)], + opacity=1.0, + line=dict(width=0.5, color="black"), + ), + name=f"{dataset_name}:{track_id}", + text=[ + f"Dataset: {dataset_name}
Track: {track_id}
Time: {t}
FOV: {fov}" + for t, fov in zip(track_data["t"], track_data["fov_name"]) + ], + hoverinfo="text", + unselected=dict( + marker=dict(opacity=0.6, size=self._DEFAULT_MARKER_SIZE) + ), + selected=dict( + marker=dict(size=self._DEFAULT_MARKER_SIZE * 2.0, opacity=1.0) + ), + hoverlabel=dict(namelength=-1), + ) + ) + + # Add trajectory lines and arrows if requested + if show_arrows and len(track_data) > 1: + x_coords = track_data[x_axis].values + y_coords = track_data[y_axis].values + track_key = (dataset_name, fov_name, track_id) + + # Add dashed lines for the trajectory + fig.add_trace( + go.Scattergl( + x=x_coords, + y=y_coords, + mode="lines", + line=dict( + color=self.track_colors[track_key], + width=3, + dash="dot", + ), + showlegend=False, + hoverinfo="skip", + selectedpoints=False, + ) + ) + + # Add arrows at regular intervals + arrow_interval = max(1, len(track_data) // 3) + for i in range(0, len(track_data) - 1, arrow_interval): + dx = x_coords[i + 1] - x_coords[i] + dy = y_coords[i + 1] - y_coords[i] + + # Only add arrow if there's significant movement + if dx * dx + dy * dy > 1e-6: + fig.add_annotation( + x=x_coords[i + 1], + y=y_coords[i + 1], + ax=x_coords[i], + ay=y_coords[i], + xref="x", + yref="y", + axref="x", + ayref="y", + showarrow=True, + arrowhead=2, + arrowsize=1, + arrowwidth=1, + arrowcolor=self.track_colors[track_key], + opacity=0.8, + ) + + return fig + + def _create_time_colored_figure( + self, + show_arrows=False, + x_axis=None, + y_axis=None, + ): + """Create scatter plot with time-based coloring""" + if self.filtered_features_df is None or self.filtered_features_df.empty: + return go.Figure() + + x_axis = x_axis or self.default_x + y_axis = y_axis or self.default_y + + fig = go.Figure() + + # Set initial layout with lasso mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + selectdirection="any", + plot_bgcolor="white", + title="Embedding Visualization", + xaxis_title=x_axis, + yaxis_title=y_axis, + uirevision=True, + hovermode="closest", + showlegend=True, + legend=dict( + yanchor="top", + y=1, + xanchor="left", + x=1.02, + bordercolor="Black", + borderwidth=1, + ), + margin=dict(l=50, r=150, t=50, b=50), + autosize=True, + ) + fig.update_xaxes(showgrid=False) + fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1) + + # Use pre-computed filtered and background data + filtered_features_df = self.filtered_features_df + + # Add background points using pre-computed background data + # Make them non-interactive and render behind main points + if not self.background_features_df.empty: + fig.add_trace( + go.Scattergl( + x=self.background_features_df[x_axis], + y=self.background_features_df[y_axis], + mode="markers", + marker=dict( + size=12, color="lightgray", opacity=0.3 + ), # Smaller and more transparent + name=f"Other points ({len(self.background_features_df)} points)", + text=[ + f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}" + for dataset, track_id, t, fov in zip( + self.background_features_df["dataset"], + self.background_features_df["track_id"], + self.background_features_df["t"], + self.background_features_df["fov_name"], + ) + ], + hoverinfo="text", + hoverlabel=dict(namelength=-1), + selectedpoints=False, + unselected=dict(marker=dict(opacity=0.3)), + selected=dict(marker=dict(opacity=0.3)), + ) + ) + + # Add time-colored points + if not filtered_features_df.empty: + fig.add_trace( + go.Scattergl( + x=filtered_features_df[x_axis], + y=filtered_features_df[y_axis], + mode="markers", + marker=dict( + size=self._DEFAULT_MARKER_SIZE, + color=filtered_features_df["t"], + colorscale="Viridis", + colorbar=dict(title="Time"), + ), + text=[ + f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}" + for dataset, track_id, t, fov in zip( + filtered_features_df["dataset"], + filtered_features_df["track_id"], + filtered_features_df["t"], + filtered_features_df["fov_name"], + ) + ], + hoverinfo="text", + showlegend=False, + hoverlabel=dict(namelength=-1), + ) + ) + + # Add arrows if requested (same as before, but make them non-selectable) + if show_arrows and not filtered_features_df.empty: + for dataset_name, fov_name, track_id in filtered_features_df.apply( + lambda row: (row["dataset"], row["fov_name"], str(row["track_id"])), + axis=1, + ).unique(): + track_data = filtered_features_df[ + (filtered_features_df["dataset"] == dataset_name) + & (filtered_features_df["fov_name"] == fov_name) + & (filtered_features_df["track_id"] == track_id) + ] + + if len(track_data) <= 1: + continue + + # Sort by time + if hasattr(track_data, "sort_values"): + track_data = track_data.sort_values("t") + x_coords = track_data[x_axis].values + y_coords = track_data[y_axis].values + distances = np.sqrt( + np.diff(np.array(x_coords)) ** 2 + np.diff(np.array(y_coords)) ** 2 + ) + + # Only show arrows for movements larger than the median distance + threshold = np.median(distances) * 0.5 if len(distances) > 0 else 0 + + arrow_x = [] + arrow_y = [] + + for i in range(len(track_data) - 1): + if distances[i] > threshold: + arrow_x.extend([x_coords[i], x_coords[i + 1], None]) + arrow_y.extend([y_coords[i], y_coords[i + 1], None]) + + if arrow_x: + fig.add_trace( + go.Scatter( + x=arrow_x, + y=arrow_y, + mode="lines", + line=dict( + color="rgba(128, 128, 128, 0.5)", + width=1, + dash="dot", + ), + showlegend=False, + hoverinfo="skip", + selectedpoints=False, + ) + ) + + return fig + + def _create_cluster_colored_figure( + self, + show_arrows=False, + x_axis=None, + y_axis=None, + ): + """Create scatter plot with cluster-based coloring""" + if self.filtered_features_df is None or self.filtered_features_df.empty: + return go.Figure() + + x_axis = x_axis or self.default_x + y_axis = y_axis or self.default_y + + fig = go.Figure() + + # Set initial layout + fig.update_layout( + dragmode="lasso", + showlegend=True, + height=700, + xaxis=dict(scaleanchor="y", scaleratio=1), # Square plot + yaxis=dict(scaleanchor="x", scaleratio=1), + ) + fig.update_xaxes(showgrid=False) + fig.update_yaxes(showgrid=False) + + # Use cluster manager's color scheme + cluster_colors = self.cluster_manager.get_cluster_colors_by_index() + + # Get all clustered cache keys + clustered_cache_keys = set() + for cluster in self.cluster_manager.clusters: + clustered_cache_keys.update(cluster.cache_keys) + + # Create a mask for unclustered points + # Map DataFrame rows to cache keys and check if they're in any cluster + df_cache_keys = [ + (row["dataset"], row["fov_name"], row["track_id"], row["t"]) + for _, row in self.filtered_features_df.iterrows() + ] + + unclustered_mask = [ + cache_key not in clustered_cache_keys for cache_key in df_cache_keys + ] + + # Add unclustered points (background) - make them non-interactive + if any(unclustered_mask): + unclustered_df = self.filtered_features_df[unclustered_mask] + + fig.add_trace( + go.Scatter( + x=unclustered_df[x_axis], + y=unclustered_df[y_axis], + mode="markers", + marker=dict( + size=12, + color="lightgray", + opacity=0.6, # More transparent + ), + name="Unclustered", + hovertemplate="Unclustered
" + + f"{x_axis}: %{{x}}
" + + f"{y_axis}: %{{y}}
" + + "", + selectedpoints=False, + unselected=dict(marker=dict(opacity=0.3)), + selected=dict(marker=dict(opacity=0.3)), + ) + ) + + # Add each cluster as a separate trace + for i, cluster in enumerate(self.cluster_manager.clusters): + # Create a mask for points in this cluster + cluster_mask = [ + cache_key in cluster.cache_keys for cache_key in df_cache_keys + ] + + if any(cluster_mask): + cluster_df = self.filtered_features_df[cluster_mask] + # Use the color from cluster manager (consistent with tabs) + color = cluster_colors[i] if i < len(cluster_colors) else "gray" + + fig.add_trace( + go.Scatter( + x=cluster_df[x_axis], + y=cluster_df[y_axis], + mode="markers", + marker=dict( + size=self._DEFAULT_MARKER_SIZE, + color=color, + line=dict(width=0.5, color="black"), + opacity=0.8, + ), + name=cluster.name or f"Cluster {i + 1}", + hovertemplate=f"{cluster.name or f'Cluster {i + 1}'}
" + + f"{x_axis}: %{{x}}
" + + f"{y_axis}: %{{y}}
" + + "", + ) + ) + + return fig + + def _cleanup_cache(self): + """Clear the image cache when the program exits""" + logging.info("Cleaning up image cache...") + self.image_cache.clear() + + def _get_cluster_images(self): + """Display images for all clusters in a grid layout""" + if not self.cluster_manager.clusters: + return html.Div("No clusters created yet") + + # Debug information + logger.info(f"Image cache size: {len(self.image_cache)}") + logger.info(f"Number of clusters: {len(self.cluster_manager.clusters)}") + + # Use cluster manager's color scheme (consistent with scatter plot) + cluster_colors = self.cluster_manager.get_cluster_colors_by_index() + + # Collect all unique channels across all datasets in the cluster points + all_channels_in_cluster = set() + for cluster in self.cluster_manager.clusters: + for point in cluster.points: + # Get channels for this specific dataset + if point.dataset in self.datasets: + dataset_channels = self.datasets[point.dataset].channels_to_display + all_channels_in_cluster.update(dataset_channels) + + logger.info(f"All channels found in clusters: {all_channels_in_cluster}") + + # Create individual cluster panels + cluster_panels = [] + for cluster_idx, cluster in enumerate(self.cluster_manager.clusters): + logger.info( + f"Processing cluster {cluster_idx} with {len(cluster.points)} points" + ) + + # Group points by dataset to handle different channels + points_by_dataset = {} + for point in cluster.points: + if point.dataset not in points_by_dataset: + points_by_dataset[point.dataset] = [] + points_by_dataset[point.dataset].append(point) + + # Create images organized by dataset and then by channel + all_channel_images = [] + images_found = 0 + + for dataset_name, dataset_points in points_by_dataset.items(): + if dataset_name not in self.datasets: + continue + + dataset_channels = self.datasets[dataset_name].channels_to_display + + # Add dataset header if there are multiple datasets + if len(points_by_dataset) > 1: + all_channel_images.append( + html.H5( + f"Dataset: {dataset_name}", + style={ + "margin": "10px 5px 5px 5px", + "fontSize": "14px", + "fontWeight": "bold", + "color": "#2c3e50", + "borderBottom": "1px solid #dee2e6", + "paddingBottom": "5px", + }, + ) + ) + + for channel in dataset_channels: + images = [] + for point in dataset_points: + cache_key = ( + point.dataset, + point.fov_name, + point.track_id, + point.t, + ) + + # Debug: Check if cache key exists + if cache_key in self.image_cache: + if channel in self.image_cache[cache_key]: + images_found += 1 + images.append( + html.Div( + [ + html.Img( + src=self.image_cache[cache_key][ + channel + ], + style={ + "width": "100px", + "height": "100px", + "margin": "2px", + "border": f"2px solid {cluster_colors[cluster_idx]}", + "borderRadius": "4px", + }, + ), + html.Div( + f"{dataset_name}:T{point.track_id}:t{point.t}", + style={ + "textAlign": "center", + "fontSize": "9px", + "maxWidth": "100px", + "overflow": "hidden", + "textOverflow": "ellipsis", + }, + ), + ], + style={ + "display": "inline-block", + "margin": "2px", + "verticalAlign": "top", + }, + ) + ) + else: + logger.debug( + f"Channel {channel} not found for cache key {cache_key}" + ) + else: + logger.debug( + f"Cache key {cache_key} not found in image cache" + ) + + if images: + all_channel_images.extend( + [ + html.H6( + f"{channel}", + style={ + "margin": "5px", + "fontSize": "12px", + "fontWeight": "bold", + "position": "sticky", + "left": "0", + "backgroundColor": "#f8f9fa", + "zIndex": "1", + "paddingLeft": "5px", + }, + ), + html.Div( + images, + style={ + "whiteSpace": "nowrap", + "marginBottom": "10px", + }, + ), + ] + ) + + logger.info(f"Cluster {cluster_idx}: Found {images_found} images") + + if all_channel_images: + # Create a panel for this cluster with synchronized scrolling + cluster_name = ( + cluster.name if cluster.name else f"Cluster {cluster_idx + 1}" + ) + cluster_panels.append( + html.Div( + [ + html.Div( + [ + html.Span( + cluster_name, + style={ + "color": cluster_colors[cluster_idx], + "fontWeight": "bold", + "fontSize": "16px", + }, + ), + html.Span( + f" ({len(cluster.points)} points)", + style={ + "color": "#2c3e50", + "fontSize": "14px", + }, + ), + html.Button( + "✏️", + id={ + "type": "edit-cluster-name", + "index": cluster.id, + }, + style={ + "marginLeft": "10px", + "backgroundColor": "transparent", + "border": "none", + "cursor": "pointer", + "fontSize": "12px", + }, + title="Edit cluster name", + ), + ], + style={ + "marginBottom": "10px", + "borderBottom": f"2px solid {cluster_colors[cluster_idx]}", + "paddingBottom": "5px", + "position": "sticky", + "top": "0", + "backgroundColor": "white", + "zIndex": "1", + }, + ), + html.Div( + all_channel_images, + style={ + "overflowX": "auto", + "overflowY": "auto", + "height": "400px", + "backgroundColor": "#ffffff", + "padding": "10px", + "borderRadius": "8px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ), + ], + style={ + "width": "24%", + "display": "inline-block", + "verticalAlign": "top", + "padding": "5px", + "boxSizing": "border-box", + }, + ) + ) + else: + # Show a message if no images were found for this cluster + cluster_name = ( + cluster.name if cluster.name else f"Cluster {cluster_idx + 1}" + ) + cluster_panels.append( + html.Div( + [ + html.Div( + [ + html.Span( + cluster_name, + style={ + "color": cluster_colors[cluster_idx], + "fontWeight": "bold", + "fontSize": "16px", + }, + ), + html.Span( + f" ({len(cluster.points)} points)", + style={ + "color": "#2c3e50", + "fontSize": "14px", + }, + ), + ], + style={ + "marginBottom": "10px", + "borderBottom": f"2px solid {cluster_colors[cluster_idx]}", + "paddingBottom": "5px", + }, + ), + html.Div( + [ + html.P("No images available for this cluster."), + html.P("This might be because:"), + html.Ul( + [ + html.Li("Images haven't been preloaded"), + html.Li( + "Cache keys don't match the cluster points" + ), + html.Li( + "Channels are not configured correctly" + ), + ] + ), + html.P( + f"Debug info: {len(cluster.points)} points in cluster" + ), + ], + style={ + "padding": "20px", + "backgroundColor": "#f8f9fa", + "borderRadius": "8px", + "textAlign": "center", + "color": "#6c757d", + }, + ), + ], + style={ + "width": "24%", + "display": "inline-block", + "verticalAlign": "top", + "padding": "5px", + "boxSizing": "border-box", + }, + ) + ) + + # If no cluster panels were created, show a debug message + if not cluster_panels: + return html.Div( + [ + html.H2("Clusters", style={"marginBottom": "20px"}), + html.Div( + [ + html.P("No cluster images could be displayed."), + html.P("Debug information:"), + html.Ul( + [ + html.Li( + f"Number of clusters: {len(self.cluster_manager.clusters)}" + ), + html.Li( + f"Image cache size: {len(self.image_cache)}" + ), + html.Li( + f"Channels to display: {all_channels_in_cluster}" + ), + ] + ), + ], + style={ + "padding": "20px", + "backgroundColor": "#f8f9fa", + "borderRadius": "8px", + "margin": "20px", + }, + ), + ] + ) + + # Create rows of 4 panels each + rows = [] + for i in range(0, len(cluster_panels), 4): + row = html.Div( + cluster_panels[i : i + 4], + style={ + "display": "flex", + "justifyContent": "flex-start", + "gap": "10px", + "marginBottom": "10px", + }, + ) + rows.append(row) + + return html.Div( + [ + html.H2( + [ + "Clusters ", + html.Span( + f"({len(self.cluster_manager.clusters)} total)", + style={"color": "#666"}, + ), + ], + style={ + "marginBottom": "20px", + "fontSize": "28px", + "fontWeight": "bold", + "color": "#2c3e50", + }, + ), + html.Div( + rows, + style={ + "maxHeight": "calc(100vh - 200px)", + "overflowY": "auto", + "padding": "10px", + }, + ), + ] + ) + + @staticmethod + def _normalize_image(img_array): + """Normalize a single image array to [0, 255] more efficiently""" + min_val = img_array.min() + max_val = img_array.max() + if min_val == max_val: + return np.zeros_like(img_array, dtype=np.uint8) + # Normalize in one step + return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8) + + @staticmethod + def _numpy_to_base64(img_array): + """Convert numpy array to base64 string with compression""" + import base64 + from io import BytesIO + + from PIL import Image + + if img_array.dtype != np.uint8: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array) + buffered = BytesIO() + # Use JPEG format with quality=85 for better compression + img.save(buffered, format="JPEG", quality=85, optimize=True) + return "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode( + "utf-8" + ) + + def preload_images(self): + """Preload all images into memory for all datasets""" + from viscy.data.triplet import TripletDataModule + + # Try to load from cache first + if self.cache_path and self.load_cache(): + logger.info("Preloading images into cache...") + return + + logger.info("Cache not found, preloading and caching images...") + # Process each dataset + for dataset_name, dataset_config in self.datasets.items(): + logger.info(f"Processing dataset: {dataset_name}") + + if ( + not hasattr(dataset_config, "fov_tracks") + or not dataset_config.fov_tracks + ): + logger.info(f"Skipping dataset {dataset_name} as it has no FOV tracks") + continue + + # Process each FOV and resolve track IDs + fov_track_mapping = {} + for fov_name, tracks in dataset_config.fov_tracks.items(): + if isinstance(tracks, list): + fov_track_mapping[fov_name] = tracks + elif tracks == "all": + # Get all tracks for this FOV from features + if self.features_df is not None: + fov_tracks_series = self.features_df[ + (self.features_df["dataset"] == dataset_name) + & (self.features_df["fov_name"] == fov_name) + ]["track_id"] + fov_track_mapping[fov_name] = ( + fov_tracks_series.unique().tolist() + ) + logger.info( + f"Resolved 'all' tracks for FOV {fov_name}: {len(fov_track_mapping[fov_name])} tracks" + ) + else: + logger.warning( + f"Cannot resolve 'all' tracks for FOV {fov_name}: features_df is None" + ) + fov_track_mapping[fov_name] = [] + else: + logger.warning( + f"Unknown track specification for FOV {fov_name}: {tracks}" + ) + fov_track_mapping[fov_name] = [] + + logger.debug(f"FOV-track mapping for {dataset_name}: {fov_track_mapping}") + + # Process each FOV and its resolved tracks + for fov_name, track_ids in fov_track_mapping.items(): + if not track_ids: # Skip FOVs with no tracks + logger.debug(f"Skipping FOV {fov_name} as it has no tracks") + continue + + logger.debug( + f"Processing FOV {fov_name} with {len(track_ids)} tracks: {track_ids}" + ) + + try: + data_module = TripletDataModule( + data_path=dataset_config.data_path, + tracks_path=dataset_config.tracks_path, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + source_channel=dataset_config.channels_to_display, + z_range=dataset_config.z_range, + initial_yx_patch_size=dataset_config.yx_patch_size, + final_yx_patch_size=dataset_config.yx_patch_size, + batch_size=1, + num_workers=self.num_loading_workers, + normalizations=[], + predict_cells=True, + ) + data_module.setup("predict") + + for batch in data_module.predict_dataloader(): + try: + images = batch["anchor"].numpy() + indices = batch["index"] + track_id = indices["track_id"].item() + t = indices["t"].item() + + img = np.stack(images) + + cache_key = (dataset_name, fov_name, track_id, t) + + logger.debug(f"Processing cache key: {cache_key}") + + # Process each channel based on its type + processed_channels = {} + for idx, channel in enumerate( + dataset_config.channels_to_display + ): + try: + if channel in ["Phase3D", "DIC", "BF"]: + # For phase contrast, use the middle z-slice + z_idx = ( + dataset_config.z_range[1] + - dataset_config.z_range[0] + ) // 2 + processed = self._normalize_image( + img[0, idx, z_idx] + ) + else: + # For fluorescence, use max projection + processed = self._normalize_image( + np.max(img[0, idx], axis=0) + ) + + processed_channels[channel] = self._numpy_to_base64( + processed + ) + logger.debug( + f"Successfully processed channel {channel} for {cache_key}" + ) + except Exception as e: + logger.error( + f"Error processing channel {channel} for {cache_key}: {e}" + ) + continue + + if ( + processed_channels + ): # Only store if at least one channel was processed + self.image_cache[cache_key] = processed_channels + + except Exception as e: + logger.error( + f"Error processing batch for {fov_name}, track {track_id}: {e}" + ) + continue + + except Exception as e: + logger.error( + f"Error setting up data module for FOV {fov_name}: {e}" + ) + continue + + # Log some statistics about the cache + cached_datasets = set(key[0] for key in self.image_cache.keys()) + cached_fovs = set((key[0], key[1]) for key in self.image_cache.keys()) + cached_tracks = set((key[0], key[1], key[2]) for key in self.image_cache.keys()) + logger.info(f"Cached datasets: {cached_datasets}") + logger.debug(f"Cached dataset-FOV combinations: {len(cached_fovs)}") + logger.debug( + f"Number of unique dataset-FOV-track combinations: {len(cached_tracks)}" + ) + + # Save cache if path is specified + if self.cache_path: + self.save_cache() + + def save_cache(self, cache_path: Union[str, Path, None] = None): + """Save the image cache to disk using pickle. + + Parameters + ---------- + cache_path : Union[str, Path, None], optional + Path to save the cache. If None, uses self.cache_path, by default None + """ + import pickle + + if cache_path is None: + if self.cache_path is None: + logger.warning("No cache path specified, skipping cache save") + return + cache_path = self.cache_path + else: + cache_path = Path(cache_path) + + # Create parent directory if it doesn't exist + cache_path.parent.mkdir(parents=True, exist_ok=True) + + # Save cache metadata for validation + cache_metadata = { + "datasets": { + name: { + "data_path": str(config.data_path), + "tracks_path": str(config.tracks_path), + "features_path": str(config.features_path), + "channels": config.channels_to_display, + "z_range": config.z_range, + "yx_patch_size": config.yx_patch_size, + } + for name, config in self.datasets.items() + }, + "cache_size": len(self.image_cache), + } + + try: + logger.info(f"Saving image cache to {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump((cache_metadata, self.image_cache), f) + except Exception as e: + logger.error(f"Error saving cache: {e}") + + def load_cache(self, cache_path: Union[str, Path, None] = None) -> bool: + """Load the image cache from disk using pickle. + + Parameters + ---------- + cache_path : Union[str, Path, None], optional + Path to load the cache from. If None, uses self.cache_path + """ + import pickle + + if cache_path is None: + if self.cache_path is None: + logger.warning("No cache path specified, skipping cache load") + return False + cache_path = self.cache_path + else: + cache_path = Path(cache_path) + + try: + logger.info(f"Loading image cache from {cache_path}") + with open(cache_path, "rb") as f: + cache_metadata, self.image_cache = pickle.load(f) + logger.info( + f"Successfully loaded cache with {len(self.image_cache)} images" + ) + return True + except Exception as e: + logger.error(f"Error loading cache: {e}") + return False + + def run(self, debug=False, port=None): + """Run the Dash server + + Parameters + ---------- + debug : bool, optional + Whether to run in debug mode, by default False + port : int, optional + Port to run on. If None, will try ports from 8050-8070, by default None + """ + import socket + + def is_port_in_use(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("127.0.0.1", port)) + return False + except socket.error: + return True + + if port is None: + # Try ports from 8050 to 8070 + port_range = list(range(8050, 8071)) + for p in port_range: + if not is_port_in_use(p): + port = p + break + if port is None: + raise RuntimeError( + f"Could not find an available port in range {port_range[0]}-{port_range[-1]}" + ) + + try: + logger.info(f"Starting server on port {port}") + if self.app is not None: + self.app.run( + debug=debug, + port=port, + use_reloader=False, + ) + except KeyboardInterrupt: + logger.info("Server shutdown requested...") + except Exception as e: + logger.error(f"Error running server: {e}") + finally: + self._cleanup_cache() + logger.info("Server shutdown complete") diff --git a/viscy/representation/visualization/cluster.py b/viscy/representation/visualization/cluster.py new file mode 100644 index 000000000..39fc0b404 --- /dev/null +++ b/viscy/representation/visualization/cluster.py @@ -0,0 +1,262 @@ +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional, Set, Tuple + + +@dataclass +class ClusterPoint: + """Represents a single point in a cluster.""" + + track_id: int + t: int + fov_name: str + dataset: str + x_coord: Optional[float] = None + y_coord: Optional[float] = None + z_coord: Optional[float] = None + + @property + def cache_key(self) -> Tuple[str, str, int, int]: + """Get the cache key for this point.""" + return (self.dataset, self.fov_name, self.track_id, self.t) + + @property + def unique_track_id(self) -> str: + """Get a globally unique track identifier.""" + return f"{self.dataset}_{self.track_id}" + + def __eq__(self, other) -> bool: + """Two points are equal if they have the same cache key.""" + if not isinstance(other, ClusterPoint): + return False + return self.cache_key == other.cache_key + + def __hash__(self) -> int: + """Hash based on cache key for use in sets.""" + return hash(self.cache_key) + + +@dataclass +class Cluster: + """Represents a cluster of points with metadata.""" + + points: List[ClusterPoint] = field(default_factory=list) + name: str = "" + color: str = "" + created_at: datetime = field(default_factory=datetime.now) + _id: str = field(default_factory=lambda: str(uuid.uuid4())) + + @property + def id(self) -> str: + """Unique identifier for this cluster.""" + return self._id + + @property + def size(self) -> int: + """Number of points in this cluster.""" + return len(self.points) + + @property + def datasets(self) -> Set[str]: + """Get all datasets represented in this cluster.""" + return {point.dataset for point in self.points} + + @property + def cache_keys(self) -> Set[Tuple[str, str, int, int]]: + """Get all cache keys for points in this cluster.""" + return {point.cache_key for point in self.points} + + def add_point(self, point: ClusterPoint) -> None: + """Add a point to this cluster.""" + if point not in self.points: + self.points.append(point) + + def remove_point(self, point: ClusterPoint) -> bool: + """Remove a point from this cluster. Returns True if point was found and removed.""" + try: + self.points.remove(point) + return True + except ValueError: + return False + + def contains_cache_key(self, cache_key: Tuple[str, str, int, int]) -> bool: + """Check if this cluster contains a point with the given cache key.""" + return cache_key in self.cache_keys + + def get_default_name(self, cluster_number: int) -> str: + """Generate a default name for this cluster.""" + if len(self.datasets) == 1: + dataset_name = list(self.datasets)[0] + return f"{dataset_name}: Cluster {cluster_number}" + else: + datasets_str = ", ".join(sorted(self.datasets)) + return f"[{datasets_str}]: Cluster {cluster_number}" + + def to_dict_list(self) -> List[Dict]: + """Convert cluster points to the legacy dictionary format for compatibility.""" + return [ + { + "track_id": point.track_id, + "t": point.t, + "fov_name": point.fov_name, + "dataset": point.dataset, + } + for point in self.points + ] + + +class ClusterManager: + """Manages a collection of clusters.""" + + def __init__(self): + self._clusters: List[Cluster] = [] + + @property + def clusters(self) -> List[Cluster]: + """Get all clusters.""" + return self._clusters + + @property + def cluster_count(self) -> int: + """Get the number of clusters.""" + return len(self._clusters) + + @property + def all_cluster_points(self) -> Set[Tuple[str, str, int, int]]: + """Get all cache keys from all clusters.""" + all_keys = set() + for cluster in self._clusters: + all_keys.update(cluster.cache_keys) + return all_keys + + def add_cluster(self, cluster: Cluster) -> str: + """Add a cluster to the manager and return its ID.""" + self._clusters.append(cluster) + return cluster.id + + def create_cluster_from_points( + self, points_data: List[Dict], name: str = "" + ) -> str: + """Create a new cluster from point data and add it to the manager.""" + cluster = Cluster() + + for point_data in points_data: + point = ClusterPoint( + track_id=point_data["track_id"], + t=point_data["t"], + fov_name=point_data["fov_name"], + dataset=point_data["dataset"], + ) + cluster.add_point(point) + + if name: + cluster.name = name + else: + # Generate default name + cluster.name = cluster.get_default_name(len(self._clusters) + 1) + + return self.add_cluster(cluster) + + def remove_cluster(self, cluster_id: str) -> bool: + """Remove a cluster by ID. Returns True if cluster was found and removed.""" + for i, cluster in enumerate(self._clusters): + if cluster.id == cluster_id: + del self._clusters[i] + return True + return False + + def remove_last_cluster(self) -> Optional[Cluster]: + """Remove and return the most recently added cluster.""" + if self._clusters: + return self._clusters.pop() + return None + + def get_cluster_by_id(self, cluster_id: str) -> Optional[Cluster]: + """Get a cluster by its ID.""" + for cluster in self._clusters: + if cluster.id == cluster_id: + return cluster + return None + + def get_cluster_by_index(self, index: int) -> Optional[Cluster]: + """Get a cluster by its index (for backward compatibility).""" + if 0 <= index < len(self._clusters): + return self._clusters[index] + return None + + def clear_all_clusters(self) -> None: + """Remove all clusters.""" + self._clusters.clear() + + def get_cluster_colors(self) -> Dict[str, str]: + """Get a mapping of cluster IDs to their colors.""" + import matplotlib.pyplot as plt + + colors = {} + for i, cluster in enumerate(self._clusters): + if not cluster.color: + # Auto-assign color if not set + cmap = plt.cm.get_cmap("Set2") + cluster.color = f"rgb{tuple(int(x * 255) for x in cmap(i % 8)[:3])}" + colors[cluster.id] = cluster.color + return colors + + def get_cluster_colors_by_index(self) -> List[str]: + """Get cluster colors as a list (for backward compatibility).""" + import matplotlib.pyplot as plt + + colors = [] + for i, cluster in enumerate(self._clusters): + if not cluster.color: + cmap = plt.cm.get_cmap("Set2") + cluster.color = f"rgb{tuple(int(x * 255) for x in cmap(i % 8)[:3])}" + colors.append(cluster.color) + return colors + + def is_point_in_any_cluster(self, cache_key: Tuple[str, str, int, int]) -> bool: + """Check if a point is in any cluster.""" + return any(cluster.contains_cache_key(cache_key) for cluster in self._clusters) + + def get_cluster_containing_point( + self, cache_key: Tuple[str, str, int, int] + ) -> Optional[Cluster]: + """Get the cluster containing a specific point.""" + for cluster in self._clusters: + if cluster.contains_cache_key(cache_key): + return cluster + return None + + def get_point_to_cluster_mapping(self) -> Dict[Tuple[str, str, int, int], int]: + """Get a mapping from cache keys to cluster indices (for backward compatibility).""" + point_to_cluster = {} + for cluster_idx, cluster in enumerate(self._clusters): + for cache_key in cluster.cache_keys: + point_to_cluster[cache_key] = cluster_idx + return point_to_cluster + + def update_cluster_name(self, cluster_id: str, new_name: str) -> bool: + """Update a cluster's name. Returns True if successful.""" + cluster = self.get_cluster_by_id(cluster_id) + if cluster: + cluster.name = new_name + return True + return False + + def update_cluster_name_by_index(self, index: int, new_name: str) -> bool: + """Update a cluster's name by index (for backward compatibility).""" + cluster = self.get_cluster_by_index(index) + if cluster: + cluster.name = new_name + return True + return False + + def get_cluster_names_by_index(self) -> Dict[int, str]: + """Get cluster names mapped by index (for backward compatibility).""" + return {i: cluster.name for i, cluster in enumerate(self._clusters)} + + def to_legacy_format(self) -> Tuple[List[List[Dict]], Dict[int, str]]: + """Convert to the legacy format for backward compatibility.""" + clusters_data = [cluster.to_dict_list() for cluster in self._clusters] + cluster_names = self.get_cluster_names_by_index() + return clusters_data, cluster_names diff --git a/viscy/representation/visualization/settings.py b/viscy/representation/visualization/settings.py new file mode 100644 index 000000000..d90520103 --- /dev/null +++ b/viscy/representation/visualization/settings.py @@ -0,0 +1,84 @@ +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field, field_validator + + +class DatasetConfig(BaseModel): + """Configuration for a single dataset.""" + + features_path: str + data_path: str + tracks_path: str + channels_to_display: List[str] + z_range: Tuple[int, int] + yx_patch_size: Tuple[int, int] + fov_tracks: Dict[str, Union[List[int], str]] = Field(default_factory=dict) + + @field_validator("features_path", "data_path", "tracks_path") + @classmethod + def validate_paths(cls, v): + if not Path(v).exists(): + logging.warning(f"Path does not exist: {v}") + return v + + +class VizConfig(BaseModel): + """Configuration for visualization app.""" + + datasets: Dict[str, DatasetConfig] = Field(default_factory=dict) + + num_PC_components: int = Field(default=8, ge=1, le=10) + phate_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description="PHATE parameters. If None, PHATE will not be computed.", + ) + + # Combined analysis options + use_cached_combined_phate: bool = Field( + default=True, + description="Use cached combined PHATE results if available", + ) + combined_phate_cache_path: Optional[str] = Field( + default=None, + description="Path to cache combined PHATE results. If None, uses cache_path/combined_phate.zarr", + ) + + # File system paths + output_dir: Optional[str] = Field( + default=None, + description="Directory to save CSV files and other outputs. If None, uses current working directory.", + ) + cache_path: Optional[str] = Field( + default=None, + description="Path to save/load image cache. If None, images will not be cached to disk.", + ) + + @field_validator("output_dir", "cache_path", "combined_phate_cache_path") + @classmethod + def validate_optional_paths(cls, v): + if v is not None: + # Create parent directory if it doesn't exist + path = Path(v) + if not path.parent.exists(): + logging.info(f"Creating parent directory for: {v}") + path.parent.mkdir(parents=True, exist_ok=True) + return v + + def get_datasets(self) -> Dict[str, DatasetConfig]: + """Get the datasets configuration.""" + return self.datasets + + def get_all_fov_tracks(self) -> Dict[str, Union[List[int], str]]: + """Get all FOV tracks from all datasets combined.""" + all_fov_tracks = {} + + for dataset_name, dataset_config in self.datasets.items(): + all_fov_tracks[dataset_name] = {} + for fov_name, track_ids in dataset_config.fov_tracks.items(): + # Keep track IDs as original integers + # Uniqueness will be handled by (dataset, track_id) tuple + all_fov_tracks[dataset_name][fov_name] = track_ids + + return all_fov_tracks diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index 4223e6784..fdbbf7e9b 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -1,9 +1,11 @@ import collections import os import re +from pathlib import Path import numpy as np import torch +import yaml from PIL import Image @@ -117,3 +119,55 @@ def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): im = Image.fromarray(data).convert("L") im.info["size"] = data.shape im.save(os.path.join(save_folder, name + ext)) + + +def yaml_to_model(yaml_path: Path, model): + """ + Load model settings from a YAML file and create a model instance. + + Borrowing from recOrder==0.4.0 + + Parameters + ---------- + yaml_path : Path + The path to the YAML file containing the model settings. + model : class + The model class used to create an instance with the loaded settings. + + Returns + ------- + object + An instance of the model class with the loaded settings. + + Raises + ------ + TypeError + If the provided model is not a class or does not have a callable constructor. + FileNotFoundError + If the YAML file specified by `yaml_path` does not exist. + + Notes + ----- + This function loads model settings from a YAML file using `yaml.safe_load()`. + It then creates an instance of the provided `model` class using the loaded settings. + + Examples + -------- + >>> from my_model import MyModel + >>> model = yaml_to_model('model.yaml', MyModel) + + """ + yaml_path = Path(yaml_path) + + if not callable(getattr(model, "__init__", None)): + raise TypeError( + "The provided model must be a class with a callable constructor." + ) + + try: + with open(yaml_path, "r") as file: + raw_settings = yaml.safe_load(file) + except FileNotFoundError: + raise FileNotFoundError(f"The YAML file '{yaml_path}' does not exist.") + + return model(**raw_settings)