Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,45 +1,47 @@
"""Interactive visualization of phenotype data."""

import logging
from pathlib import Path

import click
from numpy.random import seed

from viscy.representation.evaluation.visualization import EmbeddingVisualizationApp
from viscy.representation.visualization.app import EmbeddingVisualizationApp
from viscy.representation.visualization.settings import VizConfig
from viscy.utils.cli_utils import yaml_to_model

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

seed(42)


def main():
@click.command()
@click.option(
"--config-filepath",
"-c",
required=True,
help="Path to YAML configuration file.",
)
def main(config_filepath):
"""Main function to run the visualization app."""

# Config for the visualization app
# TODO: Update the paths to the downloaded data. By default the data is downloaded to ~/data/dynaclr/demo
download_root = Path.home() / "data/dynaclr/demo"
output_path = Path.home() / "data/dynaclr/demo/embedding-web-visualization"
viz_config = {
"data_path": download_root / "registered_test.zarr", # TODO add path to data
"tracks_path": download_root / "track_test.zarr", # TODO add path to tracks
"features_path": download_root
/ "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr", # TODO add path to features
"channels_to_display": ["Phase3D", "RFP"],
"fov_tracks": {
"/A/3/9": list(range(50)),
"/B/4/9": list(range(50)),
},
"yx_patch_size": (160, 160),
"z_range": (24, 29),
"num_PC_components": 8,
"output_dir": output_path,
}
# Load and validate configuration from YAML file
viz_config = yaml_to_model(yaml_path=config_filepath, model=VizConfig)

# Use configured paths, with fallbacks to current defaults if not specified
output_dir = viz_config.output_dir or str(Path(__file__).parent / "output")
cache_path = viz_config.cache_path

logger.info(f"Using output directory: {output_dir}")
logger.info(f"Using cache path: {cache_path}")

# Create and run the visualization app
try:
# Create and run the visualization app
app = EmbeddingVisualizationApp(**viz_config)
app = EmbeddingVisualizationApp(
viz_config=viz_config,
cache_path=cache_path,
num_loading_workers=16,
output_dir=output_dir,
)
app.preload_images()
app.run(debug=True)

Expand Down
52 changes: 52 additions & 0 deletions examples/DynaCLR/embedding-web-visualization/viz_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Multi-dataset visualization configuration
datasets:
g3bp1_sensor:
data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/sensor_160patch_99ckpt_max.zarr"
channels_to_display: ["raw mCherry EX561 EM600-37"]
z_range: [0, 1]
yx_patch_size: [192, 192]
fov_tracks:
"/B/3/000000": [42, 47, 56, 57, 58, 59, 61]
"/B/4/000001": [57, 83, 89, 90, 91, 101]


g3bp1_organelle:
data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/predictions/organelle_160patch_99ckpt_max.zarr"
channels_to_display: ["raw GFP EX488 EM525-45"]
z_range: [0, 1]
yx_patch_size: [192, 192]
fov_tracks:
# "/B/3/000000": "all"
# "/B/4/000001": "all"
"/B/2/000000": [15, 47, 61, 62, 63, 64, 65]
"/C/2/000001": "all"
# "/C/2/000000": []

# sec61_organelle:
# data_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/4-phenotyping/train-test/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr"
# tracks_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_01_28_A549_G3BP1_ZIKV_DENV/1-preprocess/label-free/3-track/2025_01_28_A549_G3BP1_ZIKV_DENV_cropped.zarr"
# features_path: "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/2-predictions/organelle_160patch_99ckpt_max.zarr"
# channels_to_display: ["raw GFP EX488 EM525-45"]
# z_range: [0, 1]
# yx_patch_size: [192, 192]

# Global settings
num_PC_components: 8
# phate_kwargs:
# n_components: 4 # Number of PHATE components
# knn: 5 # Number of nearest neighbors for KNN graph
# decay: 40 # Decay parameter for Markov operator
# n_jobs: -1 # Number of parallel jobs (-1 uses all cores)
# random_state: 42 # Random seed for reproducibility
# # gamma: 1.0 # Informational scale parameter (optional)
# # t: "auto" # Number of diffusion steps (optional, auto-selects optimal)
# # a: 1 # Alpha parameter for alpha-decay kernel (optional)
# # verbose: 0 # Verbosity level (optional)

# File system paths
output_dir: "./output" # Directory to save CSV files and other outputs
cache_path: "/home/eduardo.hirata/mydata/tmp/pcviewer/cache/image_cache1.pkl" # Path to save/load image cache
172 changes: 172 additions & 0 deletions viscy/representation/evaluation/combined_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging
from pathlib import Path
from typing import Literal, Optional

import numpy as np
import pandas as pd
from xarray import Dataset

from viscy.representation.embedding_writer import (
read_embedding_dataset,
write_embedding_dataset,
)
from viscy.representation.evaluation.data_loading import (
EmbeddingDataLoader,
TripletEmbeddingLoader,
)
from viscy.representation.evaluation.dimensionality_reduction import compute_phate

__all__ = ["load_and_combine_features", "compute_phate_for_combined_datasets"]

_logger = logging.getLogger("lightning.pytorch")


def load_and_combine_features(
feature_paths: list[Path],
dataset_names: Optional[list[str]] = None,
loader: Literal[
EmbeddingDataLoader, TripletEmbeddingLoader
] = TripletEmbeddingLoader(),
) -> tuple[np.ndarray, pd.DataFrame]:
"""
Load features from multiple datasets and combine them using a pluggable loader.

Parameters
----------
feature_paths : list[Path]
Paths to embedding datasets
dataset_names : list[str], optional
Names for datasets. If None, uses file stems
loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional
Custom data loader. If None, uses TripletEmbeddingLoader

Returns
-------
tuple[np.ndarray, pd.DataFrame]
Combined features array and index DataFrame with dataset_pair column
"""
if dataset_names is None:
dataset_names = [path.stem for path in feature_paths]

if len(dataset_names) != len(feature_paths):
raise ValueError("Number of dataset names must match number of feature paths")

all_features = []
all_indices = []

for path, dataset_name in zip(feature_paths, dataset_names):
_logger.info(f"Loading features from {path}")

dataset = loader.load_dataset(path)
features = loader.extract_features(dataset)
index_df = loader.extract_metadata(dataset)

index_df["dataset_pair"] = dataset_name
index_df["dataset_path"] = str(path)

all_features.append(features)
all_indices.append(index_df)

_logger.info(f"Loaded {len(features)} samples from {dataset_name}")

combined_features = np.vstack(all_features)
combined_indices = pd.concat(all_indices, ignore_index=True)

_logger.info(
f"Combined {len(combined_features)} total samples from {len(feature_paths)} datasets"
)

return combined_features, combined_indices


def compute_phate_for_combined_datasets(
feature_paths: list[Path],
output_path: Path,
dataset_names: Optional[list[str]] = None,
phate_kwargs: Optional[dict] = None,
overwrite: bool = False,
loader: Literal[
EmbeddingDataLoader, TripletEmbeddingLoader
] = TripletEmbeddingLoader(),
) -> Dataset:
"""
Compute PHATE embeddings on combined features from multiple datasets.

Parameters
----------
feature_paths : list[Path]
List of paths to zarr stores containing embedding datasets
output_path : Path
Path to save the combined dataset with PHATE embeddings
dataset_names : list[str], optional
Names for each dataset. If None, uses file stems
phate_kwargs : dict, optional
Parameters for PHATE computation. Default: {"knn": 5, "decay": 40, "n_components": 2}
overwrite : bool, optional
Whether to overwrite existing output file
loader : EmbeddingDataLoader | TripletEmbeddingLoader, optional
Custom data loader. If None, uses TripletEmbeddingLoader

Returns
-------
Dataset
Combined xarray dataset with original features and PHATE coordinates

Raises
------
FileExistsError
If output_path exists and overwrite is False
ImportError
If PHATE is not installed
"""
output_path = Path(output_path)

if output_path.exists() and not overwrite:
raise FileExistsError(
f"Output path {output_path} already exists. Use overwrite=True to overwrite."
)

if phate_kwargs is None:
phate_kwargs = {"knn": 5, "decay": 40, "n_components": 2}

_logger.info(
f"Computing PHATE for combined datasets with parameters: {phate_kwargs}"
)

combined_features, combined_indices = load_and_combine_features(
feature_paths, dataset_names, loader
)

n_samples = len(combined_features)
if phate_kwargs.get("knn", 5) >= n_samples:
original_knn = phate_kwargs["knn"]
phate_kwargs["knn"] = max(2, n_samples // 2)
_logger.warning(
f"Reducing knn from {original_knn} to {phate_kwargs['knn']} due to dataset size ({n_samples} samples)"
)

_logger.info("Computing PHATE embeddings on combined features")
try:
_, phate_embedding = compute_phate(combined_features, **phate_kwargs)
_logger.info(
f"PHATE computation successful, embedding shape: {phate_embedding.shape}"
)
except Exception as e:
_logger.error(f"PHATE computation failed: {str(e)}")
raise

_logger.info(f"Writing combined dataset with PHATE embeddings to {output_path}")
write_embedding_dataset(
output_path=output_path,
features=combined_features,
index_df=combined_indices,
phate_kwargs=phate_kwargs,
overwrite=overwrite,
)

result_dataset = read_embedding_dataset(output_path)
_logger.info(
f"Successfully created combined dataset with {len(result_dataset.sample)} samples"
)

return result_dataset
Loading
Loading