diff --git a/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py
index cce99c9e4..80f2fd056 100644
--- a/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py
+++ b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py
@@ -1,11 +1,12 @@
-"""Interactive visualization of phenotype data."""
-
import logging
from pathlib import Path
+import click
from numpy.random import seed
-from viscy.representation.evaluation.visualization import EmbeddingVisualizationApp
+from viscy.representation.visualization.app import EmbeddingVisualizationApp
+from viscy.representation.visualization.settings import VizConfig
+from viscy.utils.cli_utils import yaml_to_model
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -13,33 +14,34 @@
seed(42)
-def main():
+@click.command()
+@click.option(
+ "--config-filepath",
+ "-c",
+ required=True,
+ help="Path to YAML configuration file.",
+)
+def main(config_filepath):
"""Main function to run the visualization app."""
- # Config for the visualization app
- # TODO: Update the paths to the downloaded data. By default the data is downloaded to ~/data/dynaclr/demo
- download_root = Path.home() / "data/dynaclr/demo"
- output_path = Path.home() / "data/dynaclr/demo/embedding-web-visualization"
- viz_config = {
- "data_path": download_root / "registered_test.zarr", # TODO add path to data
- "tracks_path": download_root / "track_test.zarr", # TODO add path to tracks
- "features_path": download_root
- / "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr", # TODO add path to features
- "channels_to_display": ["Phase3D", "RFP"],
- "fov_tracks": {
- "/A/3/9": list(range(50)),
- "/B/4/9": list(range(50)),
- },
- "yx_patch_size": (160, 160),
- "z_range": (24, 29),
- "num_PC_components": 8,
- "output_dir": output_path,
- }
+ # Load and validate configuration from YAML file
+ viz_config = yaml_to_model(yaml_path=config_filepath, model=VizConfig)
+
+ # Use configured paths, with fallbacks to current defaults if not specified
+ output_dir = viz_config.output_dir or str(Path(__file__).parent / "output")
+ cache_path = viz_config.cache_path
+
+ logger.info(f"Using output directory: {output_dir}")
+ logger.info(f"Using cache path: {cache_path}")
# Create and run the visualization app
try:
- # Create and run the visualization app
- app = EmbeddingVisualizationApp(**viz_config)
+ app = EmbeddingVisualizationApp(
+ viz_config=viz_config,
+ cache_path=cache_path,
+ num_loading_workers=16,
+ output_dir=output_dir,
+ )
app.preload_images()
app.run(debug=True)
diff --git a/examples/DynaCLR/embedding-web-visualization/viz_config.yaml b/examples/DynaCLR/embedding-web-visualization/viz_config.yaml
new file mode 100644
index 000000000..ac35b0e98
--- /dev/null
+++ b/examples/DynaCLR/embedding-web-visualization/viz_config.yaml
@@ -0,0 +1,52 @@
+# Multi-dataset visualization configuration
+datasets:
+ g3bp1_sensor:
+ data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
+ tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
+ features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/sensor_160patch_99ckpt_max.zarr"
+ channels_to_display: ["raw mCherry EX561 EM600-37"]
+ z_range: [0, 1]
+ yx_patch_size: [192, 192]
+ fov_tracks:
+ "/B/3/000000": [42, 47, 56, 57, 58, 59, 61]
+ "/B/4/000001": [57, 83, 89, 90, 91, 101]
+
+
+ g3bp1_organelle:
+ data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
+ tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
+ features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/organelle_160patch_99ckpt_max.zarr"
+ channels_to_display: ["raw GFP EX488 EM525-45"]
+ z_range: [0, 1]
+ yx_patch_size: [192, 192]
+ fov_tracks:
+ # "/B/3/000000": "all"
+ # "/B/4/000001": "all"
+ "/B/2/000000": [15, 47, 61, 62, 63, 64, 65]
+ "/C/2/000001": "all"
+ # "/C/2/000000": []
+
+ # sec61_organelle:
+ # data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
+ # tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
+ # features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/2-predictions/organelle_160patch_99ckpt_max.zarr"
+ # channels_to_display: ["raw GFP EX488 EM525-45"]
+ # z_range: [0, 1]
+ # yx_patch_size: [192, 192]
+
+# Global settings
+num_PC_components: 8
+# phate_kwargs:
+# n_components: 4 # Number of PHATE components
+# knn: 5 # Number of nearest neighbors for KNN graph
+# decay: 40 # Decay parameter for Markov operator
+# n_jobs: -1 # Number of parallel jobs (-1 uses all cores)
+# random_state: 42 # Random seed for reproducibility
+# # gamma: 1.0 # Informational scale parameter (optional)
+# # t: "auto" # Number of diffusion steps (optional, auto-selects optimal)
+# # a: 1 # Alpha parameter for alpha-decay kernel (optional)
+# # verbose: 0 # Verbosity level (optional)
+
+# File system paths
+output_dir: "./output" # Directory to save CSV files and other outputs
+cache_path: "/home/eduardo.hirata/mydata/tmp/pcviewer/cache/image_cache1.pkl" # Path to save/load image cache
\ No newline at end of file
diff --git a/viscy/representation/evaluation/combined_analysis.py b/viscy/representation/evaluation/combined_analysis.py
new file mode 100644
index 000000000..cc8966174
--- /dev/null
+++ b/viscy/representation/evaluation/combined_analysis.py
@@ -0,0 +1,172 @@
+import logging
+from pathlib import Path
+from typing import Literal, Optional
+
+import numpy as np
+import pandas as pd
+from xarray import Dataset
+
+from viscy.representation.embedding_writer import (
+ read_embedding_dataset,
+ write_embedding_dataset,
+)
+from viscy.representation.evaluation.data_loading import (
+ EmbeddingDataLoader,
+ TripletEmbeddingLoader,
+)
+from viscy.representation.evaluation.dimensionality_reduction import compute_phate
+
+__all__ = ["load_and_combine_features", "compute_phate_for_combined_datasets"]
+
+_logger = logging.getLogger("lightning.pytorch")
+
+
+def load_and_combine_features(
+ feature_paths: list[Path],
+ dataset_names: Optional[list[str]] = None,
+ loader: Literal[
+ EmbeddingDataLoader, TripletEmbeddingLoader
+ ] = TripletEmbeddingLoader(),
+) -> tuple[np.ndarray, pd.DataFrame]:
+ """
+ Load features from multiple datasets and combine them using a pluggable loader.
+
+ Parameters
+ ----------
+ feature_paths : list[Path]
+ Paths to embedding datasets
+ dataset_names : list[str], optional
+ Names for datasets. If None, uses file stems
+ loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional
+ Custom data loader. If None, uses TripletEmbeddingLoader
+
+ Returns
+ -------
+ tuple[np.ndarray, pd.DataFrame]
+ Combined features array and index DataFrame with dataset_pair column
+ """
+ if dataset_names is None:
+ dataset_names = [path.stem for path in feature_paths]
+
+ if len(dataset_names) != len(feature_paths):
+ raise ValueError("Number of dataset names must match number of feature paths")
+
+ all_features = []
+ all_indices = []
+
+ for path, dataset_name in zip(feature_paths, dataset_names):
+ _logger.info(f"Loading features from {path}")
+
+ dataset = loader.load_dataset(path)
+ features = loader.extract_features(dataset)
+ index_df = loader.extract_metadata(dataset)
+
+ index_df["dataset_pair"] = dataset_name
+ index_df["dataset_path"] = str(path)
+
+ all_features.append(features)
+ all_indices.append(index_df)
+
+ _logger.info(f"Loaded {len(features)} samples from {dataset_name}")
+
+ combined_features = np.vstack(all_features)
+ combined_indices = pd.concat(all_indices, ignore_index=True)
+
+ _logger.info(
+ f"Combined {len(combined_features)} total samples from {len(feature_paths)} datasets"
+ )
+
+ return combined_features, combined_indices
+
+
+def compute_phate_for_combined_datasets(
+ feature_paths: list[Path],
+ output_path: Path,
+ dataset_names: Optional[list[str]] = None,
+ phate_kwargs: Optional[dict] = None,
+ overwrite: bool = False,
+ loader: Literal[
+ EmbeddingDataLoader, TripletEmbeddingLoader
+ ] = TripletEmbeddingLoader(),
+) -> Dataset:
+ """
+ Compute PHATE embeddings on combined features from multiple datasets.
+
+ Parameters
+ ----------
+ feature_paths : list[Path]
+ List of paths to zarr stores containing embedding datasets
+ output_path : Path
+ Path to save the combined dataset with PHATE embeddings
+ dataset_names : list[str], optional
+ Names for each dataset. If None, uses file stems
+ phate_kwargs : dict, optional
+ Parameters for PHATE computation. Default: {"knn": 5, "decay": 40, "n_components": 2}
+ overwrite : bool, optional
+ Whether to overwrite existing output file
+ loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional
+ Custom data loader. If None, uses TripletEmbeddingLoader
+
+ Returns
+ -------
+ Dataset
+ Combined xarray dataset with original features and PHATE coordinates
+
+ Raises
+ ------
+ FileExistsError
+ If output_path exists and overwrite is False
+ ImportError
+ If PHATE is not installed
+ """
+ output_path = Path(output_path)
+
+ if output_path.exists() and not overwrite:
+ raise FileExistsError(
+ f"Output path {output_path} already exists. Use overwrite=True to overwrite."
+ )
+
+ if phate_kwargs is None:
+ phate_kwargs = {"knn": 5, "decay": 40, "n_components": 2}
+
+ _logger.info(
+ f"Computing PHATE for combined datasets with parameters: {phate_kwargs}"
+ )
+
+ combined_features, combined_indices = load_and_combine_features(
+ feature_paths, dataset_names, loader
+ )
+
+ n_samples = len(combined_features)
+ if phate_kwargs.get("knn", 5) >= n_samples:
+ original_knn = phate_kwargs["knn"]
+ phate_kwargs["knn"] = max(2, n_samples // 2)
+ _logger.warning(
+ f"Reducing knn from {original_knn} to {phate_kwargs['knn']} due to dataset size ({n_samples} samples)"
+ )
+
+ _logger.info("Computing PHATE embeddings on combined features")
+ try:
+ _, phate_embedding = compute_phate(combined_features, **phate_kwargs)
+ _logger.info(
+ f"PHATE computation successful, embedding shape: {phate_embedding.shape}"
+ )
+ except Exception as e:
+ _logger.error(f"PHATE computation failed: {str(e)}")
+ raise
+
+ _logger.info(f"Writing combined dataset with PHATE embeddings to {output_path}")
+ write_embedding_dataset(
+ output_path=output_path,
+ features=combined_features,
+ index_df=combined_indices,
+ phate_kwargs=phate_kwargs,
+ overwrite=overwrite,
+ )
+
+ result_dataset = read_embedding_dataset(output_path)
+ _logger.info(
+ f"Successfully created combined dataset with {len(result_dataset.sample)} samples"
+ )
+
+ return result_dataset
diff --git a/viscy/representation/evaluation/data_loading.py b/viscy/representation/evaluation/data_loading.py
new file mode 100644
index 000000000..f1625c2c2
--- /dev/null
+++ b/viscy/representation/evaluation/data_loading.py
@@ -0,0 +1,163 @@
+import logging
+from pathlib import Path
+from typing import Protocol, runtime_checkable
+
+import numpy as np
+import pandas as pd
+from xarray import Dataset
+
+__all__ = ["EmbeddingDataLoader", "TripletEmbeddingLoader"]
+
+_logger = logging.getLogger("lightning.pytorch")
+
+
+@runtime_checkable
+class EmbeddingDataLoader(Protocol):
+ """Protocol for embedding dataloaders that can be used with combined analysis."""
+
+ def load_dataset(self, path: Path) -> Dataset:
+ """
+ Load dataset from path and return xarray Dataset with 'features' data variable.
+
+ Parameters
+ ----------
+ path : Path
+ Path to the dataset file
+
+ Returns
+ -------
+ Dataset
+ Xarray dataset with 'features' data variable containing embeddings
+ """
+ ...
+
+ def extract_features(self, dataset: Dataset) -> np.ndarray:
+ """
+ Extract feature embeddings from dataset.
+
+ Parameters
+ ----------
+ dataset : Dataset
+ Xarray dataset containing features
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (n_samples, n_features) containing embeddings
+ """
+ ...
+
+ def extract_metadata(self, dataset: Dataset) -> pd.DataFrame:
+ """
+ Extract metadata/index information from dataset.
+
+ Parameters
+ ----------
+ dataset : Dataset
+ Xarray dataset containing metadata
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame containing metadata for each sample, including any existing
+ dimensionality reduction coordinates (PHATE, UMAP, etc.)
+ """
+ ...
+
+
+class TripletEmbeddingLoader:
+ """Default loader for triplet-based embedding datasets."""
+
+ def load_dataset(self, path: Path) -> Dataset:
+ """Load embedding dataset using the standard embedding writer format."""
+ from viscy.representation.embedding_writer import read_embedding_dataset
+
+ _logger.debug(f"Loading dataset from {path} using TripletEmbeddingLoader")
+ return read_embedding_dataset(path)
+
+ def extract_features(self, dataset: Dataset) -> np.ndarray:
+ """Extract features from the 'features' data variable."""
+ return dataset["features"].values
+
+ def extract_metadata(self, dataset: Dataset) -> pd.DataFrame:
+ """
+ Extract metadata from dataset coordinates and data variables.
+
+ This includes sample coordinates and any existing dimensionality reduction
+ coordinates like PHATE, UMAP, PCA that were previously computed.
+ """
+ features_data_array = dataset["features"]
+
+ try:
+ coord_df = features_data_array["sample"].to_dataframe()
+
+ if coord_df.index.names != [None]:
+ index_df = coord_df.reset_index()
+ if "features" in index_df.columns:
+ index_df = index_df.drop(columns=["features"])
+ else:
+ index_df = coord_df.reset_index(drop=True)
+
+ dim_reduction_cols = [
+ col
+ for col in index_df.columns
+ if any(col.startswith(prefix) for prefix in ["PHATE", "UMAP", "PCA"])
+ ]
+ if dim_reduction_cols:
+ _logger.debug(
+ f"Found dimensionality reduction coordinates: {dim_reduction_cols}"
+ )
+
+ _logger.debug(
+ f"Extracted metadata with {len(index_df.columns)} columns: {list(index_df.columns)}"
+ )
+ return index_df
+
+ except Exception as e:
+ _logger.error(f"Error extracting metadata: {e}")
+ index_df = (
+ features_data_array["sample"].to_dataframe().reset_index(drop=True)
+ )
+ _logger.warning(
+ f"Using fallback metadata extraction with {len(index_df.columns)} columns"
+ )
+ return index_df
+
+
+# Example of how to implement a custom loader
+# TODO: replace with the other dataloaders
+class CustomEmbeddingLoader:
+ """
+ Example implementation of a custom embedding loader.
+
+ This serves as a template for implementing loaders for different data formats.
+ Replace the method implementations with your specific loading logic.
+ """
+
+ def load_dataset(self, path: Path) -> Dataset:
+ """
+ Load your custom dataset format.
+
+ This should return an xarray Dataset with at least a 'features' data variable
+ containing the embeddings with a 'sample' dimension.
+ """
+ raise NotImplementedError("Implement your custom dataset loading logic here")
+
+ def extract_features(self, dataset: Dataset) -> np.ndarray:
+ """
+ Extract features from your dataset format.
+
+ Should return a numpy array of shape (n_samples, n_features).
+ """
+ return dataset["features"].values # Modify if your format is different
+
+ def extract_metadata(self, dataset: Dataset) -> pd.DataFrame:
+ """
+ Extract metadata from your dataset format.
+
+ Should return a DataFrame with one row per sample containing metadata
+ like sample IDs, FOV names, track IDs, etc.
+ """
+ raise NotImplementedError(
+ "Implement your custom metadata extraction logic here"
+ )
diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py
index eb5d43f91..80634c7a6 100644
--- a/viscy/representation/evaluation/dimensionality_reduction.py
+++ b/viscy/representation/evaluation/dimensionality_reduction.py
@@ -61,7 +61,11 @@ def compute_phate(
# Compute PHATE embeddings
phate_model = phate.PHATE(
- n_components=n_components, knn=knn, decay=decay, **phate_kwargs
+ n_components=n_components,
+ knn=knn,
+ decay=decay,
+ random_state=42,
+ **phate_kwargs,
)
phate_embedding = phate_model.fit_transform(embeddings)
diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py
deleted file mode 100644
index 9d787fe05..000000000
--- a/viscy/representation/evaluation/visualization.py
+++ /dev/null
@@ -1,2244 +0,0 @@
-import atexit
-import base64
-import json
-import logging
-from io import BytesIO
-from pathlib import Path
-
-import dash
-import dash.dependencies as dd
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import plotly.graph_objects as go
-from dash import dcc, html
-from PIL import Image
-from sklearn.decomposition import PCA
-from sklearn.preprocessing import StandardScaler
-
-from viscy.data.triplet import TripletDataModule
-from viscy.representation.embedding_writer import read_embedding_dataset
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-
-class EmbeddingVisualizationApp:
- def __init__(
- self,
- data_path: str,
- tracks_path: str,
- features_path: str,
- channels_to_display: list[str] | str,
- fov_tracks: dict[str, list[int] | str],
- z_range: tuple[int, int] = (0, 1),
- yx_patch_size: tuple[int, int] = (128, 128),
- num_PC_components: int = 3,
- cache_path: str | None = None,
- num_loading_workers: int = 16,
- output_dir: str | None = None,
- ) -> None:
- """
- Initialize a Dash application for visualizing the DynaCLR embeddings.
-
- This class provides a visualization tool for visualizing the DynaCLR embeddings into a 2D space (e.g. PCA, UMAP, PHATE).
- It allows users to interactively explore and analyze trajectories, visualize clusters, and explore the embedding space.
-
- Parameters
- ----------
- data_path: str
- Path to the data directory.
- tracks_path: str
- Path to the tracks directory.
- features_path: str
- Path to the features directory.
- channels_to_display: list[str] | str
- List of channels to display.
- fov_tracks: dict[str, list[int] | str]
- Dictionary of FOV names and track IDs.
- z_range: tuple[int, int] | list[int,int]
- Range of z-slices to display.
- yx_patch_size: tuple[int, int] | list[int,int]
- Size of the yx-patch to display.
- num_PC_components: int
- Number of PCA components to use.
- cache_path: str | None
- Path to the cache directory.
- num_loading_workers: int
- Number of workers to use for loading data.
- output_dir: str | None, optional
- Directory to save CSV files and other outputs. If None, uses current working directory.
- Returns
- -------
- None
- Initializes the visualization app.
- """
- self.data_path = Path(data_path)
- self.tracks_path = Path(tracks_path)
- self.features_path = Path(features_path)
- self.fov_tracks = fov_tracks
- self.image_cache = {}
- self.cache_path = Path(cache_path) if cache_path else None
- self.output_dir = Path(output_dir) if output_dir else Path.cwd()
- self.app = None
- self.features_df = None
- self.fig = None
- self.channels_to_display = channels_to_display
- self.z_range = z_range
- self.yx_patch_size = yx_patch_size
- self.filtered_tracks_by_fov = {}
- self._z_idx = (self.z_range[1] - self.z_range[0]) // 2
- self.num_PC_components = num_PC_components
- self.num_loading_workers = num_loading_workers
- # Initialize cluster storage before preparing data and creating figure
- self.clusters = [] # List to store all clusters
- self.cluster_points = set() # Set to track all points in clusters
- self.cluster_names = {} # Dictionary to store cluster names
- self.next_cluster_id = 1 # Counter for cluster IDs
- # Initialize data
- self._prepare_data()
- self._create_figure()
- self._init_app()
- atexit.register(self._cleanup_cache)
-
- def _prepare_data(self):
- """Prepare the feature data and PCA transformation"""
- embedding_dataset = read_embedding_dataset(self.features_path)
- features = embedding_dataset["features"]
- self.features_df = features["sample"].to_dataframe().reset_index(drop=True)
-
- # Check if UMAP or PHATE columns already exist
- existing_dims = []
- dim_options = []
-
- # Check for PCA and compute if needed
- if not any(col.startswith("PCA") for col in self.features_df.columns):
- # PCA transformation
- scaled_features = StandardScaler().fit_transform(features.values)
- pca = PCA(n_components=self.num_PC_components)
- pca_coords = pca.fit_transform(scaled_features)
-
- # Add PCA coordinates to the features dataframe
- for i in range(self.num_PC_components):
- self.features_df[f"PCA{i + 1}"] = pca_coords[:, i]
-
- # Store explained variance for PCA
- self.pca_explained_variance = [
- f"PC{i + 1} ({var:.1f}%)"
- for i, var in enumerate(pca.explained_variance_ratio_ * 100)
- ]
-
- # Add PCA options
- for i, pc_label in enumerate(self.pca_explained_variance):
- dim_options.append({"label": pc_label, "value": f"PCA{i + 1}"})
- existing_dims.append(f"PCA{i + 1}")
-
- # Check for UMAP coordinates
- umap_dims = [col for col in self.features_df.columns if col.startswith("UMAP")]
- if umap_dims:
- for dim in umap_dims:
- dim_options.append({"label": dim, "value": dim})
- existing_dims.append(dim)
-
- # Check for PHATE coordinates
- phate_dims = [
- col for col in self.features_df.columns if col.startswith("PHATE")
- ]
- if phate_dims:
- for dim in phate_dims:
- dim_options.append({"label": dim, "value": dim})
- existing_dims.append(dim)
-
- # Store dimension options for dropdowns
- self.dim_options = dim_options
-
- # Set default x and y axes based on available dimensions
- self.default_x = existing_dims[0] if existing_dims else "PCA1"
- self.default_y = existing_dims[1] if len(existing_dims) > 1 else "PCA2"
-
- # Process each FOV and its track IDs
- all_filtered_features = []
- for fov_name, track_ids in self.fov_tracks.items():
- if track_ids == "all":
- fov_tracks = (
- self.features_df[self.features_df["fov_name"] == fov_name][
- "track_id"
- ]
- .unique()
- .tolist()
- )
- else:
- fov_tracks = track_ids
-
- self.filtered_tracks_by_fov[fov_name] = fov_tracks
-
- # Filter features for this FOV and its track IDs
- fov_features = self.features_df[
- (self.features_df["fov_name"] == fov_name)
- & (self.features_df["track_id"].isin(fov_tracks))
- ]
- all_filtered_features.append(fov_features)
-
- # Combine all filtered features
- self.filtered_features_df = pd.concat(all_filtered_features, axis=0)
-
- def _create_figure(self):
- """Create the initial scatter plot figure"""
- self.fig = self._create_track_colored_figure()
-
- def _init_app(self):
- """Initialize the Dash application"""
- self.app = dash.Dash(__name__)
-
- # Add cluster assignment button next to clear selection
- cluster_controls = html.Div(
- [
- html.Button(
- "Assign to New Cluster",
- id="assign-cluster",
- style={
- "backgroundColor": "#28a745",
- "color": "white",
- "border": "none",
- "padding": "5px 10px",
- "borderRadius": "4px",
- "cursor": "pointer",
- "marginRight": "10px",
- },
- ),
- html.Button(
- "Clear All Clusters",
- id="clear-clusters",
- style={
- "backgroundColor": "#dc3545",
- "color": "white",
- "border": "none",
- "padding": "5px 10px",
- "borderRadius": "4px",
- "cursor": "pointer",
- "marginRight": "10px",
- },
- ),
- html.Button(
- "Save Clusters to CSV",
- id="save-clusters-csv",
- style={
- "backgroundColor": "#17a2b8",
- "color": "white",
- "border": "none",
- "padding": "5px 10px",
- "borderRadius": "4px",
- "cursor": "pointer",
- "marginRight": "10px",
- },
- ),
- html.Button(
- "Clear Selection",
- id="clear-selection",
- style={
- "backgroundColor": "#6c757d",
- "color": "white",
- "border": "none",
- "padding": "5px 10px",
- "borderRadius": "4px",
- "cursor": "pointer",
- },
- ),
- ],
- style={"marginLeft": "10px", "display": "inline-block"},
- )
- # Create tabs for different views
- tabs = dcc.Tabs(
- id="view-tabs",
- value="timeline-tab",
- children=[
- dcc.Tab(
- label="Track Timeline",
- value="timeline-tab",
- children=[
- html.Div(
- id="track-timeline",
- style={
- "height": "auto",
- "overflowY": "auto",
- "maxHeight": "80vh",
- "padding": "10px",
- "marginTop": "10px",
- },
- ),
- ],
- ),
- dcc.Tab(
- label="Clusters",
- value="clusters-tab",
- id="clusters-tab",
- children=[
- html.Div(
- id="cluster-container",
- style={
- "padding": "10px",
- "marginTop": "10px",
- },
- ),
- ],
- style={"display": "none"}, # Initially hidden
- ),
- ],
- style={"marginTop": "20px"},
- )
-
- # Add modal for cluster naming
- cluster_name_modal = html.Div(
- id="cluster-name-modal",
- children=[
- html.Div(
- [
- html.H3("Name Your Cluster", style={"marginBottom": "20px"}),
- html.Label("Cluster Name:"),
- dcc.Input(
- id="cluster-name-input",
- type="text",
- placeholder="Enter cluster name...",
- style={"width": "100%", "marginBottom": "20px"},
- ),
- html.Div(
- [
- html.Button(
- "Save",
- id="save-cluster-name",
- style={
- "backgroundColor": "#28a745",
- "color": "white",
- "border": "none",
- "padding": "8px 16px",
- "borderRadius": "4px",
- "cursor": "pointer",
- "marginRight": "10px",
- },
- ),
- html.Button(
- "Cancel",
- id="cancel-cluster-name",
- style={
- "backgroundColor": "#6c757d",
- "color": "white",
- "border": "none",
- "padding": "8px 16px",
- "borderRadius": "4px",
- "cursor": "pointer",
- },
- ),
- ],
- style={"textAlign": "right"},
- ),
- ],
- style={
- "backgroundColor": "white",
- "padding": "30px",
- "borderRadius": "8px",
- "maxWidth": "400px",
- "margin": "auto",
- "boxShadow": "0 4px 6px rgba(0, 0, 0, 0.1)",
- "border": "1px solid #ddd",
- },
- )
- ],
- style={
- "display": "none",
- "position": "fixed",
- "top": "0",
- "left": "0",
- "width": "100%",
- "height": "100%",
- "backgroundColor": "rgba(0, 0, 0, 0.5)",
- "zIndex": "1000",
- "justifyContent": "center",
- "alignItems": "center",
- },
- )
-
- # Update layout to use tabs
- self.app.layout = html.Div(
- style={
- "maxWidth": "95vw",
- "margin": "auto",
- "padding": "20px",
- },
- children=[
- html.H1(
- "Track Visualization",
- style={"textAlign": "center", "marginBottom": "20px"},
- ),
- html.Div(
- [
- html.Div(
- style={
- "width": "100%",
- "display": "inline-block",
- "verticalAlign": "top",
- },
- children=[
- html.Div(
- style={
- "marginBottom": "20px",
- "display": "flex",
- "alignItems": "center",
- "gap": "20px",
- "flexWrap": "wrap",
- },
- children=[
- html.Div(
- [
- html.Label(
- "Color by:",
- style={"marginRight": "10px"},
- ),
- dcc.Dropdown(
- id="color-mode",
- options=[
- {
- "label": "Track ID",
- "value": "track",
- },
- {
- "label": "Time",
- "value": "time",
- },
- ],
- value="track",
- style={"width": "200px"},
- ),
- ]
- ),
- html.Div(
- [
- dcc.Checklist(
- id="show-arrows",
- options=[
- {
- "label": "Show arrows",
- "value": "show",
- }
- ],
- value=[],
- style={"marginLeft": "20px"},
- ),
- ]
- ),
- html.Div(
- [
- html.Label(
- "X-axis:",
- style={"marginRight": "10px"},
- ),
- dcc.Dropdown(
- id="x-axis",
- options=self.dim_options,
- value=self.default_x,
- style={"width": "200px"},
- ),
- ]
- ),
- html.Div(
- [
- html.Label(
- "Y-axis:",
- style={"marginRight": "10px"},
- ),
- dcc.Dropdown(
- id="y-axis",
- options=self.dim_options,
- value=self.default_y,
- style={"width": "200px"},
- ),
- ]
- ),
- cluster_controls,
- ],
- ),
- ],
- ),
- ]
- ),
- dcc.Loading(
- id="loading",
- children=[
- dcc.Graph(
- id="scatter-plot",
- figure=self.fig,
- config={
- "displayModeBar": True,
- "editable": False,
- "showEditInChartStudio": False,
- "modeBarButtonsToRemove": [
- "select2d",
- "resetScale2d",
- ],
- "edits": {
- "annotationPosition": False,
- "annotationTail": False,
- "annotationText": False,
- "shapePosition": True,
- },
- "scrollZoom": True,
- },
- style={"height": "50vh"},
- ),
- ],
- type="default",
- ),
- tabs,
- cluster_name_modal,
- ],
- )
-
- @self.app.callback(
- [
- dd.Output("scatter-plot", "figure", allow_duplicate=True),
- dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
- ],
- [
- dd.Input("color-mode", "value"),
- dd.Input("show-arrows", "value"),
- dd.Input("x-axis", "value"),
- dd.Input("y-axis", "value"),
- dd.Input("scatter-plot", "relayoutData"),
- dd.Input("scatter-plot", "selectedData"),
- ],
- [dd.State("scatter-plot", "figure")],
- prevent_initial_call=True,
- )
- def update_figure(
- color_mode,
- show_arrows,
- x_axis,
- y_axis,
- relayout_data,
- selected_data,
- current_figure,
- ):
- show_arrows = len(show_arrows or []) > 0
-
- ctx = dash.callback_context
- if not ctx.triggered:
- triggered_id = "No clicks yet"
- else:
- triggered_id = ctx.triggered[0]["prop_id"].split(".")[0]
-
- # Create new figure when necessary
- if triggered_id in [
- "color-mode",
- "show-arrows",
- "x-axis",
- "y-axis",
- ]:
- if color_mode == "track":
- fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis)
- else:
- fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis)
-
- # Update dragmode and selection settings
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- uirevision="true",
- selectdirection="any",
- )
- else:
- fig = dash.no_update
-
- return fig, selected_data
-
- @self.app.callback(
- dd.Output("track-timeline", "children"),
- [dd.Input("scatter-plot", "clickData")],
- prevent_initial_call=True,
- )
- def update_track_timeline(clickData):
- """Update the track timeline based on the clicked point"""
- if clickData is None:
- return html.Div("Click on a point to see the track timeline")
-
- # Parse the hover text to get track_id, time and fov_name
- hover_text = clickData["points"][0]["text"]
- track_id = int(hover_text.split("
")[0].split(": ")[1])
- clicked_time = int(hover_text.split("
")[1].split(": ")[1])
- fov_name = hover_text.split("
")[2].split(": ")[1]
-
- # Get all timepoints for this track
- track_data = self.features_df[
- (self.features_df["fov_name"] == fov_name)
- & (self.features_df["track_id"] == track_id)
- ].sort_values("t")
-
- if track_data.empty:
- return html.Div(f"No data found for track {track_id}")
-
- # Get unique timepoints
- timepoints = track_data["t"].unique()
-
- # Create a list to store all timepoint columns
- timepoint_columns = []
-
- # First create the time labels row
- time_labels = []
- for t in timepoints:
- is_clicked = t == clicked_time
- time_style = {
- "width": "150px",
- "textAlign": "center",
- "padding": "5px",
- "fontWeight": "bold" if is_clicked else "normal",
- "color": "#007bff" if is_clicked else "black",
- }
- time_labels.append(html.Div(f"t={t}", style=time_style))
-
- timepoint_columns.append(
- html.Div(
- time_labels,
- style={
- "display": "flex",
- "flexDirection": "row",
- "minWidth": "fit-content",
- "borderBottom": "2px solid #ddd",
- "marginBottom": "10px",
- "paddingBottom": "5px",
- },
- )
- )
-
- # Then create image rows for each channel
- for channel in self.channels_to_display:
- channel_images = []
- for t in timepoints:
- cache_key = (fov_name, track_id, t)
- if (
- cache_key in self.image_cache
- and channel in self.image_cache[cache_key]
- ):
- is_clicked = t == clicked_time
- image_style = {
- "width": "150px",
- "height": "150px",
- "border": (
- "3px solid #007bff" if is_clicked else "1px solid #ddd"
- ),
- "borderRadius": "4px",
- }
- channel_images.append(
- html.Div(
- html.Img(
- src=self.image_cache[cache_key][channel],
- style=image_style,
- ),
- style={
- "width": "150px",
- "padding": "5px",
- },
- )
- )
-
- if channel_images:
- # Add channel label
- timepoint_columns.append(
- html.Div(
- [
- html.Div(
- channel,
- style={
- "width": "100px",
- "fontWeight": "bold",
- "fontSize": "14px",
- "padding": "5px",
- "backgroundColor": "#f8f9fa",
- "borderRadius": "4px",
- "marginBottom": "5px",
- "textAlign": "center",
- },
- ),
- html.Div(
- channel_images,
- style={
- "display": "flex",
- "flexDirection": "row",
- "minWidth": "fit-content",
- "marginBottom": "15px",
- },
- ),
- ]
- )
- )
-
- # Create the main container with synchronized scrolling
- return html.Div(
- [
- html.H4(
- f"Track {track_id} (FOV: {fov_name})",
- style={
- "marginBottom": "20px",
- "fontSize": "20px",
- "fontWeight": "bold",
- "color": "#2c3e50",
- },
- ),
- html.Div(
- timepoint_columns,
- style={
- "overflowX": "auto",
- "overflowY": "hidden",
- "whiteSpace": "nowrap",
- "backgroundColor": "white",
- "padding": "20px",
- "borderRadius": "8px",
- "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
- "marginBottom": "20px",
- },
- ),
- ]
- )
-
- # Add callback to show/hide clusters tab and handle modal
- @self.app.callback(
- [
- dd.Output("clusters-tab", "style"),
- dd.Output("cluster-container", "children"),
- dd.Output("view-tabs", "value"),
- dd.Output("scatter-plot", "figure", allow_duplicate=True),
- dd.Output("cluster-name-modal", "style"),
- dd.Output("cluster-name-input", "value"),
- dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
- ],
- [
- dd.Input("assign-cluster", "n_clicks"),
- dd.Input("clear-clusters", "n_clicks"),
- dd.Input("save-cluster-name", "n_clicks"),
- dd.Input("cancel-cluster-name", "n_clicks"),
- dd.Input({"type": "edit-cluster-name", "index": dash.ALL}, "n_clicks"),
- ],
- [
- dd.State("scatter-plot", "selectedData"),
- dd.State("scatter-plot", "figure"),
- dd.State("color-mode", "value"),
- dd.State("show-arrows", "value"),
- dd.State("x-axis", "value"),
- dd.State("y-axis", "value"),
- dd.State("cluster-name-input", "value"),
- ],
- prevent_initial_call=True,
- )
- def update_clusters_tab(
- assign_clicks,
- clear_clicks,
- save_name_clicks,
- cancel_name_clicks,
- edit_name_clicks,
- selected_data,
- current_figure,
- color_mode,
- show_arrows,
- x_axis,
- y_axis,
- cluster_name,
- ):
- ctx = dash.callback_context
- if not ctx.triggered:
- return (
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- )
-
- button_id = ctx.triggered[0]["prop_id"].split(".")[0]
-
- # Handle edit cluster name button clicks
- if button_id.startswith('{"type":"edit-cluster-name"'):
- try:
- id_dict = json.loads(button_id)
- cluster_idx = id_dict["index"]
-
- # Get current cluster name
- current_name = self.cluster_names.get(
- cluster_idx, f"Cluster {cluster_idx + 1}"
- )
-
- # Show modal
- modal_style = {
- "display": "flex",
- "position": "fixed",
- "top": "0",
- "left": "0",
- "width": "100%",
- "height": "100%",
- "backgroundColor": "rgba(0, 0, 0, 0.5)",
- "zIndex": "1000",
- "justifyContent": "center",
- "alignItems": "center",
- }
-
- return (
- {"display": "block"},
- self._get_cluster_images(),
- "clusters-tab",
- dash.no_update,
- modal_style,
- current_name,
- dash.no_update, # Don't change selection
- )
- except Exception:
- return (
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- )
-
- if (
- button_id == "assign-cluster"
- and selected_data
- and selected_data.get("points")
- ):
- # Create new cluster from selected points
- new_cluster = []
- for point in selected_data["points"]:
- text = point["text"]
- lines = text.split("
")
- track_id = int(lines[0].split(": ")[1])
- t = int(lines[1].split(": ")[1])
- fov = lines[2].split(": ")[1]
-
- cache_key = (fov, track_id, t)
- if cache_key in self.image_cache:
- new_cluster.append(
- {
- "track_id": track_id,
- "t": t,
- "fov_name": fov,
- }
- )
- self.cluster_points.add(cache_key)
-
- if new_cluster:
- # Add cluster to list but don't assign name yet
- self.clusters.append(new_cluster)
- # Open modal for naming
- modal_style = {
- "display": "flex",
- "position": "fixed",
- "top": "0",
- "left": "0",
- "width": "100%",
- "height": "100%",
- "backgroundColor": "rgba(0, 0, 0, 0.5)",
- "zIndex": "1000",
- "justifyContent": "center",
- "alignItems": "center",
- }
- return (
- {"display": "block"},
- self._get_cluster_images(),
- "clusters-tab",
- dash.no_update, # Don't update figure yet
- modal_style, # Show modal
- "", # Clear input
- None, # Clear selection
- )
-
- elif button_id == "save-cluster-name" and cluster_name:
- # Assign name to the most recently created cluster
- if self.clusters:
- cluster_id = len(self.clusters) - 1
- self.cluster_names[cluster_id] = cluster_name.strip()
-
- # Create new figure with updated colors
- fig = self._create_track_colored_figure(
- len(show_arrows or []) > 0,
- x_axis,
- y_axis,
- )
- # Ensure the dragmode is set based on selection_mode
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- uirevision="true", # Keep the UI state
- selectdirection="any",
- )
- modal_style = {"display": "none"}
- return (
- {"display": "block"},
- self._get_cluster_images(),
- "clusters-tab",
- fig,
- modal_style, # Hide modal
- "", # Clear input
- None, # Clear selection
- )
-
- elif button_id == "cancel-cluster-name":
- # Remove the cluster that was just created
- if self.clusters:
- # Remove points from cluster_points set
- for point in self.clusters[-1]:
- cache_key = (point["fov_name"], point["track_id"], point["t"])
- self.cluster_points.discard(cache_key)
- # Remove the cluster
- self.clusters.pop()
-
- # Create new figure with updated colors
- fig = self._create_track_colored_figure(
- len(show_arrows or []) > 0,
- x_axis,
- y_axis,
- )
- # Ensure the dragmode is set based on selection_mode
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- uirevision="true", # Keep the UI state
- selectdirection="any",
- )
- modal_style = {"display": "none"}
- return (
- (
- {"display": "none"}
- if not self.clusters
- else {"display": "block"}
- ),
- self._get_cluster_images() if self.clusters else None,
- "timeline-tab" if not self.clusters else "clusters-tab",
- fig,
- modal_style, # Hide modal
- "", # Clear input
- None, # Clear selection
- )
-
- elif button_id == "clear-clusters":
- self.clusters = []
- self.cluster_points.clear()
- self.cluster_names.clear()
- # Restore original coloring
- fig = self._create_track_colored_figure(
- len(show_arrows or []) > 0,
- x_axis,
- y_axis,
- )
- # Reset UI state completely to ensure clean slate
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- uirevision=None, # Reset UI state completely
- selectdirection="any",
- )
- modal_style = {"display": "none"}
- return (
- {"display": "none"},
- None,
- "timeline-tab",
- fig,
- modal_style,
- "",
- None,
- ) # Clear selection
-
- return (
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- dash.no_update,
- )
-
- # Add callback for saving clusters to CSV
- @self.app.callback(
- dd.Output("cluster-container", "children", allow_duplicate=True),
- [dd.Input("save-clusters-csv", "n_clicks")],
- prevent_initial_call=True,
- )
- def save_clusters_csv(n_clicks):
- """Callback to save clusters to CSV file"""
- if n_clicks and self.clusters:
- try:
- output_path = self.save_clusters_to_csv()
- return html.Div(
- [
- html.H3("Clusters", style={"marginBottom": "20px"}),
- html.Div(
- f"✅ Successfully saved {len(self.clusters)} clusters to: {output_path}",
- style={
- "backgroundColor": "#d4edda",
- "color": "#155724",
- "padding": "10px",
- "borderRadius": "4px",
- "marginBottom": "20px",
- "border": "1px solid #c3e6cb",
- },
- ),
- self._get_cluster_images(),
- ]
- )
- except Exception as e:
- return html.Div(
- [
- html.H3("Clusters", style={"marginBottom": "20px"}),
- html.Div(
- f"❌ Error saving clusters: {str(e)}",
- style={
- "backgroundColor": "#f8d7da",
- "color": "#721c24",
- "padding": "10px",
- "borderRadius": "4px",
- "marginBottom": "20px",
- "border": "1px solid #f5c6cb",
- },
- ),
- self._get_cluster_images(),
- ]
- )
- elif n_clicks and not self.clusters:
- return html.Div(
- [
- html.H3("Clusters", style={"marginBottom": "20px"}),
- html.Div(
- "⚠️ No clusters to save. Create clusters first by selecting points and clicking 'Assign to New Cluster'.",
- style={
- "backgroundColor": "#fff3cd",
- "color": "#856404",
- "padding": "10px",
- "borderRadius": "4px",
- "marginBottom": "20px",
- "border": "1px solid #ffeaa7",
- },
- ),
- ]
- )
- return dash.no_update
-
- @self.app.callback(
- [
- dd.Output("scatter-plot", "figure", allow_duplicate=True),
- dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
- ],
- [dd.Input("clear-selection", "n_clicks")],
- [
- dd.State("color-mode", "value"),
- dd.State("show-arrows", "value"),
- dd.State("x-axis", "value"),
- dd.State("y-axis", "value"),
- ],
- prevent_initial_call=True,
- )
- def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis):
- """Callback to clear the selection and restore original opacity"""
- if n_clicks:
- # Create a new figure with no selections
- if color_mode == "track":
- fig = self._create_track_colored_figure(
- len(show_arrows or []) > 0,
- x_axis,
- y_axis,
- )
- else:
- fig = self._create_time_colored_figure(
- len(show_arrows or []) > 0,
- x_axis,
- y_axis,
- )
-
- # Update layout to maintain lasso mode but clear selections
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- uirevision=None, # Reset UI state
- selectdirection="any",
- )
-
- return fig, None # Return new figure and clear selectedData
- return dash.no_update, dash.no_update
-
- def _calculate_equal_aspect_ranges(self, x_data, y_data):
- """Calculate ranges for x and y axes to ensure equal aspect ratio.
-
- Parameters
- ----------
- x_data : array-like
- Data for x-axis
- y_data : array-like
- Data for y-axis
-
- Returns
- -------
- tuple
- (x_range, y_range) as tuples of (min, max) with equal scaling
- """
- # Get data ranges
- x_min, x_max = np.min(x_data), np.max(x_data)
- y_min, y_max = np.min(y_data), np.max(y_data)
-
- # Add padding (5% on each side)
- x_padding = 0.05 * (x_max - x_min)
- y_padding = 0.05 * (y_max - y_min)
-
- x_min -= x_padding
- x_max += x_padding
- y_min -= y_padding
- y_max += y_padding
-
- # Ensure equal scaling by using the larger range
- x_range = x_max - x_min
- y_range = y_max - y_min
-
- if x_range > y_range:
- # Expand y-range to match x-range aspect ratio
- y_center = (y_max + y_min) / 2
- y_min = y_center - x_range / 2
- y_max = y_center + x_range / 2
- else:
- # Expand x-range to match y-range aspect ratio
- x_center = (x_max + x_min) / 2
- x_min = x_center - y_range / 2
- x_max = x_center + y_range / 2
-
- return (x_min, x_max), (y_min, y_max)
-
- def _create_track_colored_figure(
- self,
- show_arrows=False,
- x_axis=None,
- y_axis=None,
- ):
- """Create scatter plot with track-based coloring"""
- x_axis = x_axis or self.default_x
- y_axis = y_axis or self.default_y
-
- unique_tracks = self.filtered_features_df["track_id"].unique()
- cmap = plt.cm.tab20
- track_colors = {
- track_id: f"rgb{tuple(int(x * 255) for x in cmap(i % 20)[:3])}"
- for i, track_id in enumerate(unique_tracks)
- }
-
- fig = go.Figure()
-
- # Set initial layout with lasso mode
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- selectdirection="any",
- plot_bgcolor="white",
- title="PCA visualization of Selected Tracks",
- xaxis_title=x_axis,
- yaxis_title=y_axis,
- uirevision=True,
- hovermode="closest",
- showlegend=True,
- legend=dict(
- yanchor="top",
- y=1,
- xanchor="left",
- x=1.02,
- title="Tracks",
- bordercolor="Black",
- borderwidth=1,
- ),
- margin=dict(l=50, r=150, t=50, b=50),
- autosize=True,
- )
- fig.update_xaxes(showgrid=False)
- fig.update_yaxes(showgrid=False)
-
- # Add background points with hover info (excluding the colored tracks)
- background_df = self.features_df[
- (self.features_df["fov_name"].isin(self.fov_tracks.keys()))
- & (~self.features_df["track_id"].isin(unique_tracks))
- ]
-
- if not background_df.empty:
- # Subsample background points if there are too many
- if len(background_df) > 5000: # Adjust this threshold as needed
- background_df = background_df.sample(n=5000, random_state=42)
-
- fig.add_trace(
- go.Scattergl(
- x=background_df[x_axis],
- y=background_df[y_axis],
- mode="markers",
- marker=dict(size=12, color="lightgray", opacity=0.3),
- name=f"Other tracks (showing {len(background_df)} of {len(self.features_df)} points)",
- text=[
- f"Track: {track_id}
Time: {t}
FOV: {fov}"
- for track_id, t, fov in zip(
- background_df["track_id"],
- background_df["t"],
- background_df["fov_name"],
- )
- ],
- hoverinfo="text",
- showlegend=True,
- hoverlabel=dict(namelength=-1),
- )
- )
-
- # Add points for each selected track
- for track_id in unique_tracks:
- track_data = self.filtered_features_df[
- self.filtered_features_df["track_id"] == track_id
- ].sort_values("t")
-
- # Get points for this track that are in clusters
- track_points = list(
- zip(
- [fov for fov in track_data["fov_name"]],
- [track_id] * len(track_data),
- [t for t in track_data["t"]],
- )
- )
-
- # Determine colors based on cluster membership
- colors = []
- opacities = []
- if self.clusters:
- cluster_colors = [
- f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}"
- for i in range(len(self.clusters))
- ]
- point_to_cluster = {}
- for cluster_idx, cluster in enumerate(self.clusters):
- for point in cluster:
- point_key = (point["fov_name"], point["track_id"], point["t"])
- point_to_cluster[point_key] = cluster_idx
-
- for point in track_points:
- if point in point_to_cluster:
- colors.append(cluster_colors[point_to_cluster[point]])
- opacities.append(1.0)
- else:
- colors.append("lightgray")
- opacities.append(0.3)
- else:
- colors = [track_colors[track_id]] * len(track_data)
- opacities = [1.0] * len(track_data)
-
- # Add points using Scattergl for better performance
- scatter_kwargs = {
- "x": track_data[x_axis],
- "y": track_data[y_axis],
- "mode": "markers",
- "marker": dict(
- size=10, # Reduced size
- color=colors,
- line=dict(width=1, color="black"),
- opacity=opacities,
- ),
- "name": f"Track {track_id}",
- "text": [
- f"Track: {track_id}
Time: {t}
FOV: {fov}"
- for t, fov in zip(track_data["t"], track_data["fov_name"])
- ],
- "hoverinfo": "text",
- "hoverlabel": dict(namelength=-1), # Show full text in hover
- }
-
- # Only apply selection properties if there are clusters
- # This prevents opacity conflicts when no clusters exist
- if self.clusters:
- scatter_kwargs.update(
- {
- "unselected": dict(marker=dict(opacity=0.3, size=10)),
- "selected": dict(marker=dict(size=12, opacity=1.0)),
- }
- )
-
- fig.add_trace(go.Scattergl(**scatter_kwargs))
-
- # Add trajectory lines and arrows if requested
- if show_arrows and len(track_data) > 1:
- x_coords = track_data[x_axis].values
- y_coords = track_data[y_axis].values
-
- # Add dashed lines for the trajectory using Scattergl
- fig.add_trace(
- go.Scattergl(
- x=x_coords,
- y=y_coords,
- mode="lines",
- line=dict(
- color=track_colors[track_id],
- width=1,
- dash="dot",
- ),
- showlegend=False,
- hoverinfo="skip",
- )
- )
-
- # Add arrows at regular intervals (reduced frequency)
- arrow_interval = max(
- 1, len(track_data) // 3
- ) # Reduced number of arrows
- for i in range(0, len(track_data) - 1, arrow_interval):
- # Calculate arrow angle
- dx = x_coords[i + 1] - x_coords[i]
- dy = y_coords[i + 1] - y_coords[i]
-
- # Only add arrow if there's significant movement
- if dx * dx + dy * dy > 1e-6: # Minimum distance threshold
- # Add arrow annotation
- fig.add_annotation(
- x=x_coords[i + 1],
- y=y_coords[i + 1],
- ax=x_coords[i],
- ay=y_coords[i],
- xref="x",
- yref="y",
- axref="x",
- ayref="y",
- showarrow=True,
- arrowhead=2,
- arrowsize=1, # Reduced size
- arrowwidth=1, # Reduced width
- arrowcolor=track_colors[track_id],
- opacity=0.8,
- )
-
- # Compute axis ranges to ensure equal aspect ratio
- all_x_data = self.filtered_features_df[x_axis]
- all_y_data = self.filtered_features_df[y_axis]
-
- if not all_x_data.empty and not all_y_data.empty:
- x_range, y_range = self._calculate_equal_aspect_ranges(
- all_x_data, all_y_data
- )
-
- # Set equal aspect ratio and range
- fig.update_layout(
- xaxis=dict(
- range=x_range, scaleanchor="y", scaleratio=1, constrain="domain"
- ),
- yaxis=dict(range=y_range, constrain="domain"),
- )
-
- return fig
-
- def _create_time_colored_figure(
- self,
- show_arrows=False,
- x_axis=None,
- y_axis=None,
- ):
- """Create scatter plot with time-based coloring"""
- x_axis = x_axis or self.default_x
- y_axis = y_axis or self.default_y
-
- fig = go.Figure()
-
- # Set initial layout with lasso mode
- fig.update_layout(
- dragmode="lasso",
- clickmode="event+select",
- selectdirection="any",
- plot_bgcolor="white",
- title="PCA visualization of Selected Tracks",
- xaxis_title=x_axis,
- yaxis_title=y_axis,
- uirevision=True,
- hovermode="closest",
- showlegend=True,
- legend=dict(
- yanchor="top",
- y=1,
- xanchor="left",
- x=1.02,
- title="Tracks",
- bordercolor="Black",
- borderwidth=1,
- ),
- margin=dict(l=50, r=150, t=50, b=50),
- autosize=True,
- )
- fig.update_xaxes(showgrid=False)
- fig.update_yaxes(showgrid=False)
-
- # Add background points with hover info
- all_tracks_df = self.features_df[
- self.features_df["fov_name"].isin(self.fov_tracks.keys())
- ]
-
- # Subsample background points if there are too many
- if len(all_tracks_df) > 5000: # Adjust this threshold as needed
- all_tracks_df = all_tracks_df.sample(n=5000, random_state=42)
-
- fig.add_trace(
- go.Scattergl(
- x=all_tracks_df[x_axis],
- y=all_tracks_df[y_axis],
- mode="markers",
- marker=dict(size=12, color="lightgray", opacity=0.3),
- name=f"Other points (showing {len(all_tracks_df)} of {len(self.features_df)} points)",
- text=[
- f"Track: {track_id}
Time: {t}
FOV: {fov}"
- for track_id, t, fov in zip(
- all_tracks_df["track_id"],
- all_tracks_df["t"],
- all_tracks_df["fov_name"],
- )
- ],
- hoverinfo="text",
- hoverlabel=dict(namelength=-1),
- )
- )
-
- # Add time-colored points using Scattergl
- fig.add_trace(
- go.Scattergl(
- x=self.filtered_features_df[x_axis],
- y=self.filtered_features_df[y_axis],
- mode="markers",
- marker=dict(
- size=10, # Reduced size
- color=self.filtered_features_df["t"],
- colorscale="Viridis",
- colorbar=dict(title="Time"),
- ),
- text=[
- f"Track: {track_id}
Time: {t}
FOV: {fov}"
- for track_id, t, fov in zip(
- self.filtered_features_df["track_id"],
- self.filtered_features_df["t"],
- self.filtered_features_df["fov_name"],
- )
- ],
- hoverinfo="text",
- showlegend=False,
- hoverlabel=dict(namelength=-1), # Show full text in hover
- )
- )
-
- # Add arrows if requested, but more efficiently
- if show_arrows:
- for track_id in self.filtered_features_df["track_id"].unique():
- track_data = self.filtered_features_df[
- self.filtered_features_df["track_id"] == track_id
- ].sort_values("t")
-
- if len(track_data) > 1:
- # Calculate distances between consecutive points
- x_coords = track_data[x_axis].values
- y_coords = track_data[y_axis].values
- distances = np.sqrt(np.diff(x_coords) ** 2 + np.diff(y_coords) ** 2)
-
- # Only show arrows for movements larger than the median distance
- threshold = np.median(distances) * 0.5
-
- # Add arrows as a single trace
- arrow_x = []
- arrow_y = []
-
- for i in range(len(track_data) - 1):
- if distances[i] > threshold:
- arrow_x.extend([x_coords[i], x_coords[i + 1], None])
- arrow_y.extend([y_coords[i], y_coords[i + 1], None])
-
- if arrow_x: # Only add if there are arrows to show
- fig.add_trace(
- go.Scatter(
- x=arrow_x,
- y=arrow_y,
- mode="lines",
- line=dict(
- color="rgba(128, 128, 128, 0.5)",
- width=1,
- dash="dot",
- ),
- showlegend=False,
- hoverinfo="skip",
- )
- )
-
- # Compute axis ranges to ensure equal aspect ratio
- all_x_data = self.filtered_features_df[x_axis]
- all_y_data = self.filtered_features_df[y_axis]
- if not all_x_data.empty and not all_y_data.empty:
- x_range, y_range = self._calculate_equal_aspect_ranges(
- all_x_data, all_y_data
- )
-
- # Set equal aspect ratio and range
- fig.update_layout(
- xaxis=dict(
- range=x_range, scaleanchor="y", scaleratio=1, constrain="domain"
- ),
- yaxis=dict(range=y_range, constrain="domain"),
- )
-
- return fig
-
- @staticmethod
- def _normalize_image(img_array):
- """Normalize a single image array to [0, 255] more efficiently"""
- min_val = img_array.min()
- max_val = img_array.max()
- if min_val == max_val:
- return np.zeros_like(img_array, dtype=np.uint8)
- # Normalize in one step
- return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8)
-
- @staticmethod
- def _numpy_to_base64(img_array):
- """Convert numpy array to base64 string with compression"""
- if not isinstance(img_array, np.uint8):
- img_array = img_array.astype(np.uint8)
- img = Image.fromarray(img_array)
- buffered = BytesIO()
- # Use JPEG format with quality=85 for better compression
- img.save(buffered, format="JPEG", quality=85, optimize=True)
- return "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode(
- "utf-8"
- )
-
- def save_cache(self, cache_path: str | None = None):
- """Save the image cache to disk using pickle.
-
- Parameters
- ----------
- cache_path : str | None, optional
- Path to save the cache. If None, uses self.cache_path, by default None
- """
- import pickle
-
- if cache_path is None:
- if self.cache_path is None:
- logger.warning("No cache path specified, skipping cache save")
- return
- cache_path = self.cache_path
- else:
- cache_path = Path(cache_path)
-
- # Create parent directory if it doesn't exist
- cache_path.parent.mkdir(parents=True, exist_ok=True)
-
- # Save cache metadata for validation
- cache_metadata = {
- "data_path": str(self.data_path),
- "tracks_path": str(self.tracks_path),
- "features_path": str(self.features_path),
- "channels": self.channels_to_display,
- "z_range": self.z_range,
- "yx_patch_size": self.yx_patch_size,
- "cache_size": len(self.image_cache),
- }
-
- try:
- logger.info(f"Saving image cache to {cache_path}")
- with open(cache_path, "wb") as f:
- pickle.dump((cache_metadata, self.image_cache), f)
- logger.info(f"Successfully saved cache with {len(self.image_cache)} images")
- except Exception as e:
- logger.error(f"Error saving cache: {e}")
-
- def load_cache(self, cache_path: str | None = None) -> bool:
- """Load the image cache from disk using pickle.
-
- Parameters
- ----------
- cache_path : str | None, optional
- Path to load the cache from. If None, uses self.cache_path, by default None
-
- Returns
- -------
- bool
- True if cache was successfully loaded, False otherwise
- """
- import pickle
-
- if cache_path is None:
- if self.cache_path is None:
- logger.warning("No cache path specified, skipping cache load")
- return False
- cache_path = self.cache_path
- else:
- cache_path = Path(cache_path)
-
- if not cache_path.exists():
- logger.warning(f"Cache file {cache_path} does not exist")
- return False
-
- try:
- logger.info(f"Loading image cache from {cache_path}")
- with open(cache_path, "rb") as f:
- cache_metadata, loaded_cache = pickle.load(f)
-
- # Validate cache metadata
- if (
- cache_metadata["data_path"] != str(self.data_path)
- or cache_metadata["tracks_path"] != str(self.tracks_path)
- or cache_metadata["features_path"] != str(self.features_path)
- or cache_metadata["channels"] != self.channels_to_display
- or cache_metadata["z_range"] != self.z_range
- or cache_metadata["yx_patch_size"] != self.yx_patch_size
- ):
- logger.warning("Cache metadata mismatch, skipping cache load")
- return False
-
- self.image_cache = loaded_cache
- logger.info(
- f"Successfully loaded cache with {len(self.image_cache)} images"
- )
- return True
- except Exception as e:
- logger.error(f"Error loading cache: {e}")
- return False
-
- def preload_images(self):
- """Preload all images into memory"""
- # Try to load from cache first
- if self.cache_path and self.load_cache():
- return
-
- logger.info("Preloading images into cache...")
- logger.info(f"FOVs to process: {list(self.filtered_tracks_by_fov.keys())}")
-
- # Process each FOV and its tracks
- for fov_name, track_ids in self.filtered_tracks_by_fov.items():
- if not track_ids: # Skip FOVs with no tracks
- logger.info(f"Skipping FOV {fov_name} as it has no tracks")
- continue
-
- logger.info(f"Processing FOV {fov_name} with tracks {track_ids}")
-
- try:
- data_module = TripletDataModule(
- data_path=self.data_path,
- tracks_path=self.tracks_path,
- include_fov_names=[fov_name] * len(track_ids),
- include_track_ids=track_ids,
- source_channel=self.channels_to_display,
- z_range=self.z_range,
- initial_yx_patch_size=self.yx_patch_size,
- final_yx_patch_size=self.yx_patch_size,
- batch_size=1,
- num_workers=self.num_loading_workers,
- normalizations=None,
- predict_cells=True,
- )
- data_module.setup("predict")
-
- for batch in data_module.predict_dataloader():
- try:
- images = batch["anchor"].numpy()
- indices = batch["index"]
- track_id = indices["track_id"].tolist()
- t = indices["t"].tolist()
-
- img = np.stack(images)
- cache_key = (fov_name, track_id[0], t[0])
-
- logger.debug(f"Processing cache key: {cache_key}")
-
- # Process each channel based on its type
- processed_channels = {}
- for idx, channel in enumerate(self.channels_to_display):
- try:
- if channel in ["Phase3D", "DIC", "BF"]:
- # For phase contrast, use the middle z-slice
- z_idx = (self.z_range[1] - self.z_range[0]) // 2
- processed = self._normalize_image(
- img[0, idx, z_idx]
- )
- else:
- # For fluorescence, use max projection
- processed = self._normalize_image(
- np.max(img[0, idx], axis=0)
- )
-
- processed_channels[channel] = self._numpy_to_base64(
- processed
- )
- logger.debug(
- f"Successfully processed channel {channel} for {cache_key}"
- )
- except Exception as e:
- logger.error(
- f"Error processing channel {channel} for {cache_key}: {e}"
- )
- continue
-
- if (
- processed_channels
- ): # Only store if at least one channel was processed
- self.image_cache[cache_key] = processed_channels
-
- except Exception as e:
- logger.error(
- f"Error processing batch for {fov_name}, track {track_id}: {e}"
- )
- continue
-
- except Exception as e:
- logger.error(f"Error setting up data module for FOV {fov_name}: {e}")
- continue
-
- logger.info(f"Successfully cached {len(self.image_cache)} images")
- # Log some statistics about the cache
- cached_fovs = set(key[0] for key in self.image_cache.keys())
- cached_tracks = set((key[0], key[1]) for key in self.image_cache.keys())
- logger.info(f"Cached FOVs: {cached_fovs}")
- logger.info(f"Number of unique track-FOV combinations: {len(cached_tracks)}")
-
- # Save cache if path is specified
- if self.cache_path:
- self.save_cache()
-
- def _cleanup_cache(self):
- """Clear the image cache when the program exits"""
- logging.info("Cleaning up image cache...")
- self.image_cache.clear()
-
- def _get_trajectory_images_lasso(self, x_axis, y_axis, selected_data):
- """Get images of points selected by lasso"""
- if not selected_data or not selected_data.get("points"):
- return html.Div("Use the lasso tool to select points")
-
- # Dictionary to store points for each lasso selection
- lasso_clusters = {}
-
- # Track which points we've seen to avoid duplicates within clusters
- seen_points = set()
-
- # Process each selected point
- for point in selected_data["points"]:
- text = point["text"]
- lines = text.split("
")
- track_id = int(lines[0].split(": ")[1])
- t = int(lines[1].split(": ")[1])
- fov = lines[2].split(": ")[1]
-
- point_id = (track_id, t, fov)
- cache_key = (fov, track_id, t)
-
- # Skip if we don't have the image in cache
- if cache_key not in self.image_cache:
- logger.debug(f"Skipping point {point_id} as it's not in the cache")
- continue
-
- # Determine which curve (lasso selection) this point belongs to
- curve_number = point.get("curveNumber", 0)
- if curve_number not in lasso_clusters:
- lasso_clusters[curve_number] = []
-
- # Only add if we haven't seen this point in this cluster
- cluster_point_id = (curve_number, point_id)
- if cluster_point_id not in seen_points:
- seen_points.add(cluster_point_id)
- lasso_clusters[curve_number].append(
- {
- "track_id": track_id,
- "t": t,
- "fov_name": fov,
- x_axis: point["x"],
- y_axis: point["y"],
- }
- )
-
- if not lasso_clusters:
- return html.Div("No cached images found for the selected points")
-
- # Create sections for each lasso selection
- cluster_sections = []
- for cluster_idx, points in lasso_clusters.items():
- cluster_df = pd.DataFrame(points)
-
- # Create channel rows for this cluster
- channel_rows = []
- for channel in self.channels_to_display:
- images = []
- for _, row in cluster_df.iterrows():
- cache_key = (row["fov_name"], row["track_id"], row["t"])
- images.append(
- html.Div(
- [
- html.Img(
- src=self.image_cache[cache_key][channel],
- style={
- "width": "150px",
- "height": "150px",
- "margin": "5px",
- "border": "1px solid #ddd",
- },
- ),
- html.Div(
- f"Track {row['track_id']}, t={row['t']}",
- style={
- "textAlign": "center",
- "fontSize": "12px",
- },
- ),
- ],
- style={
- "display": "inline-block",
- "margin": "5px",
- "verticalAlign": "top",
- },
- )
- )
-
- if images: # Only add row if there are images
- channel_rows.extend(
- [
- html.H5(
- f"{channel}",
- style={
- "margin": "10px 5px",
- "fontSize": "16px",
- "fontWeight": "bold",
- },
- ),
- html.Div(
- images,
- style={
- "overflowX": "auto",
- "whiteSpace": "nowrap",
- "padding": "10px",
- "border": "1px solid #ddd",
- "borderRadius": "5px",
- "marginBottom": "20px",
- "backgroundColor": "#f8f9fa",
- },
- ),
- ]
- )
-
- if channel_rows: # Only add cluster section if it has images
- cluster_sections.append(
- html.Div(
- [
- html.H3(
- f"Lasso Selection {cluster_idx + 1}",
- style={
- "marginTop": "30px",
- "marginBottom": "15px",
- "fontSize": "24px",
- "fontWeight": "bold",
- "borderBottom": "2px solid #007bff",
- "paddingBottom": "5px",
- },
- ),
- html.Div(
- channel_rows,
- style={
- "backgroundColor": "#ffffff",
- "padding": "15px",
- "borderRadius": "8px",
- "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
- },
- ),
- ]
- )
- )
-
- return html.Div(
- [
- html.H2(
- f"Selected Points ({len(cluster_sections)} selections)",
- style={
- "marginBottom": "20px",
- "fontSize": "28px",
- "fontWeight": "bold",
- "color": "#2c3e50",
- },
- ),
- html.Div(cluster_sections),
- ]
- )
-
- def _get_output_info_display(self) -> html.Div:
- """
- Create a display component showing the output directory information.
-
- Returns
- -------
- html.Div
- HTML component displaying output directory info
- """
- return html.Div(
- [
- html.H4(
- "Output Directory",
- style={"marginBottom": "10px", "fontSize": "16px"},
- ),
- html.Div(
- [
- html.Span("📁 ", style={"fontSize": "14px"}),
- html.Span(
- str(self.output_dir),
- style={
- "fontFamily": "monospace",
- "backgroundColor": "#f8f9fa",
- "padding": "4px 8px",
- "borderRadius": "4px",
- "border": "1px solid #dee2e6",
- "fontSize": "12px",
- },
- ),
- ],
- style={"marginBottom": "10px"},
- ),
- html.Div(
- "CSV files will be saved to this directory with timestamped names.",
- style={
- "fontSize": "12px",
- "color": "#6c757d",
- "fontStyle": "italic",
- },
- ),
- ],
- style={
- "backgroundColor": "#e9ecef",
- "padding": "10px",
- "borderRadius": "6px",
- "marginBottom": "15px",
- "border": "1px solid #ced4da",
- },
- )
-
- def _get_cluster_images(self):
- """Display images for all clusters in a grid layout"""
- if not self.clusters:
- return html.Div(
- [self._get_output_info_display(), html.Div("No clusters created yet")]
- )
-
- # Create cluster colors once
- cluster_colors = [
- f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}"
- for i in range(len(self.clusters))
- ]
-
- # Create individual cluster panels
- cluster_panels = []
- for cluster_idx, cluster_points in enumerate(self.clusters):
- # Get cluster name or use default
- cluster_name = self.cluster_names.get(
- cluster_idx, f"Cluster {cluster_idx + 1}"
- )
-
- # Create a single scrollable container for all channels
- all_channel_images = []
- for channel in self.channels_to_display:
- images = []
- for point in cluster_points:
- cache_key = (point["fov_name"], point["track_id"], point["t"])
-
- images.append(
- html.Div(
- [
- html.Img(
- src=self.image_cache[cache_key][channel],
- style={
- "width": "100px",
- "height": "100px",
- "margin": "2px",
- "border": f"2px solid {cluster_colors[cluster_idx]}",
- "borderRadius": "4px",
- },
- ),
- html.Div(
- f"Track {point['track_id']}, t={point['t']}",
- style={
- "textAlign": "center",
- "fontSize": "10px",
- },
- ),
- ],
- style={
- "display": "inline-block",
- "margin": "2px",
- "verticalAlign": "top",
- },
- )
- )
-
- if images:
- all_channel_images.extend(
- [
- html.H6(
- f"{channel}",
- style={
- "margin": "5px",
- "fontSize": "12px",
- "fontWeight": "bold",
- "position": "sticky",
- "left": "0",
- "backgroundColor": "#f8f9fa",
- "zIndex": "1",
- "paddingLeft": "5px",
- },
- ),
- html.Div(
- images,
- style={
- "whiteSpace": "nowrap",
- "marginBottom": "10px",
- },
- ),
- ]
- )
-
- if all_channel_images:
- # Create a panel for this cluster with synchronized scrolling
- cluster_panels.append(
- html.Div(
- [
- html.Div(
- [
- html.Span(
- cluster_name,
- style={
- "color": cluster_colors[cluster_idx],
- "fontWeight": "bold",
- "fontSize": "16px",
- },
- ),
- html.Span(
- f" ({len(cluster_points)} points)",
- style={
- "color": "#2c3e50",
- "fontSize": "14px",
- },
- ),
- html.Button(
- "✏️",
- id={
- "type": "edit-cluster-name",
- "index": cluster_idx,
- },
- style={
- "backgroundColor": "transparent",
- "border": "none",
- "cursor": "pointer",
- "fontSize": "12px",
- "marginLeft": "5px",
- "color": "#6c757d",
- },
- title="Edit cluster name",
- ),
- ],
- style={
- "marginBottom": "10px",
- "borderBottom": f"2px solid {cluster_colors[cluster_idx]}",
- "paddingBottom": "5px",
- "position": "sticky",
- "top": "0",
- "backgroundColor": "white",
- "zIndex": "1",
- },
- ),
- html.Div(
- all_channel_images,
- style={
- "overflowX": "auto",
- "overflowY": "auto",
- "height": "400px",
- "backgroundColor": "#ffffff",
- "padding": "10px",
- "borderRadius": "8px",
- "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
- },
- ),
- ],
- style={
- "width": "24%",
- "display": "inline-block",
- "verticalAlign": "top",
- "padding": "5px",
- "boxSizing": "border-box",
- },
- )
- )
-
- # Create rows of 4 panels each
- rows = []
- for i in range(0, len(cluster_panels), 4):
- row = html.Div(
- cluster_panels[i : i + 4],
- style={
- "display": "flex",
- "justifyContent": "flex-start",
- "gap": "10px",
- "marginBottom": "10px",
- },
- )
- rows.append(row)
-
- return html.Div(
- [
- html.H2(
- [
- "Clusters ",
- html.Span(
- f"({len(self.clusters)} total)",
- style={"color": "#666"},
- ),
- ],
- style={
- "marginBottom": "20px",
- "fontSize": "28px",
- "fontWeight": "bold",
- "color": "#2c3e50",
- },
- ),
- self._get_output_info_display(),
- html.Div(
- rows,
- style={
- "maxHeight": "calc(100vh - 200px)",
- "overflowY": "auto",
- "padding": "10px",
- },
- ),
- ]
- )
-
- def get_output_dir(self) -> Path:
- """
- Get the output directory for saving files.
-
- Returns
- -------
- Path
- The output directory path
- """
- return self.output_dir
-
- def save_clusters_to_csv(self, output_path: str | None = None) -> str:
- """
- Save cluster information to CSV file.
-
- This method exports all cluster data including track_id, time, FOV,
- cluster assignment, and cluster names to a CSV file for further analysis.
-
- Parameters
- ----------
- output_path : str | None, optional
- Path to save the CSV file. If None, generates a timestamped filename
- in the output directory, by default None
-
- Returns
- -------
- str
- Path to the saved CSV file
-
- Notes
- -----
- The CSV will contain columns:
- - cluster_id: The cluster number (1-indexed)
- - cluster_name: The custom name assigned to the cluster
- - track_id: The track identifier
- - time: The timepoint
- - fov_name: The field of view name
- - cluster_size: Number of points in the cluster
- """
- if not self.clusters:
- logger.warning("No clusters to save")
- return ""
-
- # Prepare data for CSV export
- csv_data = []
- for cluster_idx, cluster in enumerate(self.clusters):
- cluster_id = cluster_idx + 1 # 1-indexed for user-friendly output
- cluster_size = len(cluster)
- cluster_name = self.cluster_names.get(cluster_idx, f"Cluster {cluster_id}")
-
- for point in cluster:
- csv_data.append(
- {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "track_id": point["track_id"],
- "time": point["t"],
- "fov_name": point["fov_name"],
- "cluster_size": cluster_size,
- }
- )
-
- # Create DataFrame and save to CSV
- df = pd.DataFrame(csv_data)
-
- if output_path is None:
- # Generate timestamped filename in output directory
- from datetime import datetime
-
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- output_path = self.output_dir / f"clusters_{timestamp}.csv"
- else:
- output_path = Path(output_path)
- # If only filename is provided, use output directory
- if not output_path.parent.name:
- output_path = self.output_dir / output_path.name
-
- try:
- # Create parent directory if it doesn't exist
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- df.to_csv(output_path, index=False)
- logger.info(f"Successfully saved {len(df)} cluster points to {output_path}")
- return str(output_path)
-
- except Exception as e:
- logger.error(f"Error saving clusters to CSV: {e}")
- raise
-
- def run(self, debug=False, port=None):
- """Run the Dash server
-
- Parameters
- ----------
- debug : bool, optional
- Whether to run in debug mode, by default False
- port : int, optional
- Port to run on. If None, will try ports from 8050-8070, by default None
- """
- import socket
-
- def is_port_in_use(port):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- try:
- s.bind(("127.0.0.1", port))
- return False
- except socket.error:
- return True
-
- if port is None:
- # Try ports from 8050 to 8070
- # FIXME: set a range for the ports
- port_range = list(range(8050, 8071))
- for p in port_range:
- if not is_port_in_use(p):
- port = p
- break
- if port is None:
- raise RuntimeError(
- f"Could not find an available port in range {port_range[0]}-{port_range[-1]}"
- )
-
- try:
- logger.info(f"Starting server on port {port}")
- self.app.run(
- debug=debug,
- port=port,
- use_reloader=False, # Disable reloader to prevent multiple instances
- )
- except KeyboardInterrupt:
- logger.info("Server shutdown requested...")
- except Exception as e:
- logger.error(f"Error running server: {e}")
- finally:
- self._cleanup_cache()
- logger.info("Server shutdown complete")
diff --git a/viscy/representation/visualization/__init__.py b/viscy/representation/visualization/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/viscy/representation/visualization/app.py b/viscy/representation/visualization/app.py
new file mode 100644
index 000000000..9403317ed
--- /dev/null
+++ b/viscy/representation/visualization/app.py
@@ -0,0 +1,2706 @@
+import atexit
+import json
+import logging
+import time
+from pathlib import Path
+from typing import Union
+
+import dash
+import dash.dependencies as dd
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
+from dash import dcc, html
+from sklearn.decomposition import PCA
+from sklearn.preprocessing import StandardScaler
+
+from viscy.representation.embedding_writer import read_embedding_dataset
+from viscy.representation.evaluation.combined_analysis import (
+ compute_phate_for_combined_datasets,
+ load_and_combine_features,
+)
+from viscy.representation.evaluation.data_loading import EmbeddingDataLoader
+from viscy.representation.evaluation.dimensionality_reduction import (
+ compute_pca,
+)
+from viscy.representation.visualization.cluster import (
+ ClusterManager,
+)
+from viscy.representation.visualization.settings import VizConfig
+
+logger = logging.getLogger("lightning.pytorch")
+
+
+class EmbeddingVisualizationApp:
+ def __init__(
+ self,
+ viz_config: VizConfig,
+ cache_path: str | None = None,
+ num_loading_workers: int = 16,
+ output_dir: str | None = None,
+ loader: EmbeddingDataLoader | None = None,
+ ) -> None:
+ """
+ Initialize a Dash application for visualizing the DynaCLR embeddings.
+
+ This class provides a visualization tool for visualizing the DynaCLR embeddings into a 2D space (e.g. PCA, UMAP, PHATE).
+ It allows users to interactively explore and analyze trajectories, visualize clusters, and explore the embedding space.
+
+ Parameters
+ ----------
+ viz_config: VizConfig
+ Configuration object for visualization.
+ cache_path: str | None
+ Path to the cache directory.
+ num_loading_workers: int
+ Number of workers to use for loading data.
+ output_dir: str | None, optional
+ Directory to save CSV files and other outputs. If None, uses current working directory.
+ loader: EmbeddingDataLoader | None
+ Custom data loader for embeddings. If None, uses TripletEmbeddingLoader.
+ Returns
+ -------
+ None
+ Initializes the visualization app.
+ """
+ self.viz_config = viz_config
+ self.image_cache = {}
+ self.cache_path = Path(cache_path) if cache_path else None
+ self.output_dir = Path(output_dir) if output_dir else Path.cwd()
+ self.app = None
+ self.features_df: pd.DataFrame | None = None
+ self.fig = None
+ self.num_loading_workers = num_loading_workers
+ self.loader = loader
+
+ # Initialize cluster storage before preparing data and creating figure
+ self.cluster_manager = ClusterManager()
+
+ # Store datasets for per-dataset access
+ self.datasets = viz_config.get_datasets()
+ self._DEFAULT_MARKER_SIZE = 15
+
+ # Debouncing for dropdown updates
+ # TODO: check if this hack works for plotting to be successful
+ self._last_update_time = 0
+ self._debounce_delay = 0.3 # 300ms debounce delay
+
+ # Initialize data
+ self._prepare_data()
+ self._create_figure()
+ self._init_app()
+ atexit.register(self._cleanup_cache)
+
+ def _prepare_data(self):
+ """Load and prepare the data for visualization"""
+ # Extract feature paths and dataset names
+ feature_paths = [
+ Path(config.features_path) for config in self.datasets.values()
+ ]
+ dataset_names = list(self.datasets.keys())
+
+ if self.viz_config.phate_kwargs is not None and len(self.datasets) > 1:
+ self._prepare_data_with_combined_phate(feature_paths, dataset_names)
+ else:
+ self._loading_and_prepare_data(feature_paths, dataset_names)
+
+ def _prepare_data_with_combined_phate(self, feature_paths, dataset_names):
+ """Prepare data using the new combined PHATE approach"""
+ # Set up combined PHATE cache path
+ if self.viz_config.combined_phate_cache_path:
+ combined_cache_path = Path(self.viz_config.combined_phate_cache_path)
+ elif self.cache_path:
+ combined_cache_path = self.cache_path / "combined_phate.zarr"
+ else:
+ combined_cache_path = Path.cwd() / "combined_phate.zarr"
+
+ # Check if we should use cached results
+ use_cache = (
+ self.viz_config.use_cached_combined_phate and combined_cache_path.exists()
+ )
+
+ if use_cache:
+ logger.info(
+ f"Loading cached combined PHATE results from {combined_cache_path}"
+ )
+ try:
+ combined_dataset = read_embedding_dataset(combined_cache_path)
+ self.features_df = self.loader.extract_metadata(combined_dataset)
+ combined_embeddings = combined_dataset["features"].values
+ logger.info(
+ f"Loaded cached combined dataset with {len(self.features_df)} samples"
+ )
+ except Exception as e:
+ logger.warning(
+ f"Failed to load cached results: {e}. Computing fresh PHATE."
+ )
+ use_cache = False
+
+ if not use_cache:
+ logger.info("Computing fresh combined PHATE embeddings")
+ try:
+ if self.loader is not None:
+ combined_dataset = compute_phate_for_combined_datasets(
+ feature_paths=feature_paths,
+ output_path=combined_cache_path,
+ dataset_names=dataset_names,
+ phate_kwargs=self.viz_config.phate_kwargs,
+ overwrite=True,
+ loader=self.loader,
+ )
+ else:
+ combined_dataset = compute_phate_for_combined_datasets(
+ feature_paths=feature_paths,
+ output_path=combined_cache_path,
+ dataset_names=dataset_names,
+ phate_kwargs=self.viz_config.phate_kwargs,
+ overwrite=True,
+ )
+ self.features_df = combined_dataset.to_dataframe().reset_index()
+ combined_embeddings = combined_dataset["features"].values
+ logger.info(
+ f"Successfully computed combined PHATE with {len(self.features_df)} samples"
+ )
+ except Exception as e:
+ logger.error(f"Combined PHATE computation failed: {e}")
+ # Fall back to traditional approach
+ self._prepare_data_traditional(feature_paths, dataset_names)
+ return
+
+ if "dataset_pair" in self.features_df.columns:
+ self.features_df["dataset"] = self.features_df["dataset_pair"]
+
+ self._compute_pca_on_combined_embeddings(combined_embeddings)
+
+ def _loading_and_prepare_data(self, feature_paths, dataset_names):
+ """Load and prepare data using modular functions"""
+ if self.loader is not None:
+ combined_embeddings, combined_indices = load_and_combine_features(
+ feature_paths, dataset_names, self.loader
+ )
+ else:
+ # Let the function use its default TripletEmbeddingLoader
+ combined_embeddings, combined_indices = load_and_combine_features(
+ feature_paths, dataset_names
+ )
+
+ # Convert to DataFrame and rename dataset_pair to dataset for compatibility
+ self.features_df = combined_indices.copy()
+ if "dataset_pair" in self.features_df.columns:
+ self.features_df["dataset"] = self.features_df["dataset_pair"]
+
+ logger.info(f"Combined embeddings shape: {combined_embeddings.shape}")
+
+ # Compute PCA and PHATE on combined embeddings
+ self._compute_pca_on_combined_embeddings(combined_embeddings)
+ self._compute_phate_on_combined_embeddings(combined_embeddings)
+
+ def _compute_pca_on_combined_embeddings(self, combined_embeddings):
+ """Compute PCA on combined embeddings and set up dimension options"""
+ # Check if dimensionality reduction columns already exist
+ existing_dims = []
+ dim_options = []
+
+ # Always recompute PCA on combined embeddings for multi-dataset scenario
+ if len(self.datasets) > 1 or not any(
+ col.startswith("PCA") for col in self.features_df.columns
+ ):
+ logger.info(
+ f"Computing PCA with {self.viz_config.num_PC_components} components on combined embeddings"
+ )
+
+ # Use the compute_pca function
+ pca_coords, _ = compute_pca(
+ combined_embeddings,
+ n_components=self.viz_config.num_PC_components,
+ normalize_features=True,
+ )
+
+ # We need to get the explained variance separately since compute_pca doesn't return the model
+ # FIXME: ideally the compute_pca function should return the model
+ scaler = StandardScaler()
+ scaled_features = scaler.fit_transform(combined_embeddings)
+ pca_model = PCA(
+ n_components=self.viz_config.num_PC_components, random_state=42
+ )
+ pca_model.fit(scaled_features)
+
+ # Store explained variance for PCA labels
+ self.pca_explained_variance = [
+ f"PC{i + 1} ({var:.1f}%)"
+ for i, var in enumerate(pca_model.explained_variance_ratio_ * 100)
+ ]
+
+ # Add PCA coordinates to the features dataframe
+ for i in range(self.viz_config.num_PC_components):
+ self.features_df[f"PCA{i + 1}"] = pca_coords[:, i]
+
+ # Add PCA options to dropdown
+ for i, pc_label in enumerate(self.pca_explained_variance):
+ dim_options.append({"label": pc_label, "value": f"PCA{i + 1}"})
+ existing_dims.append(f"PCA{i + 1}")
+
+ # Check for existing PHATE coordinates (if they exist in the data already)
+ phate_dims = [
+ col for col in self.features_df.columns if col.startswith("PHATE")
+ ]
+ if phate_dims:
+ for dim in phate_dims:
+ dim_options.append({"label": dim, "value": dim})
+ existing_dims.append(dim)
+
+ # Check for existing UMAP coordinates (if they exist in the data)
+ umap_dims = [col for col in self.features_df.columns if col.startswith("UMAP")]
+ if umap_dims:
+ for dim in umap_dims:
+ dim_options.append({"label": dim, "value": dim})
+ existing_dims.append(dim)
+
+ # Store dimension options for dropdowns
+ self.dim_options = dim_options
+
+ # Set default x and y axes based on available dimensions
+ # TODO: hardcoding to default to PCA1 and PCA2
+ self.default_x = existing_dims[0] if existing_dims else "PCA1"
+ self.default_y = existing_dims[1] if len(existing_dims) > 1 else "PCA2"
+
+ def _compute_phate_on_combined_embeddings(self, combined_embeddings):
+ """Compute PHATE on combined embeddings (traditional approach)"""
+ # Compute PHATE if specified in config
+ if self.viz_config.phate_kwargs is not None:
+ logger.info(
+ f"Computing PHATE with {self.viz_config.phate_kwargs['n_components']} components on combined embeddings"
+ )
+
+ try:
+ from viscy.representation.evaluation.dimensionality_reduction import (
+ compute_phate,
+ )
+
+ # Use the compute_phate function with configurable parameters
+ logger.info(f"Using PHATE parameters: {self.viz_config.phate_kwargs}")
+
+ phate_model, phate_coords = compute_phate(
+ combined_embeddings, **self.viz_config.phate_kwargs
+ )
+
+ # Add PHATE coordinates to the features dataframe
+ for i in range(self.viz_config.phate_kwargs["n_components"]):
+ self.features_df[f"PHATE{i + 1}"] = phate_coords[:, i]
+
+ logger.info(
+ f"Successfully computed PHATE with {self.viz_config.phate_kwargs['n_components']} components"
+ )
+
+ except ImportError:
+ logger.warning(
+ "PHATE is not available. Install with: pip install viscy[phate]"
+ )
+ except Exception as e:
+ logger.warning(f"PHATE computation failed: {str(e)}")
+
+ # Collect all valid (dataset, fov, track) combinations
+ self.valid_combinations = []
+
+ for dataset_name, dataset_config in self.datasets.items():
+ logger.info(f"Processing dataset: {dataset_name}")
+ logger.info(f" fov_tracks: {dataset_config.fov_tracks}")
+
+ for fov_name, track_ids in dataset_config.fov_tracks.items():
+ if track_ids == "all":
+ # Get all tracks for this dataset/FOV combination
+ fov_tracks = self.features_df[
+ (self.features_df["dataset"] == dataset_name)
+ & (self.features_df["fov_name"] == fov_name)
+ ]["track_id"].unique()
+
+ logger.info(
+ f" FOV {fov_name}: found {len(fov_tracks)} tracks for 'all'"
+ )
+
+ for track_id in fov_tracks:
+ self.valid_combinations.append(
+ (dataset_name, fov_name, track_id)
+ )
+ else:
+ logger.info(f" FOV {fov_name}: using specific tracks {track_ids}")
+ for track_id in track_ids:
+ self.valid_combinations.append(
+ (dataset_name, fov_name, track_id)
+ )
+
+ logger.debug(f"Total valid filtered tracks: {len(self.valid_combinations)}")
+
+ # Create a MultiIndex for efficient filtering
+ if self.valid_combinations:
+ # Create temporary column for filtering
+ self.features_df["_temp_combo"] = list(
+ zip(
+ self.features_df["dataset"],
+ self.features_df["fov_name"],
+ self.features_df["track_id"],
+ )
+ )
+
+ # Create mask for selected data
+ selected_mask = self.features_df["_temp_combo"].isin(
+ self.valid_combinations
+ )
+
+ # Apply mask FIRST, then drop the temporary column
+ filtered_df_with_temp = self.features_df[selected_mask].copy()
+ background_df_with_temp = self.features_df[~selected_mask].copy()
+
+ # Now drop the temporary column from the filtered dataframes
+ self.filtered_features_df = filtered_df_with_temp.drop(
+ "_temp_combo", axis=1
+ )
+ self.background_features_df = background_df_with_temp.drop(
+ "_temp_combo", axis=1
+ )
+
+ # Subsample background points if there are too many
+ if len(self.background_features_df) > 5000:
+ self.background_features_df = self.background_features_df.sample(
+ n=5000, random_state=42
+ )
+
+ # Pre-compute track colors
+ cmap = plt.cm.get_cmap("tab20")
+ self.track_colors = {
+ track_key: f"rgb{tuple(int(x * 255) for x in cmap(i % 20)[:3])}"
+ for i, track_key in enumerate(self.valid_combinations)
+ }
+
+ # Drop the temporary column from the original dataframe
+ self.features_df = self.features_df.drop("_temp_combo", axis=1)
+ else:
+ self.filtered_features_df = pd.DataFrame()
+ self.background_features_df = pd.DataFrame()
+ self.track_colors = {}
+
+ logger.info(
+ f"Prepared data with {len(self.features_df)} total samples, "
+ f"{len(self.filtered_features_df)} filtered samples, "
+ f"and {len(self.background_features_df)} background samples, "
+ f"with {len(self.valid_combinations)} unique tracks"
+ )
+
+ def _create_figure(self):
+ """Create the initial scatter plot figure"""
+ self.fig = self._create_track_colored_figure()
+
+ def _init_app(self):
+ """Initialize the Dash application"""
+ self.app = dash.Dash(__name__)
+
+ # Add cluster assignment button next to clear selection
+ cluster_controls = html.Div(
+ [
+ html.Button(
+ "Assign to New Cluster",
+ id="assign-cluster",
+ style={
+ "backgroundColor": "#28a745",
+ "color": "white",
+ "border": "none",
+ "padding": "5px 10px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ "marginRight": "10px",
+ },
+ ),
+ html.Button(
+ "Clear All Clusters",
+ id="clear-clusters",
+ style={
+ "backgroundColor": "#dc3545",
+ "color": "white",
+ "border": "none",
+ "padding": "5px 10px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ "marginRight": "10px",
+ },
+ ),
+ html.Button(
+ "Save Clusters to CSV",
+ id="save-clusters-csv",
+ style={
+ "backgroundColor": "#17a2b8",
+ "color": "white",
+ "border": "none",
+ "padding": "5px 10px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ "marginRight": "10px",
+ },
+ ),
+ html.Button(
+ "Clear Selection",
+ id="clear-selection",
+ style={
+ "backgroundColor": "#6c757d",
+ "color": "white",
+ "border": "none",
+ "padding": "5px 10px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ },
+ ),
+ ],
+ style={"marginLeft": "10px", "display": "inline-block"},
+ )
+ # Create tabs for different views
+ tabs = dcc.Tabs(
+ id="view-tabs",
+ value="timeline-tab",
+ children=[
+ dcc.Tab(
+ label="Track Timeline",
+ value="timeline-tab",
+ children=[
+ html.Div(
+ id="track-timeline",
+ style={
+ "height": "auto",
+ "overflowY": "auto",
+ "maxHeight": "80vh",
+ "padding": "10px",
+ "marginTop": "10px",
+ },
+ ),
+ ],
+ ),
+ dcc.Tab(
+ label="Clusters",
+ value="clusters-tab",
+ id="clusters-tab",
+ children=[
+ html.Div(
+ id="cluster-container",
+ style={
+ "padding": "10px",
+ "marginTop": "10px",
+ },
+ ),
+ ],
+ style={"display": "none"}, # Initially hidden
+ ),
+ ],
+ style={"marginTop": "20px"},
+ )
+
+ # Add modal for cluster naming
+ cluster_name_modal = html.Div(
+ id="cluster-name-modal",
+ children=[
+ html.Div(
+ [
+ html.H3("Name Your Cluster", style={"marginBottom": "20px"}),
+ html.Label("Cluster Name:"),
+ dcc.Input(
+ id="cluster-name-input",
+ type="text",
+ placeholder="Enter cluster name...",
+ style={"width": "100%", "marginBottom": "20px"},
+ ),
+ html.Div(
+ [
+ html.Button(
+ "Save",
+ id="save-cluster-name",
+ style={
+ "backgroundColor": "#28a745",
+ "color": "white",
+ "border": "none",
+ "padding": "8px 16px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ "marginRight": "10px",
+ },
+ ),
+ html.Button(
+ "Cancel",
+ id="cancel-cluster-name",
+ style={
+ "backgroundColor": "#6c757d",
+ "color": "white",
+ "border": "none",
+ "padding": "8px 16px",
+ "borderRadius": "4px",
+ "cursor": "pointer",
+ },
+ ),
+ ],
+ style={"textAlign": "right"},
+ ),
+ ],
+ style={
+ "backgroundColor": "white",
+ "padding": "30px",
+ "borderRadius": "8px",
+ "maxWidth": "400px",
+ "margin": "auto",
+ "boxShadow": "0 4px 6px rgba(0, 0, 0, 0.1)",
+ "border": "1px solid #ddd",
+ },
+ )
+ ],
+ style={
+ "display": "none",
+ "position": "fixed",
+ "top": "0",
+ "left": "0",
+ "width": "100%",
+ "height": "100%",
+ "backgroundColor": "rgba(0, 0, 0, 0.5)",
+ "zIndex": "1000",
+ "justifyContent": "center",
+ "alignItems": "center",
+ },
+ )
+
+ # Update layout to use tabs
+ self.app.layout = html.Div(
+ style={
+ "maxWidth": "95vw",
+ "margin": "auto",
+ "padding": "20px",
+ },
+ children=[
+ html.H1(
+ "Track Visualization",
+ style={"textAlign": "center", "marginBottom": "20px"},
+ ),
+ html.Div(
+ [
+ html.Div(
+ style={
+ "width": "100%",
+ "display": "inline-block",
+ "verticalAlign": "top",
+ },
+ children=[
+ html.Div(
+ style={
+ "marginBottom": "20px",
+ "display": "flex",
+ "alignItems": "center",
+ "gap": "20px",
+ "flexWrap": "wrap",
+ },
+ children=[
+ html.Div(
+ [
+ html.Label(
+ "Color by:",
+ style={"marginRight": "10px"},
+ ),
+ dcc.Dropdown(
+ id="color-mode",
+ options=[
+ {
+ "label": "Track ID",
+ "value": "track",
+ },
+ {
+ "label": "Time",
+ "value": "time",
+ },
+ ],
+ value="track",
+ style={"width": "200px"},
+ ),
+ ]
+ ),
+ html.Div(
+ [
+ dcc.Checklist(
+ id="show-arrows",
+ options=[
+ {
+ "label": "Show arrows",
+ "value": "show",
+ }
+ ],
+ value=[],
+ style={"marginLeft": "20px"},
+ ),
+ ]
+ ),
+ html.Div(
+ [
+ html.Label(
+ "X-axis:",
+ style={"marginRight": "10px"},
+ ),
+ dcc.Dropdown(
+ id="x-axis",
+ options=self.dim_options,
+ value=self.default_x,
+ style={"width": "200px"},
+ ),
+ ]
+ ),
+ html.Div(
+ [
+ html.Label(
+ "Y-axis:",
+ style={"marginRight": "10px"},
+ ),
+ dcc.Dropdown(
+ id="y-axis",
+ options=self.dim_options,
+ value=self.default_y,
+ style={"width": "200px"},
+ ),
+ ]
+ ),
+ cluster_controls,
+ ],
+ ),
+ ],
+ ),
+ ]
+ ),
+ dcc.Loading(
+ id="loading",
+ children=[
+ dcc.Graph(
+ id="scatter-plot",
+ figure=self.fig,
+ config={
+ "displayModeBar": True,
+ "editable": False,
+ "showEditInChartStudio": False,
+ "modeBarButtonsToRemove": [
+ "select2d",
+ "resetScale2d",
+ ],
+ "edits": {
+ "annotationPosition": False,
+ "annotationTail": False,
+ "annotationText": False,
+ "shapePosition": True,
+ },
+ "scrollZoom": True,
+ },
+ style={"height": "50vh"},
+ ),
+ ],
+ type="default",
+ ),
+ tabs,
+ cluster_name_modal,
+ html.Div(
+ id="dummy-output", style={"display": "none"}
+ ), # Hidden dummy output
+ html.Div(
+ id="notification-area",
+ style={
+ "position": "fixed",
+ "top": "20px",
+ "right": "20px",
+ "zIndex": "2000",
+ "maxWidth": "300px",
+ },
+ ),
+ ],
+ )
+
+ @self.app.callback(
+ [
+ dd.Output("scatter-plot", "figure", allow_duplicate=True),
+ dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
+ ],
+ [
+ dd.Input("color-mode", "value"),
+ dd.Input("show-arrows", "value"),
+ dd.Input("x-axis", "value"),
+ dd.Input("y-axis", "value"),
+ dd.Input("scatter-plot", "relayoutData"),
+ dd.Input("scatter-plot", "selectedData"),
+ ],
+ [dd.State("scatter-plot", "figure")],
+ prevent_initial_call=True,
+ # Add debouncing to prevent rapid successive updates
+ config={"suppress_callback_exceptions": True},
+ )
+ def update_figure(
+ color_mode,
+ show_arrows,
+ x_axis,
+ y_axis,
+ relayout_data,
+ selected_data,
+ current_figure,
+ ):
+ # Input validation
+ if not color_mode:
+ color_mode = "track"
+ show_arrows = len(show_arrows or []) > 0
+
+ # Validate axis values exist in available options
+ valid_axis_values = [opt["value"] for opt in self.dim_options]
+ if not x_axis or x_axis not in valid_axis_values:
+ x_axis = self.default_x
+ if not y_axis or y_axis not in valid_axis_values:
+ y_axis = self.default_y
+
+ ctx = dash.callback_context
+ if not ctx.triggered:
+ triggered_id = "No clicks yet"
+ else:
+ triggered_id = ctx.triggered[0]["prop_id"].split(".")[0]
+
+ # Debouncing for axis changes to prevent rapid successive updates
+ current_time = time.time()
+ if triggered_id in ["x-axis", "y-axis"]:
+ if current_time - self._last_update_time < self._debounce_delay:
+ return dash.no_update, selected_data
+ self._last_update_time = current_time
+
+ # Always create new figure when control inputs change (remove dependency on callback context)
+ if triggered_id in [
+ "color-mode",
+ "show-arrows",
+ "x-axis",
+ "y-axis",
+ "No clicks yet"
+ ]:
+ if color_mode == "track":
+ fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis)
+ else:
+ fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis)
+
+ # Update dragmode and selection settings
+ fig.update_layout(
+ dragmode="lasso",
+ clickmode="event+select",
+ uirevision="true",
+ selectdirection="any",
+ )
+ else:
+ fig = dash.no_update
+
+ return fig, selected_data
+
+ @self.app.callback(
+ [
+ dd.Output("track-timeline", "children"),
+ dd.Output("scatter-plot", "figure", allow_duplicate=True),
+ ],
+ [dd.Input("scatter-plot", "clickData")],
+ [
+ dd.State("color-mode", "value"),
+ dd.State("show-arrows", "value"),
+ dd.State("x-axis", "value"),
+ dd.State("y-axis", "value"),
+ ],
+ prevent_initial_call=True,
+ )
+ def update_track_timeline(clickData, color_mode, show_arrows, x_axis, y_axis):
+ """Update the track timeline based on the clicked point"""
+ if clickData is None or self.features_df is None:
+ return (
+ html.Div("Click on a point to see the track timeline"),
+ dash.no_update,
+ )
+
+ # Parse the hover text to get dataset, track_id, time and fov_name
+ hover_text = clickData["points"][0]["text"]
+ lines = hover_text.split("
")
+ dataset_name = lines[0].split(": ")[1]
+ track_id = int(lines[1].split(": ")[1])
+ clicked_time = int(lines[2].split(": ")[1])
+ fov_name = lines[3].split(": ")[1]
+ # Get channels specific to this dataset
+ channels_to_display = self.datasets[dataset_name].channels_to_display
+
+ # Get all timepoints for this track
+ track_data = self.features_df[
+ (self.features_df["dataset"] == dataset_name)
+ & (self.features_df["fov_name"] == fov_name)
+ & (self.features_df["track_id"] == int(track_id))
+ ]
+
+ if track_data.empty:
+ return (
+ html.Div(
+ f"No data found for track {track_id} in dataset {dataset_name}"
+ ),
+ dash.no_update,
+ )
+
+ # Sort by time
+ track_data = track_data.sort_values("t")
+ timepoints = track_data["t"].unique()
+
+ # Create a list to store all timepoint columns
+ timepoint_columns = []
+
+ # First create the time labels row
+ time_labels = []
+ for t in timepoints:
+ is_clicked = t == clicked_time
+ time_style = {
+ "width": "150px",
+ "textAlign": "center",
+ "padding": "5px",
+ "fontSize": "20px" if is_clicked else "14px",
+ "fontWeight": "bold" if is_clicked else "normal",
+ "color": "#007bff" if is_clicked else "black",
+ }
+ time_labels.append(html.Div(f"t={t}", style=time_style))
+
+ timepoint_columns.append(
+ html.Div(
+ time_labels,
+ style={
+ "display": "flex",
+ "flexDirection": "row",
+ "minWidth": "fit-content",
+ "borderBottom": "2px solid #ddd",
+ "marginBottom": "10px",
+ "paddingBottom": "5px",
+ },
+ )
+ )
+
+ # Then create image rows for each channel
+ for channel in channels_to_display:
+ channel_images = []
+ for t in timepoints:
+ # Use correct 4-tuple cache key format
+ cache_key = (dataset_name, fov_name, int(track_id), int(t))
+
+ if (
+ cache_key in self.image_cache
+ and channel in self.image_cache[cache_key]
+ ):
+ is_clicked = t == clicked_time
+ image_style = {
+ "width": "150px",
+ "height": "150px",
+ "border": "1px solid #ddd",
+ "borderRadius": "4px",
+ }
+ channel_images.append(
+ html.Div(
+ html.Img(
+ src=self.image_cache[cache_key][channel],
+ style=image_style,
+ ),
+ style={
+ "width": "150px",
+ "padding": "5px",
+ },
+ )
+ )
+
+ if channel_images:
+ # Add channel label
+ timepoint_columns.append(
+ html.Div(
+ [
+ html.Div(
+ channel,
+ style={
+ "width": "100px",
+ "fontWeight": "bold",
+ "fontSize": "14px",
+ "padding": "5px",
+ "backgroundColor": "#f8f9fa",
+ "borderRadius": "4px",
+ "marginBottom": "5px",
+ "textAlign": "center",
+ },
+ ),
+ html.Div(
+ channel_images,
+ style={
+ "display": "flex",
+ "flexDirection": "row",
+ "minWidth": "fit-content",
+ "marginBottom": "15px",
+ },
+ ),
+ ]
+ )
+ )
+
+ # Create the main container
+ timeline_content = html.Div(
+ [
+ html.H4(
+ f"Track {track_id} (FOV: {fov_name})",
+ style={
+ "marginBottom": "20px",
+ "fontSize": "20px",
+ "fontWeight": "bold",
+ "color": "#2c3e50",
+ },
+ ),
+ html.Div(
+ timepoint_columns,
+ style={
+ "overflowX": "auto",
+ "overflowY": "hidden",
+ "whiteSpace": "nowrap",
+ "backgroundColor": "white",
+ "padding": "20px",
+ "borderRadius": "8px",
+ "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
+ "marginBottom": "20px",
+ },
+ ),
+ ]
+ )
+
+ return timeline_content, dash.no_update
+
+ # Add callback to show/hide clusters tab and handle modal
+ @self.app.callback(
+ [
+ dd.Output("clusters-tab", "style"),
+ dd.Output("cluster-container", "children"),
+ dd.Output("view-tabs", "value"),
+ dd.Output("scatter-plot", "figure", allow_duplicate=True),
+ dd.Output("cluster-name-modal", "style"),
+ dd.Output("cluster-name-input", "value"),
+ dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
+ ],
+ [
+ dd.Input("assign-cluster", "n_clicks"),
+ dd.Input("clear-clusters", "n_clicks"),
+ dd.Input("save-cluster-name", "n_clicks"),
+ dd.Input("cancel-cluster-name", "n_clicks"),
+ dd.Input({"type": "edit-cluster-name", "index": dash.ALL}, "n_clicks"),
+ ],
+ [
+ dd.State("scatter-plot", "selectedData"),
+ dd.State("scatter-plot", "figure"),
+ dd.State("color-mode", "value"),
+ dd.State("show-arrows", "value"),
+ dd.State("x-axis", "value"),
+ dd.State("y-axis", "value"),
+ dd.State("cluster-name-input", "value"),
+ ],
+ prevent_initial_call=True,
+ )
+ def update_clusters_tab(
+ assign_clicks,
+ clear_clicks,
+ save_name_clicks,
+ cancel_name_clicks,
+ edit_name_clicks,
+ selected_data,
+ current_figure,
+ color_mode,
+ show_arrows,
+ x_axis,
+ y_axis,
+ cluster_name,
+ ):
+ ctx = dash.callback_context
+ if not ctx.triggered:
+ return (
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ )
+
+ button_id = ctx.triggered[0]["prop_id"].split(".")[0]
+
+ # Handle edit cluster name button clicks
+ if button_id.startswith('{"type":"edit-cluster-name"'):
+ try:
+ id_dict = json.loads(button_id)
+ cluster_idx = id_dict["index"]
+
+ # Get current cluster name using the manager
+ cluster = self.cluster_manager.get_cluster_by_index(cluster_idx)
+ current_name = (
+ cluster.name if cluster else f"Cluster {cluster_idx + 1}"
+ )
+
+ # Show modal
+ modal_style = {
+ "display": "flex",
+ "position": "fixed",
+ "top": "0",
+ "left": "0",
+ "width": "100%",
+ "height": "100%",
+ "backgroundColor": "rgba(0, 0, 0, 0.5)",
+ "zIndex": "1000",
+ "justifyContent": "center",
+ "alignItems": "center",
+ }
+
+ return (
+ {"display": "block"},
+ self._get_cluster_images(),
+ "clusters-tab",
+ dash.no_update,
+ modal_style,
+ current_name,
+ dash.no_update,
+ )
+ except Exception:
+ return (
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ )
+
+ # Handle clear clusters button
+ elif button_id == "clear-clusters" and clear_clicks:
+ self.cluster_manager.clear_all_clusters()
+ logger.info("Cleared all clusters")
+
+ # Update figure to remove cluster coloring
+ if color_mode == "track":
+ fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis)
+ else:
+ fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis)
+
+ return (
+ {"display": "none"}, # Hide clusters tab
+ html.Div("No clusters created yet"),
+ "timeline-tab", # Switch back to timeline tab
+ fig,
+ {"display": "none"}, # Hide modal
+ "",
+ None, # Clear selection
+ )
+
+ # Handle save cluster name button
+ elif button_id == "save-cluster-name" and save_name_clicks and cluster_name:
+ # Get the most recent cluster and update its name
+ if self.cluster_manager.clusters:
+ latest_cluster = self.cluster_manager.clusters[-1]
+ latest_cluster.name = cluster_name
+ logger.info(
+ f"Named cluster {latest_cluster.id} as '{cluster_name}'"
+ )
+
+ # Close modal and update clusters display
+ modal_style = {"display": "none"}
+ return (
+ {"display": "block"},
+ self._get_cluster_images(),
+ "clusters-tab",
+ dash.no_update,
+ modal_style,
+ "",
+ dash.no_update,
+ )
+
+ # Handle cancel cluster name button
+ elif button_id == "cancel-cluster-name" and cancel_name_clicks:
+ # Just close the modal without saving
+ modal_style = {"display": "none"}
+ return (
+ {"display": "block"},
+ self._get_cluster_images(),
+ "clusters-tab",
+ dash.no_update,
+ modal_style,
+ "",
+ dash.no_update,
+ )
+
+ # Handle assign cluster button
+ elif button_id == "assign-cluster" and assign_clicks and selected_data:
+ if not selected_data or not selected_data.get("points"):
+ logger.warning("No points selected for clustering")
+ return (
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ )
+
+ selected_points_list = []
+
+ # Determine if we're in cluster mode by checking if clusters exist
+ is_cluster_mode = len(self.cluster_manager.clusters) > 0
+
+ # Get information about each selected point
+ for point in selected_data["points"]:
+ curve_number = point.get("curveNumber", 0)
+ point_index = point.get("pointIndex")
+
+ if point_index is None:
+ continue
+
+ if is_cluster_mode:
+ # Handle cluster-colored figure
+ # Structure: [unclustered_trace, cluster1_trace, cluster2_trace, ...]
+
+ if curve_number == 0:
+ # This is an unclustered point
+ # Get the DataFrame row from the unclustered points
+ df_cache_keys = [
+ (
+ row["dataset"],
+ row["fov_name"],
+ row["track_id"],
+ row["t"],
+ )
+ for _, row in self.filtered_features_df.iterrows()
+ ]
+
+ clustered_cache_keys = set()
+ for cluster in self.cluster_manager.clusters:
+ clustered_cache_keys.update(cluster.cache_keys)
+
+ unclustered_mask = [
+ cache_key not in clustered_cache_keys
+ for cache_key in df_cache_keys
+ ]
+
+ if any(unclustered_mask):
+ unclustered_df = self.filtered_features_df[
+ unclustered_mask
+ ]
+ if point_index < len(unclustered_df):
+ selected_point = unclustered_df.iloc[point_index]
+ selected_points_list.append(
+ {
+ "track_id": selected_point["track_id"],
+ "t": selected_point["t"],
+ "fov_name": selected_point["fov_name"],
+ "dataset": selected_point["dataset"],
+ }
+ )
+ else:
+ # This is a point from an existing cluster
+ cluster_index = curve_number - 1
+ if cluster_index < len(self.cluster_manager.clusters):
+ cluster = self.cluster_manager.clusters[cluster_index]
+
+ # Get the DataFrame rows for this cluster
+ df_cache_keys = [
+ (
+ row["dataset"],
+ row["fov_name"],
+ row["track_id"],
+ row["t"],
+ )
+ for _, row in self.filtered_features_df.iterrows()
+ ]
+
+ cluster_mask = [
+ cache_key in cluster.cache_keys
+ for cache_key in df_cache_keys
+ ]
+
+ if any(cluster_mask):
+ cluster_df = self.filtered_features_df[cluster_mask]
+ if point_index < len(cluster_df):
+ selected_point = cluster_df.iloc[point_index]
+ selected_points_list.append(
+ {
+ "track_id": selected_point["track_id"],
+ "t": selected_point["t"],
+ "fov_name": selected_point["fov_name"],
+ "dataset": selected_point["dataset"],
+ }
+ )
+ else:
+ # Handle track-colored figure
+ # Skip background points (curve 0 if background exists)
+ background_offset = (
+ 1 if not self.background_features_df.empty else 0
+ )
+
+ if curve_number < background_offset:
+ # This is a background point, skip it
+ continue
+
+ # Find which track this point belongs to
+ track_curve_index = curve_number - background_offset
+
+ # Map to the actual track
+ if track_curve_index < len(self.valid_combinations):
+ dataset_name, fov_name, track_id = self.valid_combinations[
+ track_curve_index
+ ]
+
+ # Get the track data
+ track_data = self.filtered_features_df[
+ (self.filtered_features_df["dataset"] == dataset_name)
+ & (self.filtered_features_df["fov_name"] == fov_name)
+ & (
+ self.filtered_features_df["track_id"]
+ == int(track_id)
+ )
+ ].sort_values("t")
+
+ # Get the specific point within this track
+ if point_index < len(track_data):
+ selected_point = track_data.iloc[point_index]
+ selected_points_list.append(
+ {
+ "track_id": selected_point["track_id"],
+ "t": selected_point["t"],
+ "fov_name": selected_point["fov_name"],
+ "dataset": selected_point["dataset"],
+ }
+ )
+
+ if not selected_points_list:
+ logger.warning("No valid points found in selection")
+ return (
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ )
+
+ # Create a new cluster using the cluster manager
+ cluster_id = self.cluster_manager.create_cluster_from_points(
+ selected_points_list
+ )
+ logger.info(
+ f"Created cluster {cluster_id} with {len(selected_points_list)} points"
+ )
+
+ # Show modal for naming the cluster
+ modal_style = {
+ "display": "flex",
+ "position": "fixed",
+ "top": "0",
+ "left": "0",
+ "width": "100%",
+ "height": "100%",
+ "backgroundColor": "rgba(0, 0, 0, 0.5)",
+ "zIndex": "1000",
+ "justifyContent": "center",
+ "alignItems": "center",
+ }
+
+ # Update figure to show cluster coloring
+ fig = self._create_cluster_colored_figure(show_arrows, x_axis, y_axis)
+
+ return (
+ {"display": "block"}, # Show clusters tab
+ self._get_cluster_images(),
+ "clusters-tab", # Switch to clusters tab
+ fig, # Updated figure with cluster colors
+ modal_style, # Show naming modal
+ f"Cluster {len(self.cluster_manager.clusters)}", # Default name
+ None, # Clear selection
+ )
+
+ # Default return (no updates)
+ return (
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ dash.no_update,
+ )
+
+ @self.app.callback(
+ [
+ dd.Output("scatter-plot", "figure", allow_duplicate=True),
+ dd.Output("scatter-plot", "selectedData", allow_duplicate=True),
+ dd.Output("track-timeline", "children", allow_duplicate=True),
+ ],
+ [dd.Input("clear-selection", "n_clicks")],
+ [
+ dd.State("color-mode", "value"),
+ dd.State("show-arrows", "value"),
+ dd.State("x-axis", "value"),
+ dd.State("y-axis", "value"),
+ ],
+ prevent_initial_call=True,
+ )
+ def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis):
+ """Callback to clear the selection and restore original opacity"""
+ if n_clicks:
+ # Create a new figure with no selections
+ if color_mode == "track":
+ fig = self._create_track_colored_figure(
+ len(show_arrows or []) > 0,
+ x_axis,
+ y_axis,
+ )
+ else:
+ fig = self._create_time_colored_figure(
+ len(show_arrows or []) > 0,
+ x_axis,
+ y_axis,
+ )
+
+ # Update layout to maintain lasso mode but clear selections
+ fig.update_layout(
+ dragmode="lasso",
+ clickmode="event+select",
+ uirevision=None, # Reset UI state to clear selections
+ selectdirection="any",
+ )
+
+ # Clear the track timeline as well
+ empty_timeline = html.Div(
+ "Click on a point to see the track timeline",
+ style={
+ "textAlign": "center",
+ "color": "#666",
+ "fontSize": "16px",
+ "padding": "40px",
+ "fontStyle": "italic",
+ },
+ )
+
+ return (
+ fig,
+ None,
+ empty_timeline,
+ ) # Clear figure selection, selectedData, and timeline
+ return dash.no_update, dash.no_update, dash.no_update
+
+ @self.app.callback(
+ [
+ dd.Output("dummy-output", "children"),
+ dd.Output("notification-area", "children"),
+ ],
+ [dd.Input("save-clusters-csv", "n_clicks")],
+ prevent_initial_call=True,
+ )
+ def save_clusters_to_csv(n_clicks):
+ """Save all clusters to CSV files"""
+ notification = ""
+
+ if n_clicks and self.cluster_manager.clusters:
+ try:
+ # Create output directory if it doesn't exist
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ saved_files = []
+
+ # Save each cluster to a separate CSV file
+ for i, cluster in enumerate(self.cluster_manager.clusters):
+ cluster_name = cluster.name or f"Cluster_{i + 1}"
+ # Clean the cluster name for filename
+ safe_name = "".join(
+ c
+ for c in cluster_name
+ if c.isalnum() or c in (" ", "-", "_")
+ ).rstrip()
+ safe_name = safe_name.replace(" ", "_")
+
+ # Create DataFrame from cluster points
+ cluster_data = []
+ for point in cluster.points:
+ # Get the full row data for this point
+ point_row = self.filtered_features_df[
+ (self.filtered_features_df["dataset"] == point.dataset)
+ & (
+ self.filtered_features_df["fov_name"]
+ == point.fov_name
+ )
+ & (
+ self.filtered_features_df["track_id"]
+ == point.track_id
+ )
+ & (self.filtered_features_df["t"] == point.t)
+ ]
+
+ if not point_row.empty:
+ cluster_data.append(point_row.iloc[0])
+
+ if cluster_data:
+ cluster_df = pd.DataFrame(cluster_data)
+
+ # Add cluster information (dataset is already in the dataframe)
+ cluster_df["cluster_name"] = (
+ cluster.name or f"Cluster {i + 1}"
+ )
+ cluster_df["cluster_size"] = len(cluster.points)
+
+ # Reorder columns to put dataset and cluster info at the front
+ cols = cluster_df.columns.tolist()
+ priority_cols = [
+ "dataset",
+ "cluster_name",
+ "cluster_size",
+ ]
+ other_cols = [
+ col for col in cols if col not in priority_cols
+ ]
+ cluster_df = cluster_df[priority_cols + other_cols]
+
+ # Save to CSV
+ csv_path = self.output_dir / f"{safe_name}.csv"
+ cluster_df.to_csv(csv_path, index=False)
+ saved_files.append(csv_path.name)
+ logger.info(f"Saved cluster '{cluster.name}' to {csv_path}")
+
+ # Also save a summary CSV with all clusters
+ if self.cluster_manager.clusters:
+ all_cluster_data = []
+ for i, cluster in enumerate(self.cluster_manager.clusters):
+ for point in cluster.points:
+ point_row = self.filtered_features_df[
+ (
+ self.filtered_features_df["dataset"]
+ == point.dataset
+ )
+ & (
+ self.filtered_features_df["fov_name"]
+ == point.fov_name
+ )
+ & (
+ self.filtered_features_df["track_id"]
+ == point.track_id
+ )
+ & (self.filtered_features_df["t"] == point.t)
+ ]
+
+ if not point_row.empty:
+ row_data = point_row.iloc[0].to_dict()
+ row_data["cluster_name"] = (
+ cluster.name or f"Cluster {i + 1}"
+ )
+ row_data["cluster_index"] = i
+ all_cluster_data.append(row_data)
+
+ if all_cluster_data:
+ summary_df = pd.DataFrame(all_cluster_data)
+
+ # Reorder columns to put dataset and cluster info at the front
+ cols = summary_df.columns.tolist()
+ priority_cols = [
+ "dataset",
+ "cluster_name",
+ "cluster_index",
+ ]
+ other_cols = [
+ col for col in cols if col not in priority_cols
+ ]
+ summary_df = summary_df[priority_cols + other_cols]
+
+ summary_path = self.output_dir / "all_clusters_summary.csv"
+ summary_df.to_csv(summary_path, index=False)
+ saved_files.append(summary_path.name)
+ logger.info(
+ f"Saved summary of all clusters to {summary_path}"
+ )
+
+ # Log summary statistics
+ dataset_counts = summary_df["dataset"].value_counts()
+ logger.info(
+ f"Exported {len(self.cluster_manager.clusters)} clusters with {len(all_cluster_data)} total points"
+ )
+ logger.info(
+ f"Points per dataset: {dataset_counts.to_dict()}"
+ )
+
+ # Create success notification
+ notification = html.Div(
+ [
+ html.Div(
+ [
+ html.Strong("✅ Clusters Saved Successfully!"),
+ html.Br(),
+ html.Small(
+ f"Saved {len(saved_files)} files to {self.output_dir}"
+ ),
+ html.Br(),
+ html.Small(
+ f"Files: {', '.join(saved_files[:3])}"
+ + ("..." if len(saved_files) > 3 else "")
+ ),
+ ],
+ style={
+ "backgroundColor": "#d4edda",
+ "color": "#155724",
+ "border": "1px solid #c3e6cb",
+ "borderRadius": "4px",
+ "padding": "10px",
+ "marginBottom": "10px",
+ "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
+ },
+ )
+ ],
+ id="success-notification",
+ )
+
+ except Exception as e:
+ logger.error(f"Error saving clusters to CSV: {e}")
+ # Create error notification
+ notification = html.Div(
+ [
+ html.Div(
+ [
+ html.Strong("❌ Error Saving Clusters"),
+ html.Br(),
+ html.Small(f"Error: {str(e)}"),
+ ],
+ style={
+ "backgroundColor": "#f8d7da",
+ "color": "#721c24",
+ "border": "1px solid #f5c6cb",
+ "borderRadius": "4px",
+ "padding": "10px",
+ "marginBottom": "10px",
+ "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
+ },
+ )
+ ],
+ id="error-notification",
+ )
+
+ elif n_clicks and not self.cluster_manager.clusters:
+ # No clusters to save
+ notification = html.Div(
+ [
+ html.Div(
+ [
+ html.Strong("ℹ️ No Clusters to Save"),
+ html.Br(),
+ html.Small("Create some clusters first before saving."),
+ ],
+ style={
+ "backgroundColor": "#d1ecf1",
+ "color": "#0c5460",
+ "border": "1px solid #bee5eb",
+ "borderRadius": "4px",
+ "padding": "10px",
+ "marginBottom": "10px",
+ "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
+ },
+ )
+ ],
+ id="info-notification",
+ )
+
+ return "", notification
+
+ def _create_track_colored_figure(
+ self,
+ show_arrows=False,
+ x_axis=None,
+ y_axis=None,
+ ):
+ """Create scatter plot with track-based coloring"""
+ if self.filtered_features_df is None or self.filtered_features_df.empty:
+ return go.Figure()
+
+ x_axis = x_axis or self.default_x
+ y_axis = y_axis or self.default_y
+
+ fig = go.Figure()
+
+ # Set initial layout with lasso mode
+ fig.update_layout(
+ dragmode="lasso",
+ clickmode="event+select",
+ selectdirection="any",
+ plot_bgcolor="white",
+ title="Embedding Visualization",
+ xaxis_title=x_axis,
+ yaxis_title=y_axis,
+ uirevision=True,
+ hovermode="closest",
+ showlegend=True,
+ legend=dict(
+ yanchor="top",
+ y=1,
+ xanchor="left",
+ x=1.02,
+ title="Tracks",
+ bordercolor="Black",
+ borderwidth=1,
+ ),
+ margin=dict(l=50, r=150, t=50, b=50),
+ autosize=True,
+ )
+ fig.update_xaxes(showgrid=False)
+ fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1)
+
+ # Use pre-computed filtered and background data
+ filtered_features_df = self.filtered_features_df
+
+ # Add background points using pre-computed background data
+ # Make them non-interactive and render behind main points
+ if not self.background_features_df.empty:
+ fig.add_trace(
+ go.Scattergl(
+ x=self.background_features_df[x_axis],
+ y=self.background_features_df[y_axis],
+ mode="markers",
+ marker=dict(
+ size=12, color="lightgray", opacity=0.3
+ ), # Smaller and more transparent
+ name=f"Other tracks ({len(self.background_features_df)} points)",
+ text=[
+ f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}"
+ for dataset, track_id, t, fov in zip(
+ self.background_features_df["dataset"],
+ self.background_features_df["track_id"],
+ self.background_features_df["t"],
+ self.background_features_df["fov_name"],
+ )
+ ],
+ hoverinfo="text",
+ showlegend=True,
+ hoverlabel=dict(namelength=-1),
+ selectedpoints=False,
+ unselected=dict(marker=dict(opacity=0.3)),
+ selected=dict(marker=dict(opacity=0.3)),
+ )
+ )
+
+ # Use pre-computed unique track keys and colors
+ # Add points for each selected track with cluster coloring
+ for dataset_name, fov_name, track_id in self.valid_combinations:
+ track_data = filtered_features_df[
+ (filtered_features_df["dataset"] == dataset_name)
+ & (filtered_features_df["fov_name"] == fov_name)
+ & (filtered_features_df["track_id"] == int(track_id))
+ ]
+
+ if track_data.empty:
+ logger.warning(
+ f"No data found for track {track_id} in dataset {dataset_name} and fov {fov_name}"
+ )
+ continue
+
+ # Sort by time
+ if hasattr(track_data, "sort_values"):
+ track_data = track_data.sort_values("t")
+
+ # Add track points
+ fig.add_trace(
+ go.Scattergl(
+ x=track_data[x_axis],
+ y=track_data[y_axis],
+ mode="markers",
+ marker=dict(
+ size=self._DEFAULT_MARKER_SIZE,
+ color=self.track_colors[(dataset_name, fov_name, track_id)],
+ opacity=1.0,
+ line=dict(width=0.5, color="black"),
+ ),
+ name=f"{dataset_name}:{track_id}",
+ text=[
+ f"Dataset: {dataset_name}
Track: {track_id}
Time: {t}
FOV: {fov}"
+ for t, fov in zip(track_data["t"], track_data["fov_name"])
+ ],
+ hoverinfo="text",
+ unselected=dict(
+ marker=dict(opacity=0.6, size=self._DEFAULT_MARKER_SIZE)
+ ),
+ selected=dict(
+ marker=dict(size=self._DEFAULT_MARKER_SIZE * 2.0, opacity=1.0)
+ ),
+ hoverlabel=dict(namelength=-1),
+ )
+ )
+
+ # Add trajectory lines and arrows if requested
+ if show_arrows and len(track_data) > 1:
+ x_coords = track_data[x_axis].values
+ y_coords = track_data[y_axis].values
+ track_key = (dataset_name, fov_name, track_id)
+
+ # Add dashed lines for the trajectory
+ fig.add_trace(
+ go.Scattergl(
+ x=x_coords,
+ y=y_coords,
+ mode="lines",
+ line=dict(
+ color=self.track_colors[track_key],
+ width=3,
+ dash="dot",
+ ),
+ showlegend=False,
+ hoverinfo="skip",
+ selectedpoints=False,
+ )
+ )
+
+ # Add arrows at regular intervals
+ arrow_interval = max(1, len(track_data) // 3)
+ for i in range(0, len(track_data) - 1, arrow_interval):
+ dx = x_coords[i + 1] - x_coords[i]
+ dy = y_coords[i + 1] - y_coords[i]
+
+ # Only add arrow if there's significant movement
+ if dx * dx + dy * dy > 1e-6:
+ fig.add_annotation(
+ x=x_coords[i + 1],
+ y=y_coords[i + 1],
+ ax=x_coords[i],
+ ay=y_coords[i],
+ xref="x",
+ yref="y",
+ axref="x",
+ ayref="y",
+ showarrow=True,
+ arrowhead=2,
+ arrowsize=1,
+ arrowwidth=1,
+ arrowcolor=self.track_colors[track_key],
+ opacity=0.8,
+ )
+
+ return fig
+
+ def _create_time_colored_figure(
+ self,
+ show_arrows=False,
+ x_axis=None,
+ y_axis=None,
+ ):
+ """Create scatter plot with time-based coloring"""
+ if self.filtered_features_df is None or self.filtered_features_df.empty:
+ return go.Figure()
+
+ x_axis = x_axis or self.default_x
+ y_axis = y_axis or self.default_y
+
+ fig = go.Figure()
+
+ # Set initial layout with lasso mode
+ fig.update_layout(
+ dragmode="lasso",
+ clickmode="event+select",
+ selectdirection="any",
+ plot_bgcolor="white",
+ title="Embedding Visualization",
+ xaxis_title=x_axis,
+ yaxis_title=y_axis,
+ uirevision=True,
+ hovermode="closest",
+ showlegend=True,
+ legend=dict(
+ yanchor="top",
+ y=1,
+ xanchor="left",
+ x=1.02,
+ bordercolor="Black",
+ borderwidth=1,
+ ),
+ margin=dict(l=50, r=150, t=50, b=50),
+ autosize=True,
+ )
+ fig.update_xaxes(showgrid=False)
+ fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1)
+
+ # Use pre-computed filtered and background data
+ filtered_features_df = self.filtered_features_df
+
+ # Add background points using pre-computed background data
+ # Make them non-interactive and render behind main points
+ if not self.background_features_df.empty:
+ fig.add_trace(
+ go.Scattergl(
+ x=self.background_features_df[x_axis],
+ y=self.background_features_df[y_axis],
+ mode="markers",
+ marker=dict(
+ size=12, color="lightgray", opacity=0.3
+ ), # Smaller and more transparent
+ name=f"Other points ({len(self.background_features_df)} points)",
+ text=[
+ f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}"
+ for dataset, track_id, t, fov in zip(
+ self.background_features_df["dataset"],
+ self.background_features_df["track_id"],
+ self.background_features_df["t"],
+ self.background_features_df["fov_name"],
+ )
+ ],
+ hoverinfo="text",
+ hoverlabel=dict(namelength=-1),
+ selectedpoints=False,
+ unselected=dict(marker=dict(opacity=0.3)),
+ selected=dict(marker=dict(opacity=0.3)),
+ )
+ )
+
+ # Add time-colored points
+ if not filtered_features_df.empty:
+ fig.add_trace(
+ go.Scattergl(
+ x=filtered_features_df[x_axis],
+ y=filtered_features_df[y_axis],
+ mode="markers",
+ marker=dict(
+ size=self._DEFAULT_MARKER_SIZE,
+ color=filtered_features_df["t"],
+ colorscale="Viridis",
+ colorbar=dict(title="Time"),
+ ),
+ text=[
+ f"Dataset: {dataset}
Track: {track_id}
Time: {t}
FOV: {fov}"
+ for dataset, track_id, t, fov in zip(
+ filtered_features_df["dataset"],
+ filtered_features_df["track_id"],
+ filtered_features_df["t"],
+ filtered_features_df["fov_name"],
+ )
+ ],
+ hoverinfo="text",
+ showlegend=False,
+ hoverlabel=dict(namelength=-1),
+ )
+ )
+
+ # Add arrows if requested (same as before, but make them non-selectable)
+ if show_arrows and not filtered_features_df.empty:
+ for dataset_name, fov_name, track_id in filtered_features_df.apply(
+ lambda row: (row["dataset"], row["fov_name"], str(row["track_id"])),
+ axis=1,
+ ).unique():
+ track_data = filtered_features_df[
+ (filtered_features_df["dataset"] == dataset_name)
+ & (filtered_features_df["fov_name"] == fov_name)
+ & (filtered_features_df["track_id"] == track_id)
+ ]
+
+ if len(track_data) <= 1:
+ continue
+
+ # Sort by time
+ if hasattr(track_data, "sort_values"):
+ track_data = track_data.sort_values("t")
+ x_coords = track_data[x_axis].values
+ y_coords = track_data[y_axis].values
+ distances = np.sqrt(
+ np.diff(np.array(x_coords)) ** 2 + np.diff(np.array(y_coords)) ** 2
+ )
+
+ # Only show arrows for movements larger than the median distance
+ threshold = np.median(distances) * 0.5 if len(distances) > 0 else 0
+
+ arrow_x = []
+ arrow_y = []
+
+ for i in range(len(track_data) - 1):
+ if distances[i] > threshold:
+ arrow_x.extend([x_coords[i], x_coords[i + 1], None])
+ arrow_y.extend([y_coords[i], y_coords[i + 1], None])
+
+ if arrow_x:
+ fig.add_trace(
+ go.Scatter(
+ x=arrow_x,
+ y=arrow_y,
+ mode="lines",
+ line=dict(
+ color="rgba(128, 128, 128, 0.5)",
+ width=1,
+ dash="dot",
+ ),
+ showlegend=False,
+ hoverinfo="skip",
+ selectedpoints=False,
+ )
+ )
+
+ return fig
+
+ def _create_cluster_colored_figure(
+ self,
+ show_arrows=False,
+ x_axis=None,
+ y_axis=None,
+ ):
+ """Create scatter plot with cluster-based coloring"""
+ if self.filtered_features_df is None or self.filtered_features_df.empty:
+ return go.Figure()
+
+ x_axis = x_axis or self.default_x
+ y_axis = y_axis or self.default_y
+
+ fig = go.Figure()
+
+ # Set initial layout
+ fig.update_layout(
+ dragmode="lasso",
+ showlegend=True,
+ height=700,
+ xaxis=dict(scaleanchor="y", scaleratio=1), # Square plot
+ yaxis=dict(scaleanchor="x", scaleratio=1),
+ )
+ fig.update_xaxes(showgrid=False)
+ fig.update_yaxes(showgrid=False)
+
+ # Use cluster manager's color scheme
+ cluster_colors = self.cluster_manager.get_cluster_colors_by_index()
+
+ # Get all clustered cache keys
+ clustered_cache_keys = set()
+ for cluster in self.cluster_manager.clusters:
+ clustered_cache_keys.update(cluster.cache_keys)
+
+ # Create a mask for unclustered points
+ # Map DataFrame rows to cache keys and check if they're in any cluster
+ df_cache_keys = [
+ (row["dataset"], row["fov_name"], row["track_id"], row["t"])
+ for _, row in self.filtered_features_df.iterrows()
+ ]
+
+ unclustered_mask = [
+ cache_key not in clustered_cache_keys for cache_key in df_cache_keys
+ ]
+
+ # Add unclustered points (background) - make them non-interactive
+ if any(unclustered_mask):
+ unclustered_df = self.filtered_features_df[unclustered_mask]
+
+ fig.add_trace(
+ go.Scatter(
+ x=unclustered_df[x_axis],
+ y=unclustered_df[y_axis],
+ mode="markers",
+ marker=dict(
+ size=12,
+ color="lightgray",
+ opacity=0.6, # More transparent
+ ),
+ name="Unclustered",
+ hovertemplate="Unclustered
"
+ + f"{x_axis}: %{{x}}
"
+ + f"{y_axis}: %{{y}}
"
+ + "",
+ selectedpoints=False,
+ unselected=dict(marker=dict(opacity=0.3)),
+ selected=dict(marker=dict(opacity=0.3)),
+ )
+ )
+
+ # Add each cluster as a separate trace
+ for i, cluster in enumerate(self.cluster_manager.clusters):
+ # Create a mask for points in this cluster
+ cluster_mask = [
+ cache_key in cluster.cache_keys for cache_key in df_cache_keys
+ ]
+
+ if any(cluster_mask):
+ cluster_df = self.filtered_features_df[cluster_mask]
+ # Use the color from cluster manager (consistent with tabs)
+ color = cluster_colors[i] if i < len(cluster_colors) else "gray"
+
+ fig.add_trace(
+ go.Scatter(
+ x=cluster_df[x_axis],
+ y=cluster_df[y_axis],
+ mode="markers",
+ marker=dict(
+ size=self._DEFAULT_MARKER_SIZE,
+ color=color,
+ line=dict(width=0.5, color="black"),
+ opacity=0.8,
+ ),
+ name=cluster.name or f"Cluster {i + 1}",
+ hovertemplate=f"{cluster.name or f'Cluster {i + 1}'}
"
+ + f"{x_axis}: %{{x}}
"
+ + f"{y_axis}: %{{y}}
"
+ + "",
+ )
+ )
+
+ return fig
+
+ def _cleanup_cache(self):
+ """Clear the image cache when the program exits"""
+ logging.info("Cleaning up image cache...")
+ self.image_cache.clear()
+
+ def _get_cluster_images(self):
+ """Display images for all clusters in a grid layout"""
+ if not self.cluster_manager.clusters:
+ return html.Div("No clusters created yet")
+
+ # Debug information
+ logger.info(f"Image cache size: {len(self.image_cache)}")
+ logger.info(f"Number of clusters: {len(self.cluster_manager.clusters)}")
+
+ # Use cluster manager's color scheme (consistent with scatter plot)
+ cluster_colors = self.cluster_manager.get_cluster_colors_by_index()
+
+ # Collect all unique channels across all datasets in the cluster points
+ all_channels_in_cluster = set()
+ for cluster in self.cluster_manager.clusters:
+ for point in cluster.points:
+ # Get channels for this specific dataset
+ if point.dataset in self.datasets:
+ dataset_channels = self.datasets[point.dataset].channels_to_display
+ all_channels_in_cluster.update(dataset_channels)
+
+ logger.info(f"All channels found in clusters: {all_channels_in_cluster}")
+
+ # Create individual cluster panels
+ cluster_panels = []
+ for cluster_idx, cluster in enumerate(self.cluster_manager.clusters):
+ logger.info(
+ f"Processing cluster {cluster_idx} with {len(cluster.points)} points"
+ )
+
+ # Group points by dataset to handle different channels
+ points_by_dataset = {}
+ for point in cluster.points:
+ if point.dataset not in points_by_dataset:
+ points_by_dataset[point.dataset] = []
+ points_by_dataset[point.dataset].append(point)
+
+ # Create images organized by dataset and then by channel
+ all_channel_images = []
+ images_found = 0
+
+ for dataset_name, dataset_points in points_by_dataset.items():
+ if dataset_name not in self.datasets:
+ continue
+
+ dataset_channels = self.datasets[dataset_name].channels_to_display
+
+ # Add dataset header if there are multiple datasets
+ if len(points_by_dataset) > 1:
+ all_channel_images.append(
+ html.H5(
+ f"Dataset: {dataset_name}",
+ style={
+ "margin": "10px 5px 5px 5px",
+ "fontSize": "14px",
+ "fontWeight": "bold",
+ "color": "#2c3e50",
+ "borderBottom": "1px solid #dee2e6",
+ "paddingBottom": "5px",
+ },
+ )
+ )
+
+ for channel in dataset_channels:
+ images = []
+ for point in dataset_points:
+ cache_key = (
+ point.dataset,
+ point.fov_name,
+ point.track_id,
+ point.t,
+ )
+
+ # Debug: Check if cache key exists
+ if cache_key in self.image_cache:
+ if channel in self.image_cache[cache_key]:
+ images_found += 1
+ images.append(
+ html.Div(
+ [
+ html.Img(
+ src=self.image_cache[cache_key][
+ channel
+ ],
+ style={
+ "width": "100px",
+ "height": "100px",
+ "margin": "2px",
+ "border": f"2px solid {cluster_colors[cluster_idx]}",
+ "borderRadius": "4px",
+ },
+ ),
+ html.Div(
+ f"{dataset_name}:T{point.track_id}:t{point.t}",
+ style={
+ "textAlign": "center",
+ "fontSize": "9px",
+ "maxWidth": "100px",
+ "overflow": "hidden",
+ "textOverflow": "ellipsis",
+ },
+ ),
+ ],
+ style={
+ "display": "inline-block",
+ "margin": "2px",
+ "verticalAlign": "top",
+ },
+ )
+ )
+ else:
+ logger.debug(
+ f"Channel {channel} not found for cache key {cache_key}"
+ )
+ else:
+ logger.debug(
+ f"Cache key {cache_key} not found in image cache"
+ )
+
+ if images:
+ all_channel_images.extend(
+ [
+ html.H6(
+ f"{channel}",
+ style={
+ "margin": "5px",
+ "fontSize": "12px",
+ "fontWeight": "bold",
+ "position": "sticky",
+ "left": "0",
+ "backgroundColor": "#f8f9fa",
+ "zIndex": "1",
+ "paddingLeft": "5px",
+ },
+ ),
+ html.Div(
+ images,
+ style={
+ "whiteSpace": "nowrap",
+ "marginBottom": "10px",
+ },
+ ),
+ ]
+ )
+
+ logger.info(f"Cluster {cluster_idx}: Found {images_found} images")
+
+ if all_channel_images:
+ # Create a panel for this cluster with synchronized scrolling
+ cluster_name = (
+ cluster.name if cluster.name else f"Cluster {cluster_idx + 1}"
+ )
+ cluster_panels.append(
+ html.Div(
+ [
+ html.Div(
+ [
+ html.Span(
+ cluster_name,
+ style={
+ "color": cluster_colors[cluster_idx],
+ "fontWeight": "bold",
+ "fontSize": "16px",
+ },
+ ),
+ html.Span(
+ f" ({len(cluster.points)} points)",
+ style={
+ "color": "#2c3e50",
+ "fontSize": "14px",
+ },
+ ),
+ html.Button(
+ "✏️",
+ id={
+ "type": "edit-cluster-name",
+ "index": cluster.id,
+ },
+ style={
+ "marginLeft": "10px",
+ "backgroundColor": "transparent",
+ "border": "none",
+ "cursor": "pointer",
+ "fontSize": "12px",
+ },
+ title="Edit cluster name",
+ ),
+ ],
+ style={
+ "marginBottom": "10px",
+ "borderBottom": f"2px solid {cluster_colors[cluster_idx]}",
+ "paddingBottom": "5px",
+ "position": "sticky",
+ "top": "0",
+ "backgroundColor": "white",
+ "zIndex": "1",
+ },
+ ),
+ html.Div(
+ all_channel_images,
+ style={
+ "overflowX": "auto",
+ "overflowY": "auto",
+ "height": "400px",
+ "backgroundColor": "#ffffff",
+ "padding": "10px",
+ "borderRadius": "8px",
+ "boxShadow": "0 2px 4px rgba(0,0,0,0.1)",
+ },
+ ),
+ ],
+ style={
+ "width": "24%",
+ "display": "inline-block",
+ "verticalAlign": "top",
+ "padding": "5px",
+ "boxSizing": "border-box",
+ },
+ )
+ )
+ else:
+ # Show a message if no images were found for this cluster
+ cluster_name = (
+ cluster.name if cluster.name else f"Cluster {cluster_idx + 1}"
+ )
+ cluster_panels.append(
+ html.Div(
+ [
+ html.Div(
+ [
+ html.Span(
+ cluster_name,
+ style={
+ "color": cluster_colors[cluster_idx],
+ "fontWeight": "bold",
+ "fontSize": "16px",
+ },
+ ),
+ html.Span(
+ f" ({len(cluster.points)} points)",
+ style={
+ "color": "#2c3e50",
+ "fontSize": "14px",
+ },
+ ),
+ ],
+ style={
+ "marginBottom": "10px",
+ "borderBottom": f"2px solid {cluster_colors[cluster_idx]}",
+ "paddingBottom": "5px",
+ },
+ ),
+ html.Div(
+ [
+ html.P("No images available for this cluster."),
+ html.P("This might be because:"),
+ html.Ul(
+ [
+ html.Li("Images haven't been preloaded"),
+ html.Li(
+ "Cache keys don't match the cluster points"
+ ),
+ html.Li(
+ "Channels are not configured correctly"
+ ),
+ ]
+ ),
+ html.P(
+ f"Debug info: {len(cluster.points)} points in cluster"
+ ),
+ ],
+ style={
+ "padding": "20px",
+ "backgroundColor": "#f8f9fa",
+ "borderRadius": "8px",
+ "textAlign": "center",
+ "color": "#6c757d",
+ },
+ ),
+ ],
+ style={
+ "width": "24%",
+ "display": "inline-block",
+ "verticalAlign": "top",
+ "padding": "5px",
+ "boxSizing": "border-box",
+ },
+ )
+ )
+
+ # If no cluster panels were created, show a debug message
+ if not cluster_panels:
+ return html.Div(
+ [
+ html.H2("Clusters", style={"marginBottom": "20px"}),
+ html.Div(
+ [
+ html.P("No cluster images could be displayed."),
+ html.P("Debug information:"),
+ html.Ul(
+ [
+ html.Li(
+ f"Number of clusters: {len(self.cluster_manager.clusters)}"
+ ),
+ html.Li(
+ f"Image cache size: {len(self.image_cache)}"
+ ),
+ html.Li(
+ f"Channels to display: {all_channels_in_cluster}"
+ ),
+ ]
+ ),
+ ],
+ style={
+ "padding": "20px",
+ "backgroundColor": "#f8f9fa",
+ "borderRadius": "8px",
+ "margin": "20px",
+ },
+ ),
+ ]
+ )
+
+ # Create rows of 4 panels each
+ rows = []
+ for i in range(0, len(cluster_panels), 4):
+ row = html.Div(
+ cluster_panels[i : i + 4],
+ style={
+ "display": "flex",
+ "justifyContent": "flex-start",
+ "gap": "10px",
+ "marginBottom": "10px",
+ },
+ )
+ rows.append(row)
+
+ return html.Div(
+ [
+ html.H2(
+ [
+ "Clusters ",
+ html.Span(
+ f"({len(self.cluster_manager.clusters)} total)",
+ style={"color": "#666"},
+ ),
+ ],
+ style={
+ "marginBottom": "20px",
+ "fontSize": "28px",
+ "fontWeight": "bold",
+ "color": "#2c3e50",
+ },
+ ),
+ html.Div(
+ rows,
+ style={
+ "maxHeight": "calc(100vh - 200px)",
+ "overflowY": "auto",
+ "padding": "10px",
+ },
+ ),
+ ]
+ )
+
+ @staticmethod
+ def _normalize_image(img_array):
+ """Normalize a single image array to [0, 255] more efficiently"""
+ min_val = img_array.min()
+ max_val = img_array.max()
+ if min_val == max_val:
+ return np.zeros_like(img_array, dtype=np.uint8)
+ # Normalize in one step
+ return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8)
+
+ @staticmethod
+ def _numpy_to_base64(img_array):
+ """Convert numpy array to base64 string with compression"""
+ import base64
+ from io import BytesIO
+
+ from PIL import Image
+
+ if img_array.dtype != np.uint8:
+ img_array = img_array.astype(np.uint8)
+ img = Image.fromarray(img_array)
+ buffered = BytesIO()
+ # Use JPEG format with quality=85 for better compression
+ img.save(buffered, format="JPEG", quality=85, optimize=True)
+ return "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode(
+ "utf-8"
+ )
+
+ def preload_images(self):
+ """Preload all images into memory for all datasets"""
+ from viscy.data.triplet import TripletDataModule
+
+ # Try to load from cache first
+ if self.cache_path and self.load_cache():
+ logger.info("Preloading images into cache...")
+ return
+
+ logger.info("Cache not found, preloading and caching images...")
+ # Process each dataset
+ for dataset_name, dataset_config in self.datasets.items():
+ logger.info(f"Processing dataset: {dataset_name}")
+
+ if (
+ not hasattr(dataset_config, "fov_tracks")
+ or not dataset_config.fov_tracks
+ ):
+ logger.info(f"Skipping dataset {dataset_name} as it has no FOV tracks")
+ continue
+
+ # Process each FOV and resolve track IDs
+ fov_track_mapping = {}
+ for fov_name, tracks in dataset_config.fov_tracks.items():
+ if isinstance(tracks, list):
+ fov_track_mapping[fov_name] = tracks
+ elif tracks == "all":
+ # Get all tracks for this FOV from features
+ if self.features_df is not None:
+ fov_tracks_series = self.features_df[
+ (self.features_df["dataset"] == dataset_name)
+ & (self.features_df["fov_name"] == fov_name)
+ ]["track_id"]
+ fov_track_mapping[fov_name] = (
+ fov_tracks_series.unique().tolist()
+ )
+ logger.info(
+ f"Resolved 'all' tracks for FOV {fov_name}: {len(fov_track_mapping[fov_name])} tracks"
+ )
+ else:
+ logger.warning(
+ f"Cannot resolve 'all' tracks for FOV {fov_name}: features_df is None"
+ )
+ fov_track_mapping[fov_name] = []
+ else:
+ logger.warning(
+ f"Unknown track specification for FOV {fov_name}: {tracks}"
+ )
+ fov_track_mapping[fov_name] = []
+
+ logger.debug(f"FOV-track mapping for {dataset_name}: {fov_track_mapping}")
+
+ # Process each FOV and its resolved tracks
+ for fov_name, track_ids in fov_track_mapping.items():
+ if not track_ids: # Skip FOVs with no tracks
+ logger.debug(f"Skipping FOV {fov_name} as it has no tracks")
+ continue
+
+ logger.debug(
+ f"Processing FOV {fov_name} with {len(track_ids)} tracks: {track_ids}"
+ )
+
+ try:
+ data_module = TripletDataModule(
+ data_path=dataset_config.data_path,
+ tracks_path=dataset_config.tracks_path,
+ include_fov_names=[fov_name] * len(track_ids),
+ include_track_ids=track_ids,
+ source_channel=dataset_config.channels_to_display,
+ z_range=dataset_config.z_range,
+ initial_yx_patch_size=dataset_config.yx_patch_size,
+ final_yx_patch_size=dataset_config.yx_patch_size,
+ batch_size=1,
+ num_workers=self.num_loading_workers,
+ normalizations=[],
+ predict_cells=True,
+ )
+ data_module.setup("predict")
+
+ for batch in data_module.predict_dataloader():
+ try:
+ images = batch["anchor"].numpy()
+ indices = batch["index"]
+ track_id = indices["track_id"].item()
+ t = indices["t"].item()
+
+ img = np.stack(images)
+
+ cache_key = (dataset_name, fov_name, track_id, t)
+
+ logger.debug(f"Processing cache key: {cache_key}")
+
+ # Process each channel based on its type
+ processed_channels = {}
+ for idx, channel in enumerate(
+ dataset_config.channels_to_display
+ ):
+ try:
+ if channel in ["Phase3D", "DIC", "BF"]:
+ # For phase contrast, use the middle z-slice
+ z_idx = (
+ dataset_config.z_range[1]
+ - dataset_config.z_range[0]
+ ) // 2
+ processed = self._normalize_image(
+ img[0, idx, z_idx]
+ )
+ else:
+ # For fluorescence, use max projection
+ processed = self._normalize_image(
+ np.max(img[0, idx], axis=0)
+ )
+
+ processed_channels[channel] = self._numpy_to_base64(
+ processed
+ )
+ logger.debug(
+ f"Successfully processed channel {channel} for {cache_key}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Error processing channel {channel} for {cache_key}: {e}"
+ )
+ continue
+
+ if (
+ processed_channels
+ ): # Only store if at least one channel was processed
+ self.image_cache[cache_key] = processed_channels
+
+ except Exception as e:
+ logger.error(
+ f"Error processing batch for {fov_name}, track {track_id}: {e}"
+ )
+ continue
+
+ except Exception as e:
+ logger.error(
+ f"Error setting up data module for FOV {fov_name}: {e}"
+ )
+ continue
+
+ # Log some statistics about the cache
+ cached_datasets = set(key[0] for key in self.image_cache.keys())
+ cached_fovs = set((key[0], key[1]) for key in self.image_cache.keys())
+ cached_tracks = set((key[0], key[1], key[2]) for key in self.image_cache.keys())
+ logger.info(f"Cached datasets: {cached_datasets}")
+ logger.debug(f"Cached dataset-FOV combinations: {len(cached_fovs)}")
+ logger.debug(
+ f"Number of unique dataset-FOV-track combinations: {len(cached_tracks)}"
+ )
+
+ # Save cache if path is specified
+ if self.cache_path:
+ self.save_cache()
+
+ def save_cache(self, cache_path: Union[str, Path, None] = None):
+ """Save the image cache to disk using pickle.
+
+ Parameters
+ ----------
+ cache_path : Union[str, Path, None], optional
+ Path to save the cache. If None, uses self.cache_path, by default None
+ """
+ import pickle
+
+ if cache_path is None:
+ if self.cache_path is None:
+ logger.warning("No cache path specified, skipping cache save")
+ return
+ cache_path = self.cache_path
+ else:
+ cache_path = Path(cache_path)
+
+ # Create parent directory if it doesn't exist
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Save cache metadata for validation
+ cache_metadata = {
+ "datasets": {
+ name: {
+ "data_path": str(config.data_path),
+ "tracks_path": str(config.tracks_path),
+ "features_path": str(config.features_path),
+ "channels": config.channels_to_display,
+ "z_range": config.z_range,
+ "yx_patch_size": config.yx_patch_size,
+ }
+ for name, config in self.datasets.items()
+ },
+ "cache_size": len(self.image_cache),
+ }
+
+ try:
+ logger.info(f"Saving image cache to {cache_path}")
+ with open(cache_path, "wb") as f:
+ pickle.dump((cache_metadata, self.image_cache), f)
+ except Exception as e:
+ logger.error(f"Error saving cache: {e}")
+
+ def load_cache(self, cache_path: Union[str, Path, None] = None) -> bool:
+ """Load the image cache from disk using pickle.
+
+ Parameters
+ ----------
+ cache_path : Union[str, Path, None], optional
+ Path to load the cache from. If None, uses self.cache_path
+ """
+ import pickle
+
+ if cache_path is None:
+ if self.cache_path is None:
+ logger.warning("No cache path specified, skipping cache load")
+ return False
+ cache_path = self.cache_path
+ else:
+ cache_path = Path(cache_path)
+
+ try:
+ logger.info(f"Loading image cache from {cache_path}")
+ with open(cache_path, "rb") as f:
+ cache_metadata, self.image_cache = pickle.load(f)
+ logger.info(
+ f"Successfully loaded cache with {len(self.image_cache)} images"
+ )
+ return True
+ except Exception as e:
+ logger.error(f"Error loading cache: {e}")
+ return False
+
+ def run(self, debug=False, port=None):
+ """Run the Dash server
+
+ Parameters
+ ----------
+ debug : bool, optional
+ Whether to run in debug mode, by default False
+ port : int, optional
+ Port to run on. If None, will try ports from 8050-8070, by default None
+ """
+ import socket
+
+ def is_port_in_use(port):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ try:
+ s.bind(("127.0.0.1", port))
+ return False
+ except socket.error:
+ return True
+
+ if port is None:
+ # Try ports from 8050 to 8070
+ port_range = list(range(8050, 8071))
+ for p in port_range:
+ if not is_port_in_use(p):
+ port = p
+ break
+ if port is None:
+ raise RuntimeError(
+ f"Could not find an available port in range {port_range[0]}-{port_range[-1]}"
+ )
+
+ try:
+ logger.info(f"Starting server on port {port}")
+ if self.app is not None:
+ self.app.run(
+ debug=debug,
+ port=port,
+ use_reloader=False,
+ )
+ except KeyboardInterrupt:
+ logger.info("Server shutdown requested...")
+ except Exception as e:
+ logger.error(f"Error running server: {e}")
+ finally:
+ self._cleanup_cache()
+ logger.info("Server shutdown complete")
diff --git a/viscy/representation/visualization/cluster.py b/viscy/representation/visualization/cluster.py
new file mode 100644
index 000000000..39fc0b404
--- /dev/null
+++ b/viscy/representation/visualization/cluster.py
@@ -0,0 +1,262 @@
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Dict, List, Optional, Set, Tuple
+
+
+@dataclass
+class ClusterPoint:
+ """Represents a single point in a cluster."""
+
+ track_id: int
+ t: int
+ fov_name: str
+ dataset: str
+ x_coord: Optional[float] = None
+ y_coord: Optional[float] = None
+ z_coord: Optional[float] = None
+
+ @property
+ def cache_key(self) -> Tuple[str, str, int, int]:
+ """Get the cache key for this point."""
+ return (self.dataset, self.fov_name, self.track_id, self.t)
+
+ @property
+ def unique_track_id(self) -> str:
+ """Get a globally unique track identifier."""
+ return f"{self.dataset}_{self.track_id}"
+
+ def __eq__(self, other) -> bool:
+ """Two points are equal if they have the same cache key."""
+ if not isinstance(other, ClusterPoint):
+ return False
+ return self.cache_key == other.cache_key
+
+ def __hash__(self) -> int:
+ """Hash based on cache key for use in sets."""
+ return hash(self.cache_key)
+
+
+@dataclass
+class Cluster:
+ """Represents a cluster of points with metadata."""
+
+ points: List[ClusterPoint] = field(default_factory=list)
+ name: str = ""
+ color: str = ""
+ created_at: datetime = field(default_factory=datetime.now)
+ _id: str = field(default_factory=lambda: str(uuid.uuid4()))
+
+ @property
+ def id(self) -> str:
+ """Unique identifier for this cluster."""
+ return self._id
+
+ @property
+ def size(self) -> int:
+ """Number of points in this cluster."""
+ return len(self.points)
+
+ @property
+ def datasets(self) -> Set[str]:
+ """Get all datasets represented in this cluster."""
+ return {point.dataset for point in self.points}
+
+ @property
+ def cache_keys(self) -> Set[Tuple[str, str, int, int]]:
+ """Get all cache keys for points in this cluster."""
+ return {point.cache_key for point in self.points}
+
+ def add_point(self, point: ClusterPoint) -> None:
+ """Add a point to this cluster."""
+ if point not in self.points:
+ self.points.append(point)
+
+ def remove_point(self, point: ClusterPoint) -> bool:
+ """Remove a point from this cluster. Returns True if point was found and removed."""
+ try:
+ self.points.remove(point)
+ return True
+ except ValueError:
+ return False
+
+ def contains_cache_key(self, cache_key: Tuple[str, str, int, int]) -> bool:
+ """Check if this cluster contains a point with the given cache key."""
+ return cache_key in self.cache_keys
+
+ def get_default_name(self, cluster_number: int) -> str:
+ """Generate a default name for this cluster."""
+ if len(self.datasets) == 1:
+ dataset_name = list(self.datasets)[0]
+ return f"{dataset_name}: Cluster {cluster_number}"
+ else:
+ datasets_str = ", ".join(sorted(self.datasets))
+ return f"[{datasets_str}]: Cluster {cluster_number}"
+
+ def to_dict_list(self) -> List[Dict]:
+ """Convert cluster points to the legacy dictionary format for compatibility."""
+ return [
+ {
+ "track_id": point.track_id,
+ "t": point.t,
+ "fov_name": point.fov_name,
+ "dataset": point.dataset,
+ }
+ for point in self.points
+ ]
+
+
+class ClusterManager:
+ """Manages a collection of clusters."""
+
+ def __init__(self):
+ self._clusters: List[Cluster] = []
+
+ @property
+ def clusters(self) -> List[Cluster]:
+ """Get all clusters."""
+ return self._clusters
+
+ @property
+ def cluster_count(self) -> int:
+ """Get the number of clusters."""
+ return len(self._clusters)
+
+ @property
+ def all_cluster_points(self) -> Set[Tuple[str, str, int, int]]:
+ """Get all cache keys from all clusters."""
+ all_keys = set()
+ for cluster in self._clusters:
+ all_keys.update(cluster.cache_keys)
+ return all_keys
+
+ def add_cluster(self, cluster: Cluster) -> str:
+ """Add a cluster to the manager and return its ID."""
+ self._clusters.append(cluster)
+ return cluster.id
+
+ def create_cluster_from_points(
+ self, points_data: List[Dict], name: str = ""
+ ) -> str:
+ """Create a new cluster from point data and add it to the manager."""
+ cluster = Cluster()
+
+ for point_data in points_data:
+ point = ClusterPoint(
+ track_id=point_data["track_id"],
+ t=point_data["t"],
+ fov_name=point_data["fov_name"],
+ dataset=point_data["dataset"],
+ )
+ cluster.add_point(point)
+
+ if name:
+ cluster.name = name
+ else:
+ # Generate default name
+ cluster.name = cluster.get_default_name(len(self._clusters) + 1)
+
+ return self.add_cluster(cluster)
+
+ def remove_cluster(self, cluster_id: str) -> bool:
+ """Remove a cluster by ID. Returns True if cluster was found and removed."""
+ for i, cluster in enumerate(self._clusters):
+ if cluster.id == cluster_id:
+ del self._clusters[i]
+ return True
+ return False
+
+ def remove_last_cluster(self) -> Optional[Cluster]:
+ """Remove and return the most recently added cluster."""
+ if self._clusters:
+ return self._clusters.pop()
+ return None
+
+ def get_cluster_by_id(self, cluster_id: str) -> Optional[Cluster]:
+ """Get a cluster by its ID."""
+ for cluster in self._clusters:
+ if cluster.id == cluster_id:
+ return cluster
+ return None
+
+ def get_cluster_by_index(self, index: int) -> Optional[Cluster]:
+ """Get a cluster by its index (for backward compatibility)."""
+ if 0 <= index < len(self._clusters):
+ return self._clusters[index]
+ return None
+
+ def clear_all_clusters(self) -> None:
+ """Remove all clusters."""
+ self._clusters.clear()
+
+ def get_cluster_colors(self) -> Dict[str, str]:
+ """Get a mapping of cluster IDs to their colors."""
+ import matplotlib.pyplot as plt
+
+ colors = {}
+ for i, cluster in enumerate(self._clusters):
+ if not cluster.color:
+ # Auto-assign color if not set
+ cmap = plt.cm.get_cmap("Set2")
+ cluster.color = f"rgb{tuple(int(x * 255) for x in cmap(i % 8)[:3])}"
+ colors[cluster.id] = cluster.color
+ return colors
+
+ def get_cluster_colors_by_index(self) -> List[str]:
+ """Get cluster colors as a list (for backward compatibility)."""
+ import matplotlib.pyplot as plt
+
+ colors = []
+ for i, cluster in enumerate(self._clusters):
+ if not cluster.color:
+ cmap = plt.cm.get_cmap("Set2")
+ cluster.color = f"rgb{tuple(int(x * 255) for x in cmap(i % 8)[:3])}"
+ colors.append(cluster.color)
+ return colors
+
+ def is_point_in_any_cluster(self, cache_key: Tuple[str, str, int, int]) -> bool:
+ """Check if a point is in any cluster."""
+ return any(cluster.contains_cache_key(cache_key) for cluster in self._clusters)
+
+ def get_cluster_containing_point(
+ self, cache_key: Tuple[str, str, int, int]
+ ) -> Optional[Cluster]:
+ """Get the cluster containing a specific point."""
+ for cluster in self._clusters:
+ if cluster.contains_cache_key(cache_key):
+ return cluster
+ return None
+
+ def get_point_to_cluster_mapping(self) -> Dict[Tuple[str, str, int, int], int]:
+ """Get a mapping from cache keys to cluster indices (for backward compatibility)."""
+ point_to_cluster = {}
+ for cluster_idx, cluster in enumerate(self._clusters):
+ for cache_key in cluster.cache_keys:
+ point_to_cluster[cache_key] = cluster_idx
+ return point_to_cluster
+
+ def update_cluster_name(self, cluster_id: str, new_name: str) -> bool:
+ """Update a cluster's name. Returns True if successful."""
+ cluster = self.get_cluster_by_id(cluster_id)
+ if cluster:
+ cluster.name = new_name
+ return True
+ return False
+
+ def update_cluster_name_by_index(self, index: int, new_name: str) -> bool:
+ """Update a cluster's name by index (for backward compatibility)."""
+ cluster = self.get_cluster_by_index(index)
+ if cluster:
+ cluster.name = new_name
+ return True
+ return False
+
+ def get_cluster_names_by_index(self) -> Dict[int, str]:
+ """Get cluster names mapped by index (for backward compatibility)."""
+ return {i: cluster.name for i, cluster in enumerate(self._clusters)}
+
+ def to_legacy_format(self) -> Tuple[List[List[Dict]], Dict[int, str]]:
+ """Convert to the legacy format for backward compatibility."""
+ clusters_data = [cluster.to_dict_list() for cluster in self._clusters]
+ cluster_names = self.get_cluster_names_by_index()
+ return clusters_data, cluster_names
diff --git a/viscy/representation/visualization/settings.py b/viscy/representation/visualization/settings.py
new file mode 100644
index 000000000..d90520103
--- /dev/null
+++ b/viscy/representation/visualization/settings.py
@@ -0,0 +1,84 @@
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from pydantic import BaseModel, Field, field_validator
+
+
+class DatasetConfig(BaseModel):
+ """Configuration for a single dataset."""
+
+ features_path: str
+ data_path: str
+ tracks_path: str
+ channels_to_display: List[str]
+ z_range: Tuple[int, int]
+ yx_patch_size: Tuple[int, int]
+ fov_tracks: Dict[str, Union[List[int], str]] = Field(default_factory=dict)
+
+ @field_validator("features_path", "data_path", "tracks_path")
+ @classmethod
+ def validate_paths(cls, v):
+ if not Path(v).exists():
+ logging.warning(f"Path does not exist: {v}")
+ return v
+
+
+class VizConfig(BaseModel):
+ """Configuration for visualization app."""
+
+ datasets: Dict[str, DatasetConfig] = Field(default_factory=dict)
+
+ num_PC_components: int = Field(default=8, ge=1, le=10)
+ phate_kwargs: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description="PHATE parameters. If None, PHATE will not be computed.",
+ )
+
+ # Combined analysis options
+ use_cached_combined_phate: bool = Field(
+ default=True,
+ description="Use cached combined PHATE results if available",
+ )
+ combined_phate_cache_path: Optional[str] = Field(
+ default=None,
+ description="Path to cache combined PHATE results. If None, uses cache_path/combined_phate.zarr",
+ )
+
+ # File system paths
+ output_dir: Optional[str] = Field(
+ default=None,
+ description="Directory to save CSV files and other outputs. If None, uses current working directory.",
+ )
+ cache_path: Optional[str] = Field(
+ default=None,
+ description="Path to save/load image cache. If None, images will not be cached to disk.",
+ )
+
+ @field_validator("output_dir", "cache_path", "combined_phate_cache_path")
+ @classmethod
+ def validate_optional_paths(cls, v):
+ if v is not None:
+ # Create parent directory if it doesn't exist
+ path = Path(v)
+ if not path.parent.exists():
+ logging.info(f"Creating parent directory for: {v}")
+ path.parent.mkdir(parents=True, exist_ok=True)
+ return v
+
+ def get_datasets(self) -> Dict[str, DatasetConfig]:
+ """Get the datasets configuration."""
+ return self.datasets
+
+ def get_all_fov_tracks(self) -> Dict[str, Union[List[int], str]]:
+ """Get all FOV tracks from all datasets combined."""
+ all_fov_tracks = {}
+
+ for dataset_name, dataset_config in self.datasets.items():
+ all_fov_tracks[dataset_name] = {}
+ for fov_name, track_ids in dataset_config.fov_tracks.items():
+ # Keep track IDs as original integers
+ # Uniqueness will be handled by (dataset, track_id) tuple
+ all_fov_tracks[dataset_name][fov_name] = track_ids
+
+ return all_fov_tracks
diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py
index 4223e6784..fdbbf7e9b 100644
--- a/viscy/utils/cli_utils.py
+++ b/viscy/utils/cli_utils.py
@@ -1,9 +1,11 @@
import collections
import os
import re
+from pathlib import Path
import numpy as np
import torch
+import yaml
from PIL import Image
@@ -117,3 +119,55 @@ def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"):
im = Image.fromarray(data).convert("L")
im.info["size"] = data.shape
im.save(os.path.join(save_folder, name + ext))
+
+
+def yaml_to_model(yaml_path: Path, model):
+ """
+ Load model settings from a YAML file and create a model instance.
+
+ Borrowing from recOrder==0.4.0
+
+ Parameters
+ ----------
+ yaml_path : Path
+ The path to the YAML file containing the model settings.
+ model : class
+ The model class used to create an instance with the loaded settings.
+
+ Returns
+ -------
+ object
+ An instance of the model class with the loaded settings.
+
+ Raises
+ ------
+ TypeError
+ If the provided model is not a class or does not have a callable constructor.
+ FileNotFoundError
+ If the YAML file specified by `yaml_path` does not exist.
+
+ Notes
+ -----
+ This function loads model settings from a YAML file using `yaml.safe_load()`.
+ It then creates an instance of the provided `model` class using the loaded settings.
+
+ Examples
+ --------
+ >>> from my_model import MyModel
+ >>> model = yaml_to_model('model.yaml', MyModel)
+
+ """
+ yaml_path = Path(yaml_path)
+
+ if not callable(getattr(model, "__init__", None)):
+ raise TypeError(
+ "The provided model must be a class with a callable constructor."
+ )
+
+ try:
+ with open(yaml_path, "r") as file:
+ raw_settings = yaml.safe_load(file)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"The YAML file '{yaml_path}' does not exist.")
+
+ return model(**raw_settings)