diff --git a/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml new file mode 100644 index 000000000..195aa6db9 --- /dev/null +++ b/applications/benchmarking/DynaCLR/DINOV3/config_dinov3_convnext_tiny.yml @@ -0,0 +1,65 @@ +datamodule_class: viscy.data.triplet.TripletDataModule +datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr + batch_size: 32 + final_yx_patch_size: + - 256 + - 256 + include_fov_names: null + include_track_ids: null + initial_yx_patch_size: + - 256 + - 256 + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - RFP + lower: 50 + upper: 99 + - class_path: viscy.transforms.NormalizeIntensityd + init_args: + keys: + - Phase3D + num_workers: 10 + source_channel: + - RFP + - Phase3D + z_range: + - 15 + - 45 + +embedding: + pca_kwargs: + n_components: 8 + phate_kwargs: + decay: 40 + knn: 5 + n_components: 2 + n_jobs: -1 + random_state: 42 + reductions: + - PHATE + - PCA + +execution: + overwrite: false + save_config: true + show_config: true + +model: + model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + pooling_method: mean # Options: "mean", "max", "cls_token" + middle_slice_index: 18 # Specific z-slice index (if null, uses D//2) + channel_reduction_methods: + Phase3D: middle_slice + RFP: max + channel_names: + - RFP + - Phase3D + +paths: + output_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_mean.zarr diff --git a/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py new file mode 100644 index 000000000..38164e523 --- /dev/null +++ b/applications/benchmarking/DynaCLR/DINOV3/dinov3_embeddings.py @@ -0,0 +1,170 @@ +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import numpy as np +import torch +from PIL import Image +from skimage.exposure import rescale_intensity +from transformers import AutoImageProcessor, AutoModel + +sys.path.append(str(Path(__file__).parent.parent)) + +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli + + +class DINOv3Module(BaseEmbeddingModule): + def __init__( + self, + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + pooling_method: Literal["mean", "max", "cls_token"] = "mean", + middle_slice_index: Optional[int] = None, + ): + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) + self.model_name = model_name + self.pooling_method = pooling_method + + self.model = None + self.processor = None + + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) + return cls( + model_name=model_config.get( + "model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m" + ), + pooling_method=model_config.get("pooling_method", "mean"), + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + channel_names=model_config.get("channel_names", []), + middle_slice_index=model_config.get("middle_slice_index", None), + ) + + def on_predict_start(self): + if self.model is None: + self.processor = AutoImageProcessor.from_pretrained(self.model_name) + self.model = AutoModel.from_pretrained(self.model_name) + self.model.eval() + self.model.to(self.device) + + def _process_input(self, x: torch.Tensor): + """Convert tensor to PIL Images for DINOv3 processing.""" + return self._convert_to_pil_images(x) + + def _extract_features(self, pil_images): + """Extract features using DINOv3 model.""" + inputs = self.processor(pil_images, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + token_features = outputs.last_hidden_state + features = self._pool_features(token_features) + + return features + + def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]: + """ + Convert tensor to list of PIL Images for DINOv3 processing. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape (B, C, H, W). + + Returns + ------- + list of PIL.Image.Image + List of PIL Images ready for DINOv3 processing. + """ + images = [] + + for b in range(x.shape[0]): + img_tensor = x[b] # (C, H, W) + + if img_tensor.shape[0] == 1: + # Single channel - convert to grayscale PIL + img_array = img_tensor[0].cpu().numpy() + # Normalize to 0-255 + img_normalized = ( + (img_array - img_array.min()) + / (img_array.max() - img_array.min()) + * 255 + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="L") + + elif img_tensor.shape[0] == 2: + img_array = img_tensor.cpu().numpy() + rgb_array = np.zeros( + (img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8 + ) + + ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype( + np.uint8 + ) + ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype( + np.uint8 + ) + + rgb_array[:, :, 0] = ch0_norm # Red + rgb_array[:, :, 1] = ch1_norm # Green + rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue + + pil_img = Image.fromarray(rgb_array, mode="RGB") + + elif img_tensor.shape[0] == 3: + # Three channels - direct RGB + img_array = img_tensor.cpu().numpy().transpose(1, 2, 0) # HWC + img_normalized = rescale_intensity( + img_array, out_range=(0, 255) + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="RGB") + + else: + # More than 3 channels - use first 3 + img_array = img_tensor[:3].cpu().numpy().transpose(1, 2, 0) # HWC + img_normalized = rescale_intensity( + img_array, out_range=(0, 255) + ).astype(np.uint8) + pil_img = Image.fromarray(img_normalized, mode="RGB") + + images.append(pil_img) + + return images + + def _pool_features(self, features: torch.Tensor) -> torch.Tensor: + """ + Pool spatial features from DINOv3 tokens. + + Parameters + ---------- + features : torch.Tensor + Token features with shape (B, num_tokens, hidden_dim). + + Returns + ------- + torch.Tensor + Pooled features with shape (B, hidden_dim). + """ + if self.pooling_method == "cls_token": + # For ViT models, first token is usually CLS token + if "vit" in self.model_name.lower(): + return features[:, 0, :] # CLS token + else: + # For ConvNeXt, no CLS token, fall back to mean + return features.mean(dim=1) + + elif self.pooling_method == "max": + return features.max(dim=1)[0] + else: # mean pooling + return features.mean(dim=1) + + +if __name__ == "__main__": + main = create_embedding_cli(DINOv3Module, "DINOv3") + main() diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml index a8d31f231..826e894a0 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml +++ b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml @@ -2,8 +2,6 @@ # Paths section paths: - data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr - tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr output_path: "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_sec61b_n_phase_3.zarr" # Model configuration @@ -16,7 +14,10 @@ model: "raw GFP EX488 EM525-45": "max" # Data module configuration +datamodule_class: viscy.data.triplet.TripletDataModule datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr source_channel: - Phase3D - "raw GFP EX488 EM525-45" diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py index f07a84a09..11d9eaee9 100644 --- a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py +++ b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py @@ -1,56 +1,27 @@ -""" -Generate embeddings using the OpenPhenom model and save them to a zarr store -using VisCy Trainer and EmbeddingWriter callback. -""" - -import importlib -import logging -import os +import sys from pathlib import Path from typing import Dict, List, Literal, Optional -import click import torch -import yaml -from lightning.pytorch import LightningModule from transformers import AutoModel -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import EmbeddingWriter -from viscy.trainer import VisCyTrainer +sys.path.append(str(Path(__file__).parent.parent)) + +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli -class OpenPhenomModule(LightningModule): +class OpenPhenomModule(BaseEmbeddingModule): def __init__( self, channel_reduction_methods: Optional[ Dict[str, Literal["middle_slice", "mean", "max"]] ] = None, channel_names: Optional[List[str]] = None, + middle_slice_index: Optional[int] = None, ): - """Initialize the OpenPhenom module. - - Parameters - ---------- - channel_reduction_methods : dict, optional - Dictionary mapping channel names to reduction methods: - - "middle_slice": Take the middle slice along the depth dimension - - "mean": Average across the depth dimension - - "max": Take the maximum value across the depth dimension - channel_names : list of str, optional - List of channel names corresponding to the input channels - - Notes - ----- - The module uses the OpenPhenom model from HuggingFace for generating embeddings. - """ - super().__init__() - - self.channel_reduction_methods = channel_reduction_methods or {} - self.channel_names = channel_names or [] + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) try: - torch.set_float32_matmul_precision("high") self.model = AutoModel.from_pretrained( "recursionpharma/OpenPhenom", trust_remote_code=True ) @@ -60,270 +31,39 @@ def __init__( "Please install the OpenPhenom dependencies: pip install transformers" ) - def on_predict_start(self): - # Move model to GPU when prediction starts - self.model.to(self.device) - - def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: - """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. - - Args: - x: 5D input tensor - - Returns: - 4D tensor after applying reduction methods - """ - if x.dim() != 5: - return x - - B, C, D, H, W = x.shape - result = torch.zeros((B, C, H, W), device=x.device) - - # Apply reduction method for each channel - for c in range(C): - channel_name = ( - self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" - ) - # Default to middle slice if not specified - method = self.channel_reduction_methods.get(channel_name, "middle_slice") - - if method == "middle_slice": - result[:, c] = x[:, c, D // 2] - elif method == "mean": - result[:, c] = x[:, c].mean(dim=1) - elif method == "max": - result[:, c] = x[:, c].max(dim=1)[0] - else: - # Fallback to middle slice for unknown methods - result[:, c] = x[:, c, D // 2] - - return result + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) + dm_config = cfg.get("datamodule", {}) - def predict_step(self, batch, batch_idx, dataloader_idx=0): - """Extract features from the input images. - - Returns: - Dictionary with features, projections (None), and index information - """ - x = batch["anchor"] + return cls( + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + channel_names=dm_config.get("source_channel", []), + ) - # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W] - # If 5D input, handle according to specified reduction methods - if x.dim() == 5: - x = self._reduce_5d_input(x) + def on_predict_start(self): + """Move model to GPU when prediction starts.""" + self.model.to(self.device) - # Convert to uint8 as OpenPhenom expects uint8 inputs + def _process_input(self, x: torch.Tensor): + """Convert to uint8 as OpenPhenom expects uint8 inputs.""" if x.dtype != torch.uint8: x = ( ((x - x.min()) / (x.max() - x.min()) * 255) .clamp(0, 255) .to(torch.uint8) ) + return x + def _extract_features(self, processed_input): + """Extract features using OpenPhenom model.""" # Get embeddings self.model.return_channelwise_embeddings = False - features = self.model.predict(x) - # Create empty projections tensor with same batch size as features - # This ensures the EmbeddingWriter can process it - projections = torch.zeros((features.shape[0], 0), device=features.device) - - return { - "features": features, - "projections": projections, - "index": batch["index"], - } - - -def load_config(config_file): - """Load configuration from a YAML file.""" - with open(config_file, "r") as f: - config = yaml.safe_load(f) - return config - - -def load_normalization_from_config(norm_config): - """Load a normalization transform from a configuration dictionary.""" - class_path = norm_config["class_path"] - init_args = norm_config.get("init_args", {}) - - # Split module and class name - module_path, class_name = class_path.rsplit(".", 1) - - # Import the module - module = importlib.import_module(module_path) - - # Get the class - transform_class = getattr(module, class_name) - - # Instantiate the transform - return transform_class(**init_args) - - -@click.command() -@click.option( - "--config", - "-c", - type=click.Path(exists=True), - required=True, - help="Path to YAML configuration file", -) -def main(config): - """Extract OpenPhenom embeddings and save to zarr format using VisCy Trainer.""" - # Configure logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - # Load config file - cfg = load_config(config) - logger.info(f"Loaded configuration from {config}") - - # Prepare datamodule parameters - dm_params = {} - - # Add data and tracks paths from the paths section - if "paths" not in cfg: - raise ValueError("Configuration must contain a 'paths' section") - - if "data_path" not in cfg["paths"]: - raise ValueError( - "Data path is required in the configuration file (paths.data_path)" - ) - dm_params["data_path"] = cfg["paths"]["data_path"] - - if "tracks_path" not in cfg["paths"]: - raise ValueError( - "Tracks path is required in the configuration file (paths.tracks_path)" - ) - dm_params["tracks_path"] = cfg["paths"]["tracks_path"] - - # Add datamodule parameters - if "datamodule" not in cfg: - raise ValueError("Configuration must contain a 'datamodule' section") - - # Prepare normalizations - if ( - "normalizations" not in cfg["datamodule"] - or not cfg["datamodule"]["normalizations"] - ): - raise ValueError( - "Normalizations are required in the configuration file (datamodule.normalizations)" - ) - - norm_configs = cfg["datamodule"]["normalizations"] - normalizations = [load_normalization_from_config(norm) for norm in norm_configs] - dm_params["normalizations"] = normalizations - - # Copy all other datamodule parameters - for param, value in cfg["datamodule"].items(): - if param != "normalizations": - # Handle patch sizes - if param == "patch_size": - dm_params["initial_yx_patch_size"] = value - dm_params["final_yx_patch_size"] = value - else: - dm_params[param] = value - - # Set up the data module - logger.info("Setting up data module") - dm = TripletDataModule(**dm_params) - - # Get model parameters for handling 5D inputs - channel_reduction_methods = {} - - if "model" in cfg and "channel_reduction_methods" in cfg["model"]: - channel_reduction_methods = cfg["model"]["channel_reduction_methods"] - - # Initialize OpenPhenom model with reduction settings - logger.info("Loading OpenPhenom model") - model = OpenPhenomModule( - channel_reduction_methods=channel_reduction_methods, - channel_names=dm_params.get("source_channel", []), - ) - - # Get dimensionality reduction parameters from config - phate_kwargs = None - pca_kwargs = None - - if "embedding" in cfg: - if "phate_kwargs" in cfg["embedding"]: - phate_kwargs = cfg["embedding"]["phate_kwargs"] - if "pca_kwargs" in cfg["embedding"]: - pca_kwargs = cfg["embedding"]["pca_kwargs"] - # Check if output path exists and should be overwritten - if "output_path" not in cfg["paths"]: - raise ValueError( - "Output path is required in the configuration file (paths.output_path)" - ) - - output_path = Path(cfg["paths"]["output_path"]) - output_dir = output_path.parent - output_dir.mkdir(parents=True, exist_ok=True) - - overwrite = False - if "execution" in cfg and "overwrite" in cfg["execution"]: - overwrite = cfg["execution"]["overwrite"] - elif output_path.exists(): - logger.warning(f"Output path {output_path} already exists, will overwrite") - overwrite = True - - # Set up EmbeddingWriter callback - embedding_writer = EmbeddingWriter( - output_path=output_path, - phate_kwargs=phate_kwargs, - pca_kwargs=pca_kwargs, - overwrite=overwrite, - ) - - # Set up and run VisCy trainer - logger.info("Setting up VisCy trainer") - trainer = VisCyTrainer( - accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1, - callbacks=[embedding_writer], - inference_mode=True, - ) - - logger.info(f"Running prediction and saving to {output_path}") - trainer.predict(model, datamodule=dm) - - # Save configuration if requested - save_config_flag = True - show_config_flag = True - - if "execution" in cfg: - if "save_config" in cfg["execution"]: - save_config_flag = cfg["execution"]["save_config"] - if "show_config" in cfg["execution"]: - show_config_flag = cfg["execution"]["show_config"] - - # Save configuration if requested - if save_config_flag: - config_path = os.path.join(output_dir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(cfg, f, default_flow_style=False) - logger.info(f"Configuration saved to {config_path}") - - # Display configuration if requested - if show_config_flag: - click.echo("\nConfiguration used:") - click.echo("-" * 40) - for key, value in cfg.items(): - click.echo(f"{key}:") - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if isinstance(subvalue, list) and subkey == "normalizations": - click.echo(f" {subkey}:") - for norm in subvalue: - click.echo(f" - class_path: {norm['class_path']}") - click.echo(f" init_args: {norm['init_args']}") - else: - click.echo(f" {subkey}: {subvalue}") - else: - click.echo(f" {value}") - click.echo("-" * 40) - - logger.info("Done!") + features = self.model.predict(processed_input) + return features if __name__ == "__main__": + main = create_embedding_cli(OpenPhenomModule, "OpenPhenom") main() diff --git a/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh new file mode 100644 index 000000000..9bf781bc0 --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/run_sam2.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +#SBATCH --job-name=dynaclr_imagenet +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem-per-cpu=7G +#SBATCH --time=0-02:00:00 +#SBATCH --output=./slurm_logs/%j_dynaclr_sam2.out + + +module load anaconda/latest +conda activate viscy + +CONFIG_PATH=/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sensor_only.yml +python /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py -c $CONFIG_PATH diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml new file mode 100644 index 000000000..5e3771d25 --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_config.yml @@ -0,0 +1,60 @@ +datamodule_class: viscy.data.triplet.TripletDataModule +datamodule: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr + batch_size: 32 + final_yx_patch_size: + - 192 + - 192 + include_fov_names: null + include_track_ids: null + initial_yx_patch_size: + - 192 + - 192 + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - Phase3D + lower: 50 + upper: 99 + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - raw GFP EX488 EM525-45 + lower: 50 + upper: 99 + num_workers: 10 + source_channel: + - Phase3D + - raw GFP EX488 EM525-45 + z_range: + - 25 + - 40 +embedding: + pca_kwargs: + n_components: 8 + phate_kwargs: + decay: 40 + knn: 5 + n_components: 2 + n_jobs: -1 + random_state: 42 + reductions: + - PHATE + - PCA +execution: + overwrite: false + save_config: true + show_config: true +model: + model_name: facebook/sam2-hiera-base-plus + channel_reduction_methods: + Phase3D: middle_slice + raw GFP EX488 EM525-45: max +paths: + output_path: /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sec61b_n_phase_all_highresfeats0.zarr diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py new file mode 100644 index 000000000..c664ca5cd --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_embeddings.py @@ -0,0 +1,101 @@ +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor +from skimage.exposure import rescale_intensity + +sys.path.append(str(Path(__file__).parent.parent)) + +from base_embedding_module import BaseEmbeddingModule, create_embedding_cli + + +class SAM2Module(BaseEmbeddingModule): + def __init__( + self, + model_name: str = "facebook/sam2-hiera-base-plus", + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + middle_slice_index: Optional[int] = None, + ): + super().__init__(channel_reduction_methods, channel_names, middle_slice_index) + self.model_name = model_name + self.model = None # Initialize in on_predict_start when device is set + + @classmethod + def from_config(cls, cfg): + """Create model instance from configuration.""" + model_config = cfg.get("model", {}) + + return cls( + model_name=model_config.get("model_name", "facebook/sam2-hiera-base-plus"), + channel_reduction_methods=model_config.get("channel_reduction_methods", {}), + middle_slice_index=model_config.get("middle_slice_index", None), + ) + + def on_predict_start(self): + """Initialize model with proper device when prediction starts.""" + if self.model is None: + self.model = SAM2ImagePredictor.from_pretrained( + self.model_name, device=self.device + ) + + def _process_input(self, x: torch.Tensor): + """Convert input tensor to 3-channel RGB format as needed for SAM2.""" + return self._convert_to_rgb(x) + + def _extract_features(self, image_list): + """Extract features using SAM2 model.""" + self.model.set_image_batch(image_list) + # Extract high-resolution features and apply global average pooling + features = self.model._features["high_res_feats"][0].mean(dim=(2, 3)) + return features + + def _convert_to_rgb(self, x: torch.Tensor) -> list: + """ + Convert input tensor to 3-channel RGB format as needed for SAM2. + + Parameters + ---------- + x : torch.Tensor + Input tensor with 1, 2, or 3+ channels and shape (B, C, H, W). + + Returns + ------- + list of numpy.ndarray + List of numpy arrays in HWC format for SAM2 processing. + """ + # Convert to RGB and scale to [0, 255] range for SAM2 + if x.shape[1] == 1: + x_rgb = x.repeat(1, 3, 1, 1) * 255.0 + elif x.shape[1] == 2: + x_3ch = torch.zeros( + (x.shape[0], 3, x.shape[2], x.shape[3]), device=x.device + ) + x[:, 0] = rescale_intensity(x[:, 0], out_range="uint8") + x[:, 1] = rescale_intensity(x[:, 1], out_range="uint8") + + x_3ch[:, 0] = x[:, 0] + x_3ch[:, 1] = x[:, 1] + x_3ch[:, 2] = 0.5 * (x[:, 0] + x[:, 1]) # B channel as blend + x_rgb = x_3ch + + elif x.shape[1] == 3: + x_rgb = rescale_intensity(x, out_range="uint8") + else: + # More than 3 channels, normalize first 3 and scale + x_3ch = x[:, :3] + x_rgb = rescale_intensity(x_3ch, out_range="uint8") + + # Convert to list of numpy arrays in HWC format for SAM2 + return [ + x_rgb[i].cpu().numpy().transpose(1, 2, 0) for i in range(x_rgb.shape[0]) + ] + + +if __name__ == "__main__": + main = create_embedding_cli(SAM2Module, "SAM2") + main() diff --git a/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py new file mode 100644 index 000000000..46a9b2eac --- /dev/null +++ b/applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py @@ -0,0 +1,213 @@ +# %% +""" +Test script to visualize SAM2 input images and feature processing. +This script helps debug what images are being passed to SAM2 and how they're processed. +""" + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +from sam2_embeddings import SAM2Module, load_config, load_normalization_from_config + +from viscy.data.triplet import TripletDataModule + + +def visualize_rgb_conversion(x_original, x_rgb_list, save_dir="./debug_images"): + """Visualize the RGB conversion process""" + os.makedirs(save_dir, exist_ok=True) + + print(f"Original input shape: {x_original.shape}") + print(f"Original input range: [{x_original.min():.3f}, {x_original.max():.3f}]") + + # Plot original channels + B, C = x_original.shape[:2] + fig, axes = plt.subplots(3, max(3, C), figsize=(15, 12)) + + # Plot original channels + for c in range(C): + ax = axes[0, c] if C > 1 else axes[0, 0] + img = x_original[0, c].cpu().numpy() + im = ax.imshow(img, cmap="gray") + ax.set_title(f"Original Channel {c}") + ax.axis("off") + plt.colorbar(im, ax=ax) + + # Plot RGB conversion + rgb_img = x_rgb_list[0] # First batch item + print(f"RGB image shape: {rgb_img.shape}") + print(f"RGB image range: [{rgb_img.min():.3f}, {rgb_img.max():.3f}]") + + for c in range(3): + ax = axes[1, c] + im = ax.imshow(rgb_img[:, :, c], cmap="gray") + ax.set_title(f"RGB Channel {c}") + ax.axis("off") + plt.colorbar(im, ax=ax) + + # Plot merged RGB image + ax = axes[2, 0] + # Normalize to 0-1 for display + rgb_display = rgb_img.copy() + rgb_display = (rgb_display - rgb_display.min()) / ( + rgb_display.max() - rgb_display.min() + ) + ax.imshow(rgb_display) + ax.set_title("Merged RGB Image") + ax.axis("off") + + # Check if RGB is properly scaled to 0-255 + ax = axes[2, 1] + ax.text( + 0.1, + 0.8, + f"RGB Range: [{rgb_img.min():.1f}, {rgb_img.max():.1f}]", + transform=ax.transAxes, + ) + ax.text(0.1, 0.6, "Expected: [0, 255]", transform=ax.transAxes) + ax.text( + 0.1, + 0.4, + f"Properly scaled: {rgb_img.min() >= 0 and rgb_img.max() <= 255}", + transform=ax.transAxes, + ) + ax.text(0.1, 0.2, f"Mean: {rgb_img.mean():.1f}", transform=ax.transAxes) + ax.set_title("RGB Scaling Check") + ax.axis("off") + + plt.tight_layout() + plt.savefig(f"{save_dir}/rgb_conversion.png", dpi=150, bbox_inches="tight") + plt.close() + + +def test_sam2_processing(config_path, num_samples=3): + """Test SAM2 processing with visualization""" + + # Load configuration + cfg = load_config(config_path) + print(f"Loaded config from: {config_path}") + + # Setup data module (same as in main function) + dm_params = {} + dm_params["data_path"] = cfg["paths"]["data_path"] + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Setup normalizations + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + print("Setting up data module...") + dm = TripletDataModule(**dm_params) + dm.setup(stage="predict") + + # Get model parameters + channel_reduction_methods = {} + if "model" in cfg and "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + + # Initialize SAM2 model + print("Loading SAM2 model...") + model = SAM2Module( + model_name=cfg["model"]["model_name"], + channel_reduction_methods=channel_reduction_methods, + ) + + # Get dataloader + predict_dataloader = dm.predict_dataloader() + + print(f"Testing with {num_samples} samples...") + + # Test processing + for i, batch in enumerate(predict_dataloader): + if i >= num_samples: + break + + print(f"\n--- Sample {i + 1} ---") + x = batch["anchor"] + print(f"Input tensor shape: {x.shape}") + print(f"Input tensor range: [{x.min():.3f}, {x.max():.3f}]") + + # Test 5D reduction if needed + if x.dim() == 5: + print("Applying 5D reduction...") + x_reduced = model._reduce_5d_input(x) + print(f"After 5D reduction: {x_reduced.shape}") + print(f"Reduction methods: {model.channel_reduction_methods}") + else: + x_reduced = x + + # Test RGB conversion + print("Converting to RGB...") + x_rgb_list = model._convert_to_rgb(x_reduced) + print(f"RGB conversion result: {len(x_rgb_list)} images") + print(f"First RGB image shape: {x_rgb_list[0].shape}") + + # Visualize the conversion + visualize_rgb_conversion(x_reduced, x_rgb_list, f"./debug_images/sample_{i}") + + # Test feature extraction (if model is available) + try: + print("Testing feature extraction...") + model.model = model.model or model.on_predict_start() + model.model.set_image_batch(x_rgb_list) + + # Check what features are available + features_dict = model.model._features + print(f"Available features: {list(features_dict.keys())}") + + if "high_res_feats" in features_dict: + high_res_feats = features_dict["high_res_feats"] + print(f"High-res features length: {len(high_res_feats)}") + for j, feat in enumerate(high_res_feats): + print(f" Layer {j}: {feat.shape}") + + if "image_embed" in features_dict: + image_embed = features_dict["image_embed"] + print(f"Image embed shape: {image_embed.shape}") + + # Extract final features (current approach) + features = model.model._features["high_res_feats"][1].mean(dim=(2, 3)) + print(f"Final features shape: {features.shape}") + print(f"Final features range: [{features.min():.3f}, {features.max():.3f}]") + + except Exception as e: + print(f"Feature extraction failed: {e}") + + print("-" * 50) + + +def main(): + """Main function to run the test""" + config_path = "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/SAM2/sam2_sensor_only.yml" + + if not Path(config_path).exists(): + print(f"Config file not found: {config_path}") + print("Please provide a valid config file path") + return + + try: + test_sam2_processing(config_path, num_samples=3) + print("\nTest completed successfully!") + print("Check ./debug_images/ for visualization outputs") + except Exception as e: + print(f"Test failed: {e}") + import traceback + + traceback.print_exc() + + +# %% +if __name__ == "__main__": + main() + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py deleted file mode 100644 index b77e8a0f6..000000000 --- a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py +++ /dev/null @@ -1,229 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.distance import ( - compute_displacement, - compute_displacement_statistics, -) - -# Paths to datasets -feature_paths = { - "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", - "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", -} - -# Colors for different time intervals -interval_colors = { - "7 min interval": "blue", - "21 min interval": "red", -} - -# %% Compute MSD for each dataset -results = {} -raw_displacements = {} - -for label, path in feature_paths.items(): - print(f"\nProcessing {label}...") - embedding_dataset = read_embedding_dataset(Path(path)) - - # Compute displacements - displacements = compute_displacement( - embedding_dataset=embedding_dataset, - distance_metric="euclidean_squared", - ) - means, stds = compute_displacement_statistics(displacements) - results[label] = (means, stds) - raw_displacements[label] = displacements - - # Print some statistics - taus = sorted(means.keys()) - print(f" Number of different τ values: {len(taus)}") - print(f" τ range: {min(taus)} to {max(taus)}") - print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") - -# %% Plot MSD vs time (linear scale) -plt.figure(figsize=(10, 6)) - -# Plot each time interval -for interval_label, path in feature_paths.items(): - means, stds = results[interval_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - plt.plot( - taus, - mean_values, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - plt.scatter( - taus, - mean_values, - color=interval_colors[interval_label], - s=20, - label=interval_label, - zorder=2, - ) - -plt.xlabel("Time Shift (τ)") -plt.ylabel("Mean Square Displacement") -plt.title("MSD vs Time Shift") -plt.grid(True, alpha=0.3) -plt.legend() -plt.tight_layout() -plt.show() - -# %% Plot MSD vs time (log-log scale with slopes) -plt.figure(figsize=(10, 6)) - -# Plot each time interval -for interval_label, path in feature_paths.items(): - means, stds = results[interval_label] - - # Sort by tau for plotting - taus = sorted(means.keys()) - mean_values = [means[tau] for tau in taus] - std_values = [stds[tau] for tau in taus] - - # Filter out non-positive values for log scale - valid_mask = np.array(mean_values) > 0 - valid_taus = np.array(taus)[valid_mask] - valid_means = np.array(mean_values)[valid_mask] - - # Calculate slopes for different regions - log_taus = np.log(valid_taus) - log_means = np.log(valid_means) - - # Early slope (first third of points) - n_points = len(log_taus) - early_end = n_points // 3 - early_slope, early_intercept = np.polyfit( - log_taus[:early_end], log_means[:early_end], 1 - ) - - # Late slope (last third of points) - late_start = 2 * (n_points // 3) - late_slope, late_intercept = np.polyfit( - log_taus[late_start:], log_means[late_start:], 1 - ) - - plt.plot( - valid_taus, - valid_means, - "-", - color=interval_colors[interval_label], - alpha=0.5, - zorder=1, - ) - plt.scatter( - valid_taus, - valid_means, - color=interval_colors[interval_label], - s=20, - label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})", - zorder=2, - ) - - # Plot fitted lines for early and late regions - early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) - late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:]) - - plt.plot( - valid_taus[:early_end], - early_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) - plt.plot( - valid_taus[late_start:], - late_fit, - "--", - color=interval_colors[interval_label], - alpha=0.3, - zorder=1, - ) - -plt.xscale("log") -plt.yscale("log") -plt.xlabel("Time Shift (τ)") -plt.ylabel("Mean Square Displacement") -plt.title("MSD vs Time Shift (log-log)") -plt.grid(True, alpha=0.3, which="both") -plt.legend( - title="α = slope in log-log space", bbox_to_anchor=(1.05, 1), loc="upper left" -) -plt.tight_layout() -plt.show() - -# %% Plot slopes analysis -early_slopes = [] -late_slopes = [] -intervals = [] - -for interval_label in feature_paths.keys(): - means, _ = results[interval_label] - - # Calculate slopes - taus = np.array(sorted(means.keys())) - mean_values = np.array([means[tau] for tau in taus]) - valid_mask = mean_values > 0 - - if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes - log_taus = np.log(taus[valid_mask]) - log_means = np.log(mean_values[valid_mask]) - - # Calculate early and late slopes - n_points = len(log_taus) - early_end = n_points // 3 - late_start = 2 * (n_points // 3) - - early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) - late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1) - - early_slopes.append(early_slope) - late_slopes.append(late_slope) - intervals.append(interval_label) - -# Create bar plot -plt.figure(figsize=(12, 6)) - -x = np.arange(len(intervals)) -width = 0.35 - -plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) -plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7) - -# Add reference lines -plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") -plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) - -plt.xlabel("Time Interval") -plt.ylabel("Slope (α)") -plt.title("MSD Slopes by Time Interval") -plt.xticks(x, intervals, rotation=45) -plt.legend() - -# Add annotations for diffusion regimes -plt.text( - plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center" -) -plt.text( - plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center" -) - -plt.grid(True, alpha=0.3) -plt.tight_layout() -plt.show() - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py new file mode 100644 index 000000000..ed86e0b4c --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py @@ -0,0 +1,215 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from scipy import stats + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.distance import ( + compute_track_displacement, +) + +# Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr", + "14 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr", + "91 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr", +} + + +cmap = plt.get_cmap("tab10") # or use "Set2", "tab20", etc. +labels = list(feature_paths.keys()) +interval_colors = {label: cmap(i % cmap.N) for i, label in enumerate(labels)} + +# Print and check each path +for label, path in feature_paths.items(): + print(f"{label} color: {interval_colors[label]}") + assert Path(path).exists(), f"Path {path} does not exist" + +# %% Compute MSD for each dataset +results = {} +raw_displacements = {} + +DISTANCE_METRIC = "cosine" +for label, path in feature_paths.items(): + results[label] = {} + print(f"\nProcessing {label}...") + embedding_dataset = read_embedding_dataset(Path(path)) + + # Compute displacements + displacements_per_tau = compute_track_displacement( + embedding_dataset=embedding_dataset, + distance_metric=DISTANCE_METRIC, + ) + + # Store displacements with conditional normalization + if DISTANCE_METRIC == "cosine": + # Cosine distance is already scale-invariant, no normalization needed + for tau, displacements in displacements_per_tau.items(): + results[label][tau] = displacements + else: + # Normalize by embeddings variance for euclidean distance + embeddings_variance = np.var(embedding_dataset["features"].values) + for tau, displacements in displacements_per_tau.items(): + results[label][tau] = [disp / embeddings_variance for disp in displacements] + + +# %% Plot MSD vs time (linear scale) +show_power_law_fits = True +log_scale = True +title = "Mean Track Displacement vs Time Shift" + +fig, ax = plt.subplots(figsize=(10, 7)) + +for model_type, msd_data in results.items(): + time_lags = sorted(msd_data.keys()) + msd_means = [] + msd_stds = [] + + # Compute mean and std of MSD for each time lag + for tau in time_lags: + displacements = np.array(msd_data[tau]) + msd_means.append(np.mean(displacements)) + msd_stds.append(np.std(displacements) / np.sqrt(len(displacements))) + + time_lags = np.array(time_lags) + msd_means = np.array(msd_means) + msd_stds = np.array(msd_stds) + + # Plot with error bars + color = interval_colors.get(model_type, "#1f77b4") + ax.errorbar( + time_lags, + msd_means, + yerr=msd_stds, + marker="o", + label=f"{model_type.replace('_', ' ').title()}", + color=color, + capsize=3, + capthick=1, + linewidth=2, + markersize=6, + ) + # Fit power law if requested + if show_power_law_fits and len(time_lags) > 3: + valid_mask = (time_lags > 0) & (msd_means > 0) + if np.sum(valid_mask) > 3: + log_tau = np.log(time_lags[valid_mask]) + log_msd = np.log(msd_means[valid_mask]) + + slope, intercept, r_value, p_value, std_err = stats.linregress( + log_tau, log_msd + ) + + # Plot fit line + tau_fit = np.linspace( + time_lags[valid_mask][0], time_lags[valid_mask][-1], 50 + ) + msd_fit = np.exp(intercept) * tau_fit**slope + + ax.plot( + tau_fit, + msd_fit, + "--", + color=color, + alpha=0.7, + label=f"{model_type}: α={slope:.2f} (R²={r_value**2:.3f})", + ) + + ax.set_xlabel("Time Lag (τ)", fontsize=12) + ax.set_ylabel("Mean Track Displacement", fontsize=12) + ax.set_title(title, fontsize=14) + + if log_scale: + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(True, alpha=0.3) + + ax.legend() + plt.tight_layout() +plt.savefig(f"msd_vs_time_shift_{DISTANCE_METRIC}.png", dpi=300) +# %% +# Step size analysis + + +def extract_step_sizes(embedding_dataset: xr.Dataset): + """Extract step sizes with simple coordinate access.""" + + unique_tracks_df = ( + embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() + ) + all_step_sizes = [] + + for fov_name, track_id in zip( + unique_tracks_df["fov_name"], unique_tracks_df["track_id"] + ): + track_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + time_order = np.argsort(track_data["t"].values) + times = track_data["t"].values[time_order] + track_embeddings = track_data["features"].values[time_order] + if len(times) != len(np.unique(times)): + print(f"Duplicates found in FOV {fov_name}, track {track_id}") + + if len(track_embeddings) > 1: + steps = np.diff(track_embeddings, axis=0) + step_sizes = np.linalg.norm(steps, axis=1) + all_step_sizes.extend(step_sizes) + + return np.array(all_step_sizes) + + +all_step_data = {} +cv_values = [] +labels = [] + +for label, path in feature_paths.items(): + print(f"\nProcessing {label}...") + embedding_dataset = read_embedding_dataset(Path(path)) + steps = extract_step_sizes(embedding_dataset) + all_step_data[label] = steps + + # Calculate coefficient of variation + cv = np.std(steps) / np.mean(steps) + cv_values.append(cv) + labels.append(label.replace("_", " ").title()) + +# %% +# Plot histograms +ax1, ax2 = plt.subplots(1, 2, figsize=(15, 6))[1] + +for model_type, steps in all_step_data.items(): + color = interval_colors.get(model_type, "#1f77b4") + ax1.hist( + steps, + bins=50, + alpha=0.7, + color=color, + label=f"{model_type.replace('_', ' ').title()} (n={len(steps)}, μ={np.mean(steps):.3f}, σ={np.std(steps):.3f})", + ) + +ax1.set_xlabel("Step Size") +ax1.set_ylabel("Frequency") +ax1.set_title("Step Size Distributions") +ax1.legend() + +# Plot coefficient of variation +bar_colors = [ + interval_colors.get(model_type, "#1f77b4") for model_type in results.keys() +] +bars = ax2.bar(labels, cv_values, color=bar_colors, alpha=0.7) +ax2.set_ylabel("Coefficient of Variation (σ/μ)") +ax2.set_title("Step Size Variability") +ax2.tick_params(axis="x", rotation=45) +plt.tight_layout() +# plt.show() +plt.savefig(f"step_size_distributions_{DISTANCE_METRIC}.png", dpi=300) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py b/applications/contrastive_phenotyping/evaluation/archive/analyze_embeddings.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/analyze_embeddings.py rename to applications/contrastive_phenotyping/evaluation/archive/analyze_embeddings.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/archive/cosine_dissimilarity_dataset.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py rename to applications/contrastive_phenotyping/evaluation/archive/cosine_dissimilarity_dataset.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/archive/cosine_similarity.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_similarity.py rename to applications/contrastive_phenotyping/evaluation/archive/cosine_similarity.py diff --git a/applications/contrastive_phenotyping/evaluation/linear_probing.py b/applications/contrastive_phenotyping/evaluation/archive/linear_probing.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/linear_probing.py rename to applications/contrastive_phenotyping/evaluation/archive/linear_probing.py diff --git a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py b/applications/contrastive_phenotyping/evaluation/archive/log_regresssion_training.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/log_regresssion_training.py rename to applications/contrastive_phenotyping/evaluation/archive/log_regresssion_training.py diff --git a/applications/contrastive_phenotyping/evaluation/time_decay_knn.py b/applications/contrastive_phenotyping/evaluation/archive/time_decay_knn.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/time_decay_knn.py rename to applications/contrastive_phenotyping/evaluation/archive/time_decay_knn.py diff --git a/applications/contrastive_phenotyping/evaluation/displacement.py b/applications/contrastive_phenotyping/evaluation/displacement.py deleted file mode 100644 index 5505f706c..000000000 --- a/applications/contrastive_phenotyping/evaluation/displacement.py +++ /dev/null @@ -1,166 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.distance import ( - calculate_normalized_euclidean_distance_cell, - compute_displacement_mean_std_full, -) - -# %% paths - -features_path_30_min = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) - -feature_path_no_track = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" -) - -features_path_any_time = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" -) - -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) - -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - -# %% Load embedding datasets for all three sampling -fov_name = "/B/4/6" -track_id = 52 - -embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) -embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) -embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) - -# %% -# Calculate displacement for each sampling -time_points_30_min, cosine_similarities_30_min = ( - calculate_normalized_euclidean_distance_cell( - embedding_dataset_30_min, fov_name, track_id - ) -) -time_points_no_track, cosine_similarities_no_track = ( - calculate_normalized_euclidean_distance_cell( - embedding_dataset_no_track, fov_name, track_id - ) -) -time_points_any_time, cosine_similarities_any_time = ( - calculate_normalized_euclidean_distance_cell( - embedding_dataset_any_time, fov_name, track_id - ) -) - -# %% Plot displacement over time for all three conditions - -plt.figure(figsize=(10, 6)) - -plt.plot( - time_points_no_track, - cosine_similarities_no_track, - marker="o", - label="classical contrastive (no tracking)", -) -plt.plot( - time_points_any_time, cosine_similarities_any_time, marker="o", label="cell aware" -) -plt.plot( - time_points_30_min, - cosine_similarities_30_min, - marker="o", - label="cell & time aware (interval 30 min)", -) - -plt.xlabel("Time Delay (t)", fontsize=10) -plt.ylabel("Normalized Euclidean Distance with First Time Point", fontsize=10) -plt.title( - "Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12 -) - -plt.grid(True) -plt.legend(fontsize=10) - -# plt.savefig('4_euc_dist_full.svg', format='svg') -plt.show() - - -# %% Paths to datasets -features_path_30_min = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) -feature_path_no_track = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" -) - -embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) -embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) - - -# %% -max_tau = 10 - -mean_displacement_30_min_euc, std_displacement_30_min_euc = ( - compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) -) -mean_displacement_no_track_euc, std_displacement_no_track_euc = ( - compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) -) - -# %% Plot 2: Cosine Displacements -plt.figure(figsize=(10, 6)) - -taus = list(mean_displacement_30_min_euc.keys()) - -mean_values_30_min_euc = list(mean_displacement_30_min_euc.values()) -std_values_30_min_euc = list(std_displacement_30_min_euc.values()) - -plt.plot( - taus, - mean_values_30_min_euc, - marker="o", - label="Cell & Time Aware (30 min interval)", - color="green", -) -plt.fill_between( - taus, - np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), - np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), - color="green", - alpha=0.3, - label="Std Dev (30 min interval)", -) - -mean_values_no_track_euc = list(mean_displacement_no_track_euc.values()) -std_values_no_track_euc = list(std_displacement_no_track_euc.values()) - -plt.plot( - taus, - mean_values_no_track_euc, - marker="o", - label="Classical Contrastive (No Tracking)", - color="blue", -) -plt.fill_between( - taus, - np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), - np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), - color="blue", - alpha=0.3, - label="Std Dev (No Tracking)", -) - -plt.xlabel("Time Shift (τ)") -plt.ylabel("Euclidean Distance") -plt.title("Embedding Displacement Over Time (Features)") - -plt.grid(True) -plt.legend() - -plt.show() diff --git a/applications/contrastive_phenotyping/evaluation/imagenet_pretrained_features.py b/applications/contrastive_phenotyping/evaluation/imagenet/imagenet_pretrained_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/imagenet_pretrained_features.py rename to applications/contrastive_phenotyping/evaluation/imagenet/imagenet_pretrained_features.py diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/knowledge_distillation.py rename to applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation.py diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation_teacher.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py rename to applications/contrastive_phenotyping/evaluation/knowledge_distillation/knowledge_distillation_teacher.py diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py b/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/PC_vs_computed_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py rename to applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/PC_vs_computed_features.py diff --git a/applications/contrastive_phenotyping/evaluation/compute_pca_features.py b/applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/compute_pca_features.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/compute_pca_features.py rename to applications/contrastive_phenotyping/evaluation/pc_vs_computed_features/compute_pca_features.py diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py new file mode 100644 index 000000000..89d8bb6fc --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/linear_classifier.py @@ -0,0 +1,142 @@ +# %% +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, classification_report +from sklearn.model_selection import train_test_split + +from viscy.representation.embedding_writer import read_embedding_dataset + +test_data_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/bf_only_timeaware_ntxent_lr2e-5_temp_7e-2_tau1_w_augmentations_2_ckpt306.zarr" +) +cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/dynaclr/pseudolabels/cell_cycle_labels_w_mitosis.csv" + +# %% +# Load the data +cell_cycle_labels_df = pd.read_csv(cell_cycle_labels_path, dtype={"dataset_name": str}) +test_embeddings = read_embedding_dataset(test_data_features_path) + +# Extract features (768-dimensional embeddings) +features = test_embeddings.features.values + +# %% +sample_coords = test_embeddings.coords["sample"].values +fov_names = [coord[0] for coord in sample_coords] +ids = [coord[1] for coord in sample_coords] + +# Create DataFrame with embeddings and identifiers +embedding_df = pd.DataFrame( + { + "dataset_name": fov_names, + "timepoint": ids, + } +) + +# Merge with cell cycle labels +merged_data = embedding_df.merge( + cell_cycle_labels_df, on=["dataset_name", "timepoint"], how="inner" +) + +print(f"Original embeddings: {len(embedding_df)}") +print(f"Cell cycle labels: {len(cell_cycle_labels_df)}") +print(f"Merged data: {len(merged_data)}") +print(f"Cell cycle distribution:\n{merged_data['cell_cycle_state'].value_counts()}") + +# Get corresponding features for merged samples +merged_indices = merged_data.index.values +X = features[merged_indices] +y = merged_data["cell_cycle_state"].values + +# %% +# First split: 80% train+val, 20% test +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) +print(f"Training set: {X_train.shape[0]} samples") +print(f"Test set: {X_test.shape[0]} samples") + +# %% +# Train logistic regression model +clf = LogisticRegression(random_state=42, max_iter=1000) +clf.fit(X_train, y_train) + +y_test_pred = clf.predict(X_test) +test_accuracy = accuracy_score(y_test, y_test_pred) +print(f"Test accuracy: {test_accuracy:.4f}") + +print("\nTest set classification report:") +print(classification_report(y_test, y_test_pred)) + +# %% +# Enhanced evaluation and visualization +import matplotlib.pyplot as plt +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix + +# 1. Confusion Matrix - shows which classes are confused with each other +cm = confusion_matrix(y_test, y_test_pred) +plt.figure(figsize=(8, 6)) +ConfusionMatrixDisplay(cm, display_labels=["G1", "G2", "S", "M"]).plot(cmap="Blues") +plt.title("Confusion Matrix") +plt.show() + +# 2. Per-class errors breakdown +print("\nDetailed per-class analysis:") +for class_name in ["G1", "G2", "S", "M"]: + mask = y_test == class_name + correct = (y_test_pred[mask] == class_name).sum() + total = mask.sum() + print(f"{class_name}: {correct}/{total} correct ({correct / total:.3f})") + + # Show what this class was misclassified as + if total > correct: + wrong_preds = y_test_pred[mask & (y_test_pred != class_name)] + unique, counts = np.unique(wrong_preds, return_counts=True) + print(f" Misclassified as: {dict(zip(unique, counts))}") + +# 3. Prediction confidence (probabilities) +y_test_proba = clf.predict_proba(X_test) +class_names = clf.classes_ + +plt.figure(figsize=(12, 4)) +for i, class_name in enumerate(class_names): + plt.subplot(1, 4, i + 1) + plt.hist( + y_test_proba[:, i], + bins=20, + alpha=0.7, + color=["blue", "orange", "green", "red"][i], + ) + plt.title(f"Confidence for {class_name}") + plt.xlabel("Probability") + plt.ylabel("Count") +plt.tight_layout() +plt.show() + +# 4. Most confident correct and incorrect predictions +print("\nMost confident predictions:") +max_proba = np.max(y_test_proba, axis=1) +pred_correct = y_test == y_test_pred + +# Most confident correct predictions +correct_idx = np.where(pred_correct)[0] +most_confident_correct = correct_idx[np.argsort(max_proba[correct_idx])[-5:]] +print("Top 5 most confident CORRECT predictions:") +for idx in most_confident_correct: + print( + f" True: {y_test[idx]}, Pred: {y_test_pred[idx]}, Confidence: {max_proba[idx]:.3f}" + ) + +# Most confident incorrect predictions +incorrect_idx = np.where(~pred_correct)[0] +if len(incorrect_idx) > 0: + most_confident_wrong = incorrect_idx[np.argsort(max_proba[incorrect_idx])[-5:]] + print("\nTop 5 most confident WRONG predictions:") + for idx in most_confident_wrong: + print( + f" True: {y_test[idx]}, Pred: {y_test_pred[idx]}, Confidence: {max_proba[idx]:.3f}" + ) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py new file mode 100644 index 000000000..e4a56ab88 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/rpe1_fucci/phate_plot.py @@ -0,0 +1,171 @@ +# %% Imports +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.dimensionality_reduction import compute_phate + +# %% +test_data_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_data_ckpt264.zarr" +) +test_drugs_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_rpe_fucci_leger_weigert/0-phenotyping/rpe_fucci_test_drugs_ckpt264.zarr" +) +cell_cycle_labels_path = "/hpc/projects/organelle_phenotyping/models/rpe_fucci/pseudolabels/cell_cycle_labels.csv" + +# %% Load embeddings and annotations. + +test_features = read_embedding_dataset(test_data_features_path) +# test_drugs = read_embedding_dataset(test_drugs_path) + +# Load cell cycle labels +cell_cycle_labels_df = pd.read_csv(cell_cycle_labels_path, dtype={"dataset_name": str}) + +# Create a combined identifier for matching +sample_coords = test_features.coords["sample"].values +fov_names = [coord[0] for coord in sample_coords] +ids = [coord[1] for coord in sample_coords] + +# Create DataFrame with embeddings and identifiers +embedding_df = pd.DataFrame( + { + "dataset_name": fov_names, + "timepoint": ids, + } +) + +# Merge with cell cycle labels +merged_data = embedding_df.merge( + cell_cycle_labels_df, on=["dataset_name", "timepoint"], how="inner" +) + +print(f"Original embeddings: {len(embedding_df)}") +print(f"Cell cycle labels: {len(cell_cycle_labels_df)}") +print(f"Merged data: {len(merged_data)}") +print(f"Cell cycle distribution:\n{merged_data['cell_cycle_state'].value_counts()}") + +# Get corresponding features for merged samples +merged_indices = merged_data.index.values +cell_cycle_states = merged_data["cell_cycle_state"].values + +# %% +# compute phate +phate_kwargs = { + "knn": 10, + "decay": 20, + "n_components": 2, + "gamma": 1, + "t": "auto", + "n_jobs": -1, +} + +phate_model, phate_embedding = compute_phate(test_features, **phate_kwargs) +# %% + +# Define colorblind-friendly palette for cell cycle states (blue/orange as requested) +cycle_colors = {"G1": "#1f77b4", "G2": "#ff7f0e", "S": "#9467bd"} + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x=phate_embedding[merged_indices, 0], + y=phate_embedding[merged_indices, 1], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, +) +plt.title("PHATE Embedding Colored by Cell Cycle State") +plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + + +# %% +# Plot the PHATE embedding from the xarray + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE2"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, +) +plt.title("PHATE1 vs PHATE2 Colored by Cell Cycle State") +plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") +# %% +# plot the 3D PHATE embedding (Note: seaborn scatterplot doesn't support 3D, using matplotlib) + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111, projection="3d") + +for state in ["G1", "G2", "S"]: + mask = cell_cycle_states == state + ax.scatter( + test_features["PHATE1"][merged_indices][mask], + test_features["PHATE2"][merged_indices][mask], + test_features["PHATE3"][merged_indices][mask], + c=cycle_colors[state], + alpha=0.6, + label=state, + ) + +ax.set_xlabel("PHATE1") +ax.set_ylabel("PHATE2") +ax.set_zlabel("PHATE3") +ax.set_title("3D PHATE Embedding Colored by Cell Cycle State") +ax.legend() + +# %% +# Plot the PHATE embedding from test_drugs (commented out since not loaded) +# plt.figure(figsize=(10, 10)) +# sns.scatterplot( +# x=test_drugs["PHATE1"], +# y=test_drugs["PHATE2"], +# # hue=test_drugs["t"], +# alpha=0.5, +# ) +# plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") +# %% +fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + +# PHATE1 vs PHATE2 +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE2"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[0], +) +axes[0].set_title("PHATE1 vs PHATE2") +axes[0].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +# PHATE1 vs PHATE3 +sns.scatterplot( + x=test_features["PHATE1"][merged_indices], + y=test_features["PHATE3"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[1], +) +axes[1].set_title("PHATE1 vs PHATE3") +axes[1].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +# PHATE2 vs PHATE3 +sns.scatterplot( + x=test_features["PHATE2"][merged_indices], + y=test_features["PHATE3"][merged_indices], + hue=cell_cycle_states, + palette=cycle_colors, + alpha=0.6, + ax=axes[2], +) +axes[2].set_title("PHATE2 vs PHATE3") +axes[2].legend(bbox_to_anchor=(1.05, 1), loc="upper left") + +plt.tight_layout() +plt.show() +# %% diff --git a/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py new file mode 100644 index 000000000..8047c3d63 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py @@ -0,0 +1,103 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.smoothness import compute_embeddings_smoothness + +# %% +# FEATURES + +# openphenom_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/open_phenom/features/open_phenom_features.csv") +# imagenet_features_path = Path("/home/jason/projects/contrastive_phenotyping/data/imagenet/features/imagenet_features.csv") +dynaclr_features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/dtw_evaluation/SAM2/sam2_sensor_only.zarr" +) +dinov3_features_path = Path( + "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_phase_only_2.zarr" +) + +# LOADING DATASETS +# openphenom_features = read_embedding_dataset(openphenom_features_path) +# imagenet_features = read_embedding_dataset(imagenet_features_path) +dynaclr_embedding_dataset = read_embedding_dataset(dynaclr_features_path) +dinov3_embedding_dataset = read_embedding_dataset(dinov3_features_path) +# %% +# Compute the smoothness of the features +DISTANCE_METRIC = "cosine" +feature_paths = { + # "dynaclr": dynaclr_features_path, + "dinov3": dinov3_features_path, +} +cmap = plt.get_cmap("tab10") # or use "Set2", "tab20", etc. +labels = list(feature_paths.keys()) +interval_colors = {label: cmap(i % cmap.N) for i, label in enumerate(labels)} +# Print and check each path +for label, path in feature_paths.items(): + print(f"{label} color: {interval_colors[label]}") + assert Path(path).exists(), f"Path {path} does not exist" + +output_dir = Path("./smoothness_metrics") +output_dir.mkdir(parents=True, exist_ok=True) + +results = {} +for label, path in feature_paths.items(): + results[label] = {} + print(f"\nProcessing - {label}") + embedding_dataset = read_embedding_dataset(Path(path)) + + # Compute displacements + stats, distributions, _ = compute_embeddings_smoothness( + embedding_dataset=embedding_dataset, + distance_metric=DISTANCE_METRIC, + verbose=True, + ) + + # Plot the piecewise distances + plt.figure() + sns.histplot( + distributions["adjacent_frame_distribution"], + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + distributions["random_frame_distribution"], + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", + ) + plt.xlabel(f"{DISTANCE_METRIC} Distance") + plt.ylabel("Density") + # Add vertical lines for the peaks + plt.axvline(x=stats["adjacent_frame_peak"], color="cyan", linestyle="--", alpha=0.8) + plt.axvline(x=stats["random_frame_peak"], color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + plt.savefig(output_dir / f"{label}_smoothness.pdf", dpi=300) + plt.savefig(output_dir / f"{label}_smoothness.png", dpi=300) + plt.close() + + # metrics to csv + scalar_metrics = { + "adjacent_frame_mean": stats["adjacent_frame_mean"], + "adjacent_frame_std": stats["adjacent_frame_std"], + "adjacent_frame_median": stats["adjacent_frame_median"], + "adjacent_frame_peak": stats["adjacent_frame_peak"], + "random_frame_mean": stats["random_frame_mean"], + "random_frame_std": stats["random_frame_std"], + "random_frame_median": stats["random_frame_median"], + "random_frame_peak": stats["random_frame_peak"], + "smoothness_score": stats["smoothness_score"], + "dynamic_range": stats["dynamic_range"], + } + # Create DataFrame with single row + stats_df = pd.DataFrame(scalar_metrics, index=[0]) + stats_df.to_csv(output_dir / f"{label}_smoothness_stats.csv", index=False) diff --git a/applications/contrastive_phenotyping/figures/grad_attr.py b/applications/contrastive_phenotyping/figures/grad_attr.py index 8dfca61dd..038cc5c96 100644 --- a/applications/contrastive_phenotyping/figures/grad_attr.py +++ b/applications/contrastive_phenotyping/figures/grad_attr.py @@ -1,11 +1,15 @@ # %% +import logging +import warnings from pathlib import Path import matplotlib as mpl +import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch +import xarray as xr from cmap import Colormap from lightning.pytorch import seed_everything from skimage.exposure import rescale_intensity @@ -19,21 +23,34 @@ fit_logistic_regression, linear_from_binary_logistic_regression, ) -from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd +from viscy.transforms import ( + Decollated, + NormalizeSampled, + ScaleIntensityRangePercentilesd, +) -# %% seed_everything(42, workers=True) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +# %% +# Dataset for display and occlusion analysis +data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +tracks_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +annotation_occlusion_infection_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" +annotation_occlusion_division_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" fov = "/B/4/8" -track = 44 +track = [44, 46] # %% dm = TripletDataModule( - data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr", + data_path=data_path, + tracks_path=tracks_path, source_channel=["Phase3D", "RFP"], z_range=[25, 40], - batch_size=48, + batch_size=1, num_workers=0, initial_yx_patch_size=(128, 128), final_yx_patch_size=(128, 128), @@ -44,10 +61,13 @@ ScaleIntensityRangePercentilesd( keys=["RFP"], lower=50, upper=99, b_min=0.0, b_max=1.0 ), + Decollated( + keys=["Phase3D", "RFP"], + ), ], predict_cells=True, - include_fov_names=[fov], - include_track_ids=[track], + include_fov_names=[fov] * len(track), + include_track_ids=track, ) dm.setup("predict") len(dm.predict_dataset) @@ -67,6 +87,253 @@ ), ).eval() + +# %% +def load_and_combine_datasets( + datasets, + target_type="infection", + standardization_mapping=None, +): + """Load and combine multiple embedding datasets with their annotations. + + Parameters + ---------- + datasets : list of tuple + List of (embedding_path, annotation_path, train_fovs) tuples containing + paths to embedding files, annotation CSV files, and training FOVs. + target_type : str, default='infection' + Type of classification target. Either 'infection' or 'division' - determines + which column to look for in the annotation files. + standardization_mapping : dict, optional + Dictionary to standardize different annotation formats across datasets. + Maps original values to standardized values. + Example: {'infected': 2, 'uninfected': 1, 'background': 0, + 2.0: 2, 1.0: 1, 0.0: 0, 'mitosis': 2, 'interphase': 1, 'unknown': 0} + + Returns + ------- + combined_features : xarray.DataArray + Combined feature embeddings from all successfully loaded datasets. + combined_annotations : pandas.Series + Combined and standardized annotations from all datasets. + + Raises + ------ + ValueError + If no datasets were successfully loaded. + """ + + all_features = [] + all_annotations = [] + + # Default standardization mappings + if standardization_mapping is None: + if target_type == "infection": + standardization_mapping = { + # String formats + "infected": 2, + "uninfected": 1, + "background": 0, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + } + elif target_type == "division": + standardization_mapping = { + # String formats + "mitosis": 2, + "interphase": 1, + "unknown": 0, + # Numeric formats + 2.0: 2, + 1.0: 1, + 0.0: 0, + 2: 2, + } + + for emb_path, ann_path, train_fovs in datasets: + try: + logger.debug(f"Loading dataset: {emb_path}") + dataset = read_embedding_dataset(emb_path) + + # Read annotation CSV to detect column names + logger.debug(f"Reading annotation CSV: {ann_path}") + ann_df = pd.read_csv(ann_path) + # make sure the ann_fov_names start with '/' otherwise add it, and strip whitespace + ann_df["fov_name"] = ann_df["fov_name"].apply( + lambda x: ( + "/" + x.strip() if not x.strip().startswith("/") else x.strip() + ) + ) + + if train_fovs == "all": + train_fovs = np.unique(dataset["fov_name"]) + + # Auto-detect annotation column based on target_type + annotation_key = None + if target_type == "infection": + for col in [ + "infection_state", + "infection", + "infection_status", + ]: + if col in ann_df.columns: + annotation_key = col + break + + elif target_type == "division": + for col in ["division", "cell_division", "cell_state"]: + if col in ann_df.columns: + annotation_key = col + break + + if annotation_key is None: + print(f" No {target_type} column found, skipping...") + continue + + # Filter the dataset to only include the FOVs in the annotation + # Use xarray's native filtering methods + ann_fov_names = set(ann_df["fov_name"].unique()) + train_fovs = set(train_fovs) + + logger.debug(f"Dataset FOVs: {dataset['fov_name'].values}") + logger.debug(f"Annotation FOV names: {ann_fov_names}") + logger.debug(f"Train FOVs: {train_fovs}") + logger.debug(f"Dataset samples before filtering: {len(dataset.sample)}") + + # Filter and get only the intersection of train_fovs and ann_fov_names + common_fovs = train_fovs.intersection(ann_fov_names) + # missed out fovs in the dataset + missed_fovs = train_fovs - common_fovs + # missed out fovs in the annotations + missed_fovs_ann = ann_fov_names - common_fovs + + if len(common_fovs) == 0: + raise ValueError( + f"No common FOVs found between dataset and annotations: {train_fovs} not in {ann_fov_names}" + ) + elif len(missed_fovs) > 0: + warnings.warn( + f"No matching found for FOVs in the train dataset: {missed_fovs}" + ) + elif len(missed_fovs_ann) > 0: + warnings.warn( + f"No matching found for FOVs in the annotations: {missed_fovs_ann}" + ) + + logger.debug(f"Intersection of train_fovs and ann_fov_names: {common_fovs}") + + # Filter the dataset to only include the intersection of train_fovs and ann_fov_names + dataset = dataset.where( + dataset["fov_name"].isin(list(common_fovs)), drop=True + ) + + logger.debug(f"Dataset samples after filtering: {len(dataset.sample)}") + + # Load annotations without class mapping first + annotations = load_annotation(dataset, ann_path, annotation_key) + + # Check unique values before standardization + unique_vals = annotations.unique() + logger.debug(f"Original unique values: {unique_vals}") + + # Apply standardization mapping + standardized_annotations = annotations.copy() + if standardization_mapping: + for original_val, standard_val in standardization_mapping.items(): + mask = annotations == original_val + if mask.any(): + standardized_annotations[mask] = standard_val + logger.debug( + f"Mapped {original_val} -> {standard_val} ({mask.sum()} instances)" + ) + + # Check standardized values + std_unique_vals = standardized_annotations.unique() + logger.debug(f"Standardized unique values: {std_unique_vals}") + + # Convert to categorical for consistency + standardized_annotations = standardized_annotations.astype("category") + + # Keep features as xarray DataArray for compatibility with fit_logistic_regression + all_features.append(dataset["features"]) + all_annotations.append(standardized_annotations) + + logger.debug(f"Features shape: {dataset['features'].shape}") + logger.debug(f"Annotations shape: {standardized_annotations.shape}") + except Exception as e: + raise ValueError(f"Error loading dataset {emb_path}: {e}") + + # Combine all datasets + if all_features: + # Extract features and coordinates from each dataset + all_features_arrays = [] + all_coords = [] + + for dataset in all_features: + # Extract the features array + features_array = dataset["features"].values + all_features_arrays.append(features_array) + + # Extract coordinates + coords_dict = {} + for coord_name in dataset.coords: + if coord_name != "sample": # skip sample coordinate + coords_dict[coord_name] = dataset.coords[coord_name].values + all_coords.append(coords_dict) + + # Combine feature arrays + combined_features_array = np.concatenate(all_features_arrays, axis=0) + + # Combine coordinates (excluding 'features' from coordinates) + combined_coords = {} + for coord_name in all_coords[0].keys(): + if coord_name != "features": # Don't include 'features' in coordinates + coord_values = [] + for coords_dict in all_coords: + coord_values.extend(coords_dict[coord_name]) + combined_coords[coord_name] = coord_values + + # Create new combined dataset in the correct format + coords_dict = { + "sample": range(len(combined_features_array)), + } + + # Add each coordinate as a 1D coordinate along the sample dimension + for coord_name, coord_values in combined_coords.items(): + coords_dict[coord_name] = ("sample", coord_values) + + combined_dataset = xr.Dataset( + { + "features": (("sample", "features"), combined_features_array), + }, + coords=coords_dict, + ) + + # Set the index properly like the original datasets + if "fov_name" in combined_coords: + available_coords = [ + coord + for coord in combined_coords.keys() + if coord in ["fov_name", "track_id", "t"] + ] + combined_dataset = combined_dataset.set_index(sample=available_coords) + + combined_annotations = pd.concat(all_annotations, ignore_index=True) + + logger.debug(f"Combined features shape: {combined_dataset['features'].shape}") + logger.debug(f"Combined annotations shape: {combined_annotations.shape}") + + # Final check of combined annotations + final_unique = combined_annotations.unique() + logger.debug(f"Final combined unique values: {final_unique}") + + return combined_dataset["features"], combined_annotations + + # %% # train linear classifier path_infection_embedding = Path( @@ -149,7 +416,7 @@ ) track_classes_infection = infection[infection["fov_name"] == fov[1:]] track_classes_infection = track_classes_infection[ - track_classes_infection["track_id"] == track + track_classes_infection["track_id"].isin(track) ]["infection_state"] # %% @@ -159,13 +426,17 @@ ) track_classes_division = division[division["fov_name"] == fov[1:]] track_classes_division = track_classes_division[ - track_classes_division["track_id"] == track + track_classes_division["track_id"].isin(track) ]["division"] # %% +# Loading the lineage images +img = [] for sample in dm.predict_dataloader(): - img = sample["anchor"].numpy() + img.append(sample["anchor"].numpy()) +img = np.concatenate(img, axis=0) +print(f"Loaded images with shape: {img.shape}") # %% img_tensor = torch.from_numpy(img).to(model.device) @@ -217,18 +488,17 @@ def clim_percentile(heatmap, low=1, high=99): np.concatenate([phase_heatmap_div, rfp_heatmap_div], axis=2), -g_lim, g_lim ) - # %% plt.style.use("./figure.mplstyle") selected_time_points = [3, 6, 15, 16] selected_div_states = [False] * 3 + [True] -sps = len(selected_time_points) - icefire = Colormap("icefire").to_mpl() -f, ax = plt.subplots(3, sps, figsize=(5.5, 3), layout="compressed") +f, ax = plt.subplots( + 3, len(selected_time_points), figsize=(5.5, 3), layout="compressed" +) for i, time in enumerate(selected_time_points): hpi = 3 + 0.5 * time prob = infection_probs[time].item() @@ -264,3 +534,111 @@ def clim_percentile(heatmap, low=1, high=99): ) # %% +# Create video animation of occlusion analysis +icefire = Colormap("icefire").to_mpl() +plt.style.use("./figure.mplstyle") + +fig, ax = plt.subplots(3, 1, figsize=(6, 8), layout="compressed") + +# Initialize plots +im1 = ax[0].imshow(img_render[0], cmap="gray") +ax[0].set_title("Original Image") +ax[0].axis("off") + +im2 = ax[1].imshow(inf_render[0], cmap=icefire, vmin=0, vmax=1) +ax[1].set_title("Infection Occlusion Attribution") +ax[1].axis("off") + +im3 = ax[2].imshow(div_render[0], cmap=icefire, vmin=0, vmax=1) +ax[2].set_title("Division Occlusion Attribution") +ax[2].axis("off") + +# Store initial border colors +for a in ax: + for spine in a.spines.values(): + spine.set_linewidth(3) + spine.set_color("black") + +# Add colorbar +norm = mpl.colors.Normalize(vmin=-g_lim, vmax=g_lim) +cbar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=icefire), + ax=ax[1:], + orientation="horizontal", + shrink=0.8, + pad=0.1, +) +cbar.set_label("Occlusion Attribution") + + +# Animation function +def animate(frame): + time = frame + hpi = 3 + 0.5 * time + + # Update images + im1.set_array(img_render[time]) + im2.set_array(inf_render[time]) + im3.set_array(div_render[time]) + + # Update titles with probabilities + inf_prob = infection_probs[time].item() + div_prob = division_probs[time].item() + inf_binary = bool(track_classes_infection.iloc[time] - 1) + div_binary = bool(track_classes_division.iloc[time] - 1) + + # Color code labels - red for true, green for false + inf_color = "darkorange" if inf_binary else "blue" + div_color = "darkorange" if div_binary else "blue" + + # Make label text bold when true + inf_weight = "bold" if inf_binary else "normal" + div_weight = "bold" if div_binary else "normal" + + # Update border colors to highlight true labels + for spine in ax[1].spines.values(): + spine.set_color(inf_color) + spine.set_linewidth(4 if inf_binary else 2) + + for spine in ax[2].spines.values(): + spine.set_color(div_color) + spine.set_linewidth(4 if div_binary else 2) + + ax[0].set_title(f"Original Image - {hpi:.1f} HPI", fontsize=12, fontweight="bold") + ax[1].set_title( + f"Infection Attribution - Prob: {inf_prob:.3f} (Label: {str(inf_binary).lower()})", + fontsize=12, + fontweight=inf_weight, + color=inf_color, + ) + ax[2].set_title( + f"Division Attribution - Prob: {div_prob:.3f} (Label: {str(div_binary).lower()})", + fontsize=12, + fontweight=div_weight, + color=div_color, + ) + + return [im1, im2, im3] + + +# %% + +# Create animation +anim = animation.FuncAnimation( + fig, animate, frames=len(img_render), interval=200, blit=True, repeat=True +) + +# Save as video +video_path = ( + Path.home() + / "mydata" + / "gdrive/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/occlusion_analysis_video.mp4" +) +video_path.parent.mkdir(parents=True, exist_ok=True) + +# Save as MP4 +Writer = animation.writers["ffmpeg"] +writer = Writer(fps=5, metadata=dict(artist="VisCy"), bitrate=1800) +anim.save(str(video_path), writer=writer) + +print(f"Video saved to: {video_path}") diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py new file mode 100644 index 000000000..0381c12c4 --- /dev/null +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py @@ -0,0 +1,759 @@ +# %% +import ast +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from plotting_utils import ( + find_pattern_matches, + identify_lineages, + plot_pc_trajectories, +) +from tqdm import tqdm + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset + +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") # Simplified format +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +NAPARI = True +if NAPARI: + import os + + import napari + + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() +# %% +# Organelle and Phate aligned to infection + +input_data_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +) +infection_annotations_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/combined_annotations_n_tracks_infection.csv" +) + +pretrain_features_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/prediction_pretrained_models" +) +# Phase n organelle +# dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + +# pahe n sensor +dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions_infection/2chan_192patch_100ckpt_timeAware_ntxent_GT.zarr" + +output_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/model_comparison" +) + + +# Load embeddings +imagenet_features_path = ( + pretrain_features_root / "ImageNet/20241107_sensor_n_phase_imagenet.zarr" +) +openphenom_features_path = ( + pretrain_features_root / "OpenPhenom/20241107_sensor_n_phase_openphenom.zarr" +) + +dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path) +imagenet_embeddings = read_embedding_dataset(imagenet_features_path) +openphenom_embeddings = read_embedding_dataset(openphenom_features_path) + +# Load infection annotations +infection_annotations_df = pd.read_csv(infection_annotations_path) +infection_annotations_df["fov_name"] = "/C/2/000001" + +process_embeddings = [ + (dynaclr_embeddings, "dynaclr"), + (imagenet_embeddings, "imagenet"), + (openphenom_embeddings, "openphenom"), +] + + +output_root.mkdir(parents=True, exist_ok=True) +# %% +feature_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True) + +# Logic to find lineages +lineages = identify_lineages(feature_df) +logger.info(f"Found {len(lineages)} distinct lineages") +filtered_lineages = [] +min_timepoints = 20 +for fov_id, track_ids in lineages: + # Get all rows for this lineage + lineage_rows = feature_df[ + (feature_df["fov_name"] == fov_id) & (feature_df["track_id"].isin(track_ids)) + ] + + # Count the total number of timepoints + total_timepoints = len(lineage_rows) + + # Only keep lineages with at least min_timepoints + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) +logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +) + +# %% +# Aligning condition embeddings to infection +# OPTION 1: Use the infection annotations to find the reference lineage + +# Option 2: from the filtered lineages find one from FOV C/2/000001 +reference_lineage_fov = "/C/2/000001" +for fov_id, track_ids in filtered_lineages: + if reference_lineage_fov == fov_id: + break +reference_lineage_track_id = track_ids +reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling + +# %% +# Dictionary to store alignment results for comparison +alignment_results = {} + +for embeddings, name in process_embeddings: + # Get the reference pattern from the current embedding space + reference_pattern = None + reference_lineage = [] + for fov_id, track_ids in filtered_lineages: + if fov_id == reference_lineage_fov and all( + track_id in track_ids for track_id in reference_lineage_track_id + ): + logger.info( + f"Found reference pattern for {fov_id} {reference_lineage_track_id} using {name} embeddings" + ) + reference_pattern = embeddings.sel( + sample=(fov_id, reference_lineage_track_id) + ).features.values + reference_lineage.append(reference_pattern) + break + if reference_pattern is None: + logger.info(f"Reference pattern not found for {name} embeddings. Skipping.") + continue + reference_pattern = np.concatenate(reference_lineage) + reference_pattern = reference_pattern[ + reference_timepoints[0] : reference_timepoints[1] + ] + + # Find all matches to the reference pattern + metric = "cosine" + all_match_positions = find_pattern_matches( + reference_pattern, + filtered_lineages, + embeddings, + window_step_fraction=0.1, + num_candidates=4, + method="bernd_clifford", + save_path=output_root / f"{name}_matching_lineages_{metric}.csv", + metric=metric, + ) + + # Store results for later comparison + alignment_results[name] = all_match_positions + +# Visualize warping paths in PC space instead of raw embedding dimensions +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Call the new function from plotting_utils + plot_pc_trajectories( + reference_lineage_fov=reference_lineage_fov, + reference_lineage_track_id=reference_lineage_track_id, + reference_timepoints=reference_timepoints, + match_positions=match_positions, + embeddings_dataset=next( + emb for emb, emb_name in process_embeddings if emb_name == name + ), + filtered_lineages=filtered_lineages, + name=name, + save_path=output_root / f"{name}_pc_lineage_alignment.png", + ) + + +# %% +# Compare DTW performance between embedding methods + +# Create a DataFrame to collect the alignment statistics for comparison +match_data = [] +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + for i, row in match_positions.head(10).iterrows(): # Take top 10 matches + warping_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + match_data.append( + { + "model": name, + "match_position": row["start_timepoint"], + "dtw_distance": row["distance"], + "path_skewness": row["skewness"], + "path_length": len(warping_path), + } + ) + +comparison_df = pd.DataFrame(match_data) + +# Create visualizations to compare alignment quality +plt.figure(figsize=(12, 10)) + +# 1. Compare DTW distances +plt.subplot(2, 2, 1) +sns.boxplot(x="model", y="dtw_distance", data=comparison_df) +plt.title("DTW Distance by Model") +plt.ylabel("DTW Distance (lower is better)") + +# 2. Compare path skewness +plt.subplot(2, 2, 2) +sns.boxplot(x="model", y="path_skewness", data=comparison_df) +plt.title("Path Skewness by Model") +plt.ylabel("Skewness (lower is better)") + +# 3. Compare path lengths +plt.subplot(2, 2, 3) +sns.boxplot(x="model", y="path_length", data=comparison_df) +plt.title("Warping Path Length by Model") +plt.ylabel("Path Length") + +# 4. Scatterplot of distance vs skewness +plt.subplot(2, 2, 4) +scatter = sns.scatterplot( + x="dtw_distance", y="path_skewness", hue="model", data=comparison_df +) +plt.title("DTW Distance vs Path Skewness") +plt.xlabel("DTW Distance") +plt.ylabel("Path Skewness") +plt.legend(title="Model") + +plt.tight_layout() +plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300) +plt.close() + +# %% +# Analyze warping path step patterns for better understanding of alignment quality + +# Step pattern analysis +step_pattern_counts = { + name: {"diagonal": 0, "horizontal": 0, "vertical": 0, "total": 0} + for name in alignment_results.keys() +} + +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Get the top match + top_match = match_positions.iloc[0] + path = ( + ast.literal_eval(top_match["warp_path"]) + if isinstance(top_match["warp_path"], str) + else top_match["warp_path"] + ) + + # Count step types + for i in range(1, len(path)): + prev_i, prev_j = path[i - 1] + curr_i, curr_j = path[i] + + step_i = curr_i - prev_i + step_j = curr_j - prev_j + + if step_i == 1 and step_j == 1: + step_pattern_counts[name]["diagonal"] += 1 + elif step_i == 1 and step_j == 0: + step_pattern_counts[name]["vertical"] += 1 + elif step_i == 0 and step_j == 1: + step_pattern_counts[name]["horizontal"] += 1 + + step_pattern_counts[name]["total"] += 1 + +# Convert to percentages +for name in step_pattern_counts: + total = step_pattern_counts[name]["total"] + if total > 0: + for key in ["diagonal", "horizontal", "vertical"]: + step_pattern_counts[name][key] = ( + step_pattern_counts[name][key] / total + ) * 100 + +# Visualize step pattern distributions +step_df = pd.DataFrame( + { + "model": [name for name in step_pattern_counts.keys() for _ in range(3)], + "step_type": ["diagonal", "horizontal", "vertical"] * len(step_pattern_counts), + "percentage": [ + step_pattern_counts[name]["diagonal"] for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["horizontal"] + for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["vertical"] for name in step_pattern_counts.keys() + ], + } +) + +plt.figure(figsize=(10, 6)) +sns.barplot(x="model", y="percentage", hue="step_type", data=step_df) +plt.title("Step Pattern Distribution in Warping Paths") +plt.ylabel("Percentage (%)") +plt.savefig(output_root / "step_pattern_distribution.png", dpi=300) +plt.close() + +# %% +# Find all matches to the reference pattern +MODEL = "openphenom" +alignment_df_path = output_root / f"{MODEL}_matching_lineages_cosine.csv" +alignment_df = pd.read_csv(alignment_df_path) + +# Get the top N aligned cells + +source_channels = [ + "Phase3D", + "raw GFP EX488 EM525-45", + "raw mCherry EX561 EM600-37", +] +yx_patch_size = (192, 192) +z_range = (10, 30) +view_ref_sector_only = (True,) + +all_lineage_images = [] +all_aligned_stacks = [] +all_unaligned_stacks = [] + +# Get aligned and unaligned stacks +top_aligned_cells = alignment_df.head(5) +napari_viewer = viewer if NAPARI else None +# Plot the aligned and unaligned stacks +for idx, row in tqdm( + top_aligned_cells.iterrows(), + total=len(top_aligned_cells), + desc="Aligning images", +): + fov_name = row["fov_name"] + track_ids = ast.literal_eval(row["track_ids"]) + warp_path = ast.literal_eval(row["warp_path"]) + start_time = int(row["start_timepoint"]) + + print(f"Aligning images for {fov_name} with track ids: {track_ids}") + data_module = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=source_channels, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + batch_size=1, + num_workers=12, + predict_cells=True, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + ) + data_module.setup("predict") + + # Get the images for the lineage + lineage_images = [] + for batch in data_module.predict_dataloader(): + image = batch["anchor"].numpy()[0] + lineage_images.append(image) + + lineage_images = np.array(lineage_images) + all_lineage_images.append(lineage_images) + print(f"Lineage images shape: {np.array(lineage_images).shape}") + + # Create an aligned stack based on the warping path + if view_ref_sector_only: + aligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + unaligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + + # Map each reference timepoint to the corresponding lineage timepoint + for ref_idx in range(len(reference_pattern)): + # Find matches in warping path for this reference index + matches = [(i, q) for i, q in warp_path if i == ref_idx] + unaligned_stack[ref_idx] = lineage_images[ref_idx] + if matches: + # Get the corresponding lineage timepoint (first match if multiple) + print(f"Found match for ref idx: {ref_idx}") + match = matches[0] + query_idx = match[1] + lineage_idx = int(start_time + query_idx) + print( + f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + ) + # Copy the image if it's within bounds + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Find nearest valid timepoint if out of bounds + nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + else: + # If no direct match, find closest reference timepoint in warping path + print(f"No match found for ref idx: {ref_idx}") + all_ref_indices = [i for i, _ in warp_path] + if all_ref_indices: + closest_ref_idx = min( + all_ref_indices, key=lambda x: abs(x - ref_idx) + ) + closest_matches = [ + (i, q) for i, q in warp_path if i == closest_ref_idx + ] + + if closest_matches: + closest_query_idx = closest_matches[0][1] + lineage_idx = int(start_time + closest_query_idx) + + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Bound to valid range + nearest_idx = min( + max(0, lineage_idx), len(lineage_images) - 1 + ) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + + all_aligned_stacks.append(aligned_stack) + all_unaligned_stacks.append(unaligned_stack) + +all_aligned_stacks = np.array(all_aligned_stacks) +all_unaligned_stacks = np.array(all_unaligned_stacks) +# %% +if NAPARI: + for idx, row in tqdm( + top_aligned_cells.reset_index().iterrows(), + total=len(top_aligned_cells), + desc="Plotting aligned and unaligned stacks", + ): + fov_name = row["fov_name"] + # track_ids = ast.literal_eval(row["track_ids"]) + track_ids = row["track_ids"] + + aligned_stack = all_aligned_stacks[idx] + unaligned_stack = all_unaligned_stacks[idx] + + unaligned_gfp_mip = np.max(unaligned_stack[:, 1, :, :], axis=1) + aligned_gfp_mip = np.max(aligned_stack[:, 1, :, :], axis=1) + unaligned_mcherry_mip = np.max(unaligned_stack[:, 2, :, :], axis=1) + aligned_mcherry_mip = np.max(aligned_stack[:, 2, :, :], axis=1) + + z_slice = 15 + unaligned_phase = unaligned_stack[:, 0, z_slice, :] + aligned_phase = aligned_stack[:, 0, z_slice, :] + + # unaligned + viewer.add_image( + unaligned_gfp_mip, + name=f"unaligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + unaligned_mcherry_mip, + name=f"unaligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + unaligned_phase, + name=f"unaligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + # aligned + viewer.add_image( + aligned_gfp_mip, + name=f"aligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + aligned_mcherry_mip, + name=f"aligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + aligned_phase, + name=f"aligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + viewer.grid.enabled = True + viewer.grid.shape = (-1, 6) +# %% +# Evaluate model performance based on infection state warping accuracy +# Check unique infection status values +unique_infection_statuses = infection_annotations_df["infection_status"].unique() +logger.info(f"Unique infection status values: {unique_infection_statuses}") + +# If "infected" is not in the unique values, this could explain zero precision/recall +if "infected" not in unique_infection_statuses: + logger.warning('The label "infected" is not found in the infection_status column!') + logger.info(f"Using these values instead: {unique_infection_statuses}") + + # If we need to map values, we could do it here + if len(unique_infection_statuses) >= 2: + logger.info( + f'Will treat "{unique_infection_statuses[1]}" as "infected" for metrics calculation' + ) + infection_target_value = unique_infection_statuses[1] + else: + infection_target_value = unique_infection_statuses[0] +else: + infection_target_value = "infected" + +logger.info(f'Using "{infection_target_value}" as positive class for F1 calculation') + +# Check if the reference track is in the annotations +logger.info( + f"Looking for infection annotations for reference lineage: {reference_lineage_fov}, tracks: {reference_lineage_track_id}" +) +print(f"Sample of infection_annotations_df: {infection_annotations_df.head()}") + +reference_infection_states = {} +for track_id in reference_lineage_track_id: + reference_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == reference_lineage_fov) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this reference track + annotation_count = len(reference_annotations) + logger.info(f"Found {annotation_count} annotations for track {track_id}") + if annotation_count > 0: + print( + f"Sample annotations for track {track_id}: {reference_annotations.head()}" + ) + + for _, row in reference_annotations.iterrows(): + reference_infection_states[row["t"]] = row["infection_status"] + +if reference_infection_states: + logger.info( + f"Total reference timepoints with infection status: {len(reference_infection_states)}" + ) + reference_t_range = range(reference_timepoints[0], reference_timepoints[1]) + reference_gt_states = [ + reference_infection_states.get(t, "unknown") for t in reference_t_range + ] + logger.info(f"Reference track infection states: {reference_gt_states[:5]}...") + + # Evaluate warping accuracy for each model + model_performance = [] + + for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + total_correct = 0 + total_predictions = 0 + true_positives = 0 + false_positives = 0 + false_negatives = 0 + + # Analyze top alignments for this model + alignment_details = [] + for i, row in match_positions.head(10).iterrows(): + fov_name = row["fov_name"] + track_ids = row[ + "track_ids" + ] # This is already a list of track IDs for the lineage + warp_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + start_time = int(row["start_timepoint"]) + + # Get annotations for all tracks in this lineage + track_infection_states = {} + for track_id in track_ids: + track_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == fov_name) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this track to the combined dictionary + for _, annotation_row in track_annotations.iterrows(): + # Use t + track-specific offset if needed to handle timepoint overlaps between tracks + track_infection_states[annotation_row["t"]] = annotation_row[ + "infection_status" + ] + + # Only proceed if we found annotations for at least one track + if track_infection_states: + # For each reference timepoint, check if the warped timepoint maintains the infection state + track_correct = 0 + track_predictions = 0 + track_tp = 0 + track_fp = 0 + track_fn = 0 + + for ref_idx, query_idx in warp_path: + # Map to actual timepoints + ref_t = reference_timepoints[0] + ref_idx + query_t = start_time + query_idx + + # Get ground truth infection states + ref_state = reference_infection_states.get(ref_t, "unknown") + query_state = track_infection_states.get(query_t, "unknown") + + # Skip unknown states + if ref_state != "unknown" and query_state != "unknown": + track_predictions += 1 + + # Count correct alignments + if ref_state == query_state: + track_correct += 1 + + # Calculate F1 score components for "infected" state + if ( + ref_state == infection_target_value + and query_state == infection_target_value + ): + track_tp += 1 + elif ( + ref_state != infection_target_value + and query_state == infection_target_value + ): + track_fp += 1 + elif ( + ref_state == infection_target_value + and query_state != infection_target_value + ): + track_fn += 1 + + # Calculate track-specific metrics + if track_predictions > 0: + track_accuracy = track_correct / track_predictions + track_precision = ( + track_tp / (track_tp + track_fp) + if (track_tp + track_fp) > 0 + else 0 + ) + track_recall = ( + track_tp / (track_tp + track_fn) + if (track_tp + track_fn) > 0 + else 0 + ) + track_f1 = ( + 2 + * (track_precision * track_recall) + / (track_precision + track_recall) + if (track_precision + track_recall) > 0 + else 0 + ) + + alignment_details.append( + { + "fov_name": fov_name, + "track_ids": track_ids, + "accuracy": track_accuracy, + "precision": track_precision, + "recall": track_recall, + "f1_score": track_f1, + "correct": track_correct, + "total": track_predictions, + } + ) + + # Add to model totals + total_correct += track_correct + total_predictions += track_predictions + true_positives += track_tp + false_positives += track_fp + false_negatives += track_fn + + # Calculate metrics + accuracy = total_correct / total_predictions if total_predictions > 0 else 0 + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + # Store alignment details for this model + if alignment_details: + alignment_details_df = pd.DataFrame(alignment_details) + print(f"\nDetailed alignment results for {name}:") + print(alignment_details_df) + alignment_details_df.to_csv( + output_root / f"{name}_alignment_details.csv", index=False + ) + + model_performance.append( + { + "model": name, + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1_score": f1, + "total_predictions": total_predictions, + } + ) + + # Create performance DataFrame and visualize + performance_df = pd.DataFrame(model_performance) + print(performance_df) + + # Plot performance metrics + plt.figure(figsize=(12, 8)) + + # Accuracy plot + plt.subplot(2, 2, 1) + sns.barplot(x="model", y="accuracy", data=performance_df) + plt.title("Infection State Warping Accuracy") + plt.ylabel("Accuracy") + + # Precision plot + plt.subplot(2, 2, 2) + sns.barplot(x="model", y="precision", data=performance_df) + plt.title("Precision for Infected State") + plt.ylabel("Precision") + + # Recall plot + plt.subplot(2, 2, 3) + sns.barplot(x="model", y="recall", data=performance_df) + plt.title("Recall for Infected State") + plt.ylabel("Recall") + + # F1 score plot + plt.subplot(2, 2, 4) + sns.barplot(x="model", y="f1_score", data=performance_df) + plt.title("F1 Score for Infected State") + plt.ylabel("F1 Score") + + plt.tight_layout() + # plt.savefig(output_root / "infection_state_warping_performance.png", dpi=300) + # plt.close() +else: + logger.warning("Reference track annotations not found in infection_annotations_df") + +# %% diff --git a/tests/representation/evaluation/test_clustering.py b/tests/representation/evaluation/test_clustering.py new file mode 100644 index 000000000..cfff09694 --- /dev/null +++ b/tests/representation/evaluation/test_clustering.py @@ -0,0 +1,181 @@ +import numpy as np +import pytest +from numpy.typing import NDArray + +from viscy.representation.evaluation.clustering import pairwise_distance_matrix + + +@pytest.fixture +def sample_features(): + """Create sample features for testing.""" + np.random.seed(42) + return np.random.randn(50, 128).astype(np.float64) + + +@pytest.fixture +def small_features(): + """Create small sample with known values for numerical testing.""" + return np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]) + + +class TestPairwiseDistanceMatrix: + """Tests for pairwise_distance_matrix function.""" + + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + def test_scipy_baseline(self, sample_features: NDArray, metric: str): + """Test that scipy backend produces valid distance matrices.""" + dist_matrix = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + + # Check shape + n = len(sample_features) + assert dist_matrix.shape == (n, n) + + # Check symmetry + assert np.allclose(dist_matrix, dist_matrix.T) + + # Check diagonal is zero (or near zero for numerical precision) + assert np.allclose(np.diag(dist_matrix), 0, atol=1e-10) + + # Check all distances are non-negative + assert np.all(dist_matrix >= 0) + + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + @pytest.mark.parametrize("device", ["cpu", "auto"]) + def test_torch_vs_scipy(self, sample_features: NDArray, metric: str, device: str): + """Test that PyTorch implementation matches scipy results.""" + pytest.importorskip("torch") + + dist_scipy = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + dist_torch = pairwise_distance_matrix( + sample_features, metric=metric, device=device + ) + + # Check numerical agreement + assert np.allclose(dist_scipy, dist_torch, rtol=1e-5, atol=1e-6) + + @pytest.mark.skipif( + not pytest.importorskip("torch").cuda.is_available(), + reason="CUDA not available", + ) + @pytest.mark.parametrize("metric", ["cosine", "euclidean"]) + def test_gpu_vs_scipy(self, sample_features: NDArray, metric: str): + """Test that GPU implementation matches scipy results.""" + dist_scipy = pairwise_distance_matrix( + sample_features, metric=metric, device="scipy" + ) + dist_gpu = pairwise_distance_matrix( + sample_features, metric=metric, device="cuda" + ) + + # Check numerical agreement + assert np.allclose(dist_scipy, dist_gpu, rtol=1e-5, atol=1e-6) + + def test_cosine_distance_known_values(self, small_features: NDArray): + """Test cosine distance with known values.""" + dist_matrix = pairwise_distance_matrix( + small_features, metric="cosine", device="scipy" + ) + + # [1,0] and [0,1] are orthogonal: cosine distance = 1 + assert np.isclose(dist_matrix[0, 1], 1.0, atol=1e-10) + + # [1,1] and [0.5, 0.5] are parallel: cosine distance = 0 + assert np.isclose(dist_matrix[2, 3], 0.0, atol=1e-10) + + # [1,0] and [1,1]: cosine similarity = 1/sqrt(2), distance = 1 - 1/sqrt(2) + expected = 1 - 1 / np.sqrt(2) + assert np.isclose(dist_matrix[0, 2], expected, atol=1e-10) + + def test_euclidean_distance_known_values(self, small_features: NDArray): + """Test euclidean distance with known values.""" + dist_matrix = pairwise_distance_matrix( + small_features, metric="euclidean", device="scipy" + ) + + # Distance between [1,0] and [0,1] is sqrt(2) + assert np.isclose(dist_matrix[0, 1], np.sqrt(2), atol=1e-10) + + # Distance between [1,1] and [0.5, 0.5] is sqrt(0.5) + assert np.isclose(dist_matrix[2, 3], np.sqrt(0.5), atol=1e-10) + + def test_unsupported_metric_falls_back_to_scipy(self, sample_features: NDArray): + """Test that unsupported metrics fall back to scipy.""" + # These metrics are only supported by scipy, not PyTorch + dist_matrix = pairwise_distance_matrix( + sample_features, metric="cityblock", device="auto" + ) + + # Should still produce valid results via scipy fallback + n = len(sample_features) + assert dist_matrix.shape == (n, n) + assert np.allclose(dist_matrix, dist_matrix.T) + + def test_device_options(self, sample_features: NDArray): + """Test various device options.""" + # Test scipy explicitly + dist_scipy = pairwise_distance_matrix( + sample_features, metric="cosine", device="scipy" + ) + assert dist_scipy is not None + + # Test None as scipy + dist_none = pairwise_distance_matrix( + sample_features, metric="cosine", device=None + ) + assert np.allclose(dist_scipy, dist_none) + + @pytest.mark.skipif( + not pytest.importorskip("torch").cuda.is_available(), + reason="CUDA not available", + ) + def test_cuda_aliases(self, sample_features: NDArray): + """Test that cuda and gpu device names work.""" + dist_cuda = pairwise_distance_matrix( + sample_features, metric="cosine", device="cuda" + ) + dist_gpu = pairwise_distance_matrix( + sample_features, metric="cosine", device="gpu" + ) + + assert np.allclose(dist_cuda, dist_gpu) + + def test_invalid_device_raises_error(self, sample_features: NDArray): + """Test that invalid device names raise appropriate errors.""" + pytest.importorskip("torch") + + with pytest.raises(ValueError, match="Invalid device"): + pairwise_distance_matrix( + sample_features, metric="cosine", device="invalid_device" + ) + + def test_float32_input_preserves_precision(self): + """Test that float32 input is converted to float64 for precision.""" + pytest.importorskip("torch") + + features_f32 = np.random.randn(10, 32).astype(np.float32) + + dist_scipy = pairwise_distance_matrix( + features_f32, metric="cosine", device="scipy" + ) + dist_torch = pairwise_distance_matrix( + features_f32, metric="cosine", device="cpu" + ) + + # Should still have good agreement despite float32 input + assert np.allclose(dist_scipy, dist_torch, rtol=1e-5, atol=1e-6) + + def test_large_matrix_shape(self): + """Test with larger matrix to ensure it works at scale.""" + large_features = np.random.randn(500, 64).astype(np.float64) + + dist_matrix = pairwise_distance_matrix( + large_features, metric="cosine", device="auto" + ) + + assert dist_matrix.shape == (500, 500) + assert np.allclose(dist_matrix, dist_matrix.T) + assert np.allclose(np.diag(dist_matrix), 0, atol=1e-6) diff --git a/tests/representation/evaluation/test_distance.py b/tests/representation/evaluation/test_distance.py new file mode 100644 index 000000000..6f82d5e54 --- /dev/null +++ b/tests/representation/evaluation/test_distance.py @@ -0,0 +1,117 @@ +import numpy as np +import pytest +import xarray as xr + +from viscy.representation.evaluation.distance import ( + calculate_cosine_similarity_cell, + compute_track_displacement, +) + + +@pytest.fixture +def sample_embedding_dataset(): + """Create a sample embedding dataset for testing.""" + n_samples = 10 + n_features = 5 + + features = np.random.rand(n_samples, n_features) + fov_names = ["fov1"] * 5 + ["fov2"] * 5 + track_ids = [1, 1, 1, 2, 2, 3, 3, 3, 4, 4] + time_points = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1] + + dataset = xr.Dataset( + { + "features": (["sample", "features"], features), + "fov_name": (["sample"], fov_names), + "track_id": (["sample"], track_ids), + "t": (["sample"], time_points), + } + ) + return dataset + + +def test_calculate_cosine_similarity_cell(sample_embedding_dataset): + """Test cosine similarity calculation for a single track.""" + time_points, similarities = calculate_cosine_similarity_cell( + sample_embedding_dataset, "fov1", 1 + ) + + assert len(time_points) == len(similarities) + assert len(time_points) == 3 + assert np.isclose(similarities[0], 1.0, atol=1e-6) + assert all(-1 <= sim <= 1 for sim in similarities) + + +@pytest.mark.parametrize("distance_metric", ["cosine", "euclidean", "sqeuclidean"]) +def test_compute_track_displacement(sample_embedding_dataset, distance_metric): + """Test track displacement computation with different metrics.""" + result = compute_track_displacement( + sample_embedding_dataset, distance_metric=distance_metric + ) + + assert isinstance(result, dict) + assert all(isinstance(tau, int) for tau in result.keys()) + assert all(isinstance(displacements, list) for displacements in result.values()) + assert all( + all(isinstance(d, (int, float)) and d >= 0 for d in displacements) + for displacements in result.values() + ) + + +def test_compute_track_displacement_numerical(): + """Test compute_track_displacement with known embeddings and expected results.""" + features = np.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + + dataset = xr.Dataset( + { + "features": (["sample", "features"], features), + "fov_name": (["sample"], ["fov1", "fov1", "fov1"]), + "track_id": (["sample"], [1, 1, 1]), + "t": (["sample"], [0, 1, 2]), + } + ) + result_euclidean = compute_track_displacement(dataset, distance_metric="euclidean") + + assert 1 in result_euclidean + assert 2 in result_euclidean + assert len(result_euclidean[1]) == 2 + assert len(result_euclidean[2]) == 1 + + result_sqeuclidean = compute_track_displacement( + dataset, distance_metric="sqeuclidean" + ) + expected_tau1 = [2.0, 1.0] + expected_tau2 = [1.0] + + assert np.allclose(sorted(result_sqeuclidean[1]), sorted(expected_tau1), atol=1e-10) + assert np.allclose(result_sqeuclidean[2], expected_tau2, atol=1e-10) + + result_cosine = compute_track_displacement(dataset, distance_metric="cosine") + expected_cosine_tau1 = [1.0, 1 - 1 / np.sqrt(2)] + expected_cosine_tau2 = [1 - 1 / np.sqrt(2)] + + assert np.allclose( + sorted(result_cosine[1]), sorted(expected_cosine_tau1), atol=1e-10 + ) + assert np.allclose(result_cosine[2], expected_cosine_tau2, atol=1e-10) + + +def test_compute_track_displacement_empty_dataset(): + """Test behavior with empty dataset.""" + empty_dataset = xr.Dataset( + { + "features": (["sample", "features"], np.empty((0, 5))), + "fov_name": (["sample"], []), + "track_id": (["sample"], []), + "t": (["sample"], []), + } + ) + + result = compute_track_displacement(empty_dataset) + assert result == {} diff --git a/viscy/data/cell_division_triplet.py b/viscy/data/cell_division_triplet.py new file mode 100644 index 000000000..4547adaa1 --- /dev/null +++ b/viscy/data/cell_division_triplet.py @@ -0,0 +1,449 @@ +import logging +import random +from pathlib import Path +from typing import Literal, Sequence + +import numpy as np +import torch +from monai.transforms import Compose, MapTransform +from torch import Tensor +from torch.utils.data import Dataset + +from viscy.data.hcs import HCSDataModule +from viscy.data.triplet import ( + _transform_channel_wise, +) +from viscy.data.typing import DictTransform, TripletSample + +_logger = logging.getLogger("lightning.pytorch") + + +class CellDivisionTripletDataset(Dataset): + """Dataset for triplet sampling of cell division data from npy files. + + For the dataset from the paper: + https://arxiv.org/html/2502.02182v1 + """ + + # NOTE: Hardcoded channel mapping for .npy files + CHANNEL_MAPPING = { + # Channel 0 aliases (brightfield) + "bf": 0, + "brightfield": 0, + # Channel 1 aliases (h2b) + "h2b": 1, + "nuclei": 1, + } + + def __init__( + self, + data_paths: list[Path], + channel_names: list[str], + anchor_transform: DictTransform | None = None, + positive_transform: DictTransform | None = None, + negative_transform: DictTransform | None = None, + fit: bool = True, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + output_2d: bool = False, + ) -> None: + """Dataset for triplet sampling of cell division data from npy files. + + Parameters + ---------- + data_paths : list[Path] + List of paths to npy files containing cell division tracks (T,C,Y,X format) + channel_names : list[str] + Input channel names + anchor_transform : DictTransform | None, optional + Transforms applied to the anchor sample, by default None + positive_transform : DictTransform | None, optional + Transforms applied to the positive sample, by default None + negative_transform : DictTransform | None, optional + Transforms applied to the negative sample, by default None + fit : bool, optional + Fitting mode in which the full triplet will be sampled, + only sample anchor if False, by default True + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage, by default True + output_2d : bool, optional + Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False + """ + self.channel_names = channel_names + self.anchor_transform = anchor_transform + self.positive_transform = positive_transform + self.negative_transform = negative_transform + self.fit = fit + self.time_interval = time_interval + self.return_negative = return_negative + self.output_2d = output_2d + + # Load and process all data files + self.cell_tracks = self._load_data(data_paths) + self.valid_anchors = self._filter_anchors() + + # Create arrays for vectorized operations + self.track_ids = np.array([t["track_id"] for t in self.cell_tracks]) + self.cell_tracks_array = np.array(self.cell_tracks) + + # Map channel names to indices using CHANNEL_MAPPING + self.channel_indices = self._map_channel_indices(channel_names) + + def _map_channel_indices(self, channel_names: list[str]) -> list[int]: + """Map channel names to their corresponding indices in the data array.""" + channel_indices = [] + for name in channel_names: + if name in self.CHANNEL_MAPPING: + channel_indices.append(self.CHANNEL_MAPPING[name]) + else: + # Try to parse as integer if not in mapping + try: + channel_indices.append(int(name)) + except ValueError: + raise ValueError( + f"Channel '{name}' not found in CHANNEL_MAPPING and is not a valid integer" + ) + return channel_indices + + def _select_channels(self, patch: Tensor) -> Tensor: + """Select only the requested channels from the patch.""" + return patch[self.channel_indices] + + def _load_data(self, data_paths: list[Path]) -> list[dict]: + """Load npy files.""" + all_tracks = [] + + for path in data_paths: + data = np.load(path) # Shape: (T, C, Y, X) + T, C, Y, X = data.shape + + # Create track info for this file + # NOTE: using the filename as track ID as UID. + track_info = { + "data": torch.from_numpy(data.astype(np.float32)), + "file_path": str(path), + "track_id": path.stem, + "num_timepoints": T, + "shape": (T, C, Y, X), + } + all_tracks.append(track_info) + + _logger.info(f"Loaded {len(all_tracks)} tracks") + return all_tracks + + def _filter_anchors(self) -> list[dict]: + """Create valid anchor points based on time interval constraints.""" + valid_anchors = [] + + for track in self.cell_tracks: + num_timepoints = track["num_timepoints"] + + if self.time_interval == "any" or not self.fit: + valid_timepoints = list(range(num_timepoints)) + else: + # Only timepoints that have a future timepoint at the specified interval + valid_timepoints = list(range(num_timepoints - self.time_interval)) + + for t in valid_timepoints: + anchor_info = { + "track": track, + "timepoint": t, + "track_id": track["track_id"], + "file_path": track["file_path"], + } + valid_anchors.append(anchor_info) + + return valid_anchors + + def __len__(self) -> int: + return len(self.valid_anchors) + + def _sample_positive(self, anchor_info: dict) -> Tensor: + """Select a positive sample from the same track.""" + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + + if self.time_interval == "any": + # Use the same anchor patch (will be augmented differently) + positive_t = anchor_t + else: + # Use future timepoint + positive_t = anchor_t + self.time_interval + + positive_patch = track["data"][positive_t] + positive_patch = self._select_channels(positive_patch) + if not self.output_2d: + positive_patch = positive_patch.unsqueeze(1) + return positive_patch + + def _sample_negative(self, anchor_info: dict) -> Tensor: + """Select a negative sample from a different track.""" + anchor_track_id = anchor_info["track_id"] + + # Vectorized filtering using boolean indexing + mask = self.track_ids != anchor_track_id + negative_candidates = self.cell_tracks_array[mask].tolist() + + if not negative_candidates: + # Fallback: use different timepoint from same track + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + available_times = [ + t for t in range(track["num_timepoints"]) if t != anchor_t + ] + if available_times: + neg_t = random.choice(available_times) + negative_patch = track["data"][neg_t] + negative_patch = self._select_channels(negative_patch) + else: + # Ultimate fallback: use same patch (transforms will differentiate) + negative_patch = track["data"][anchor_t] + negative_patch = self._select_channels(negative_patch) + else: + # Sample from different track + neg_track = random.choice(negative_candidates) + + if self.time_interval == "any": + neg_t = random.randint(0, neg_track["num_timepoints"] - 1) + else: + # Try to use same relative timepoint, fallback to random + anchor_t = anchor_info["timepoint"] + target_t = anchor_t + self.time_interval + if target_t < neg_track["num_timepoints"]: + neg_t = target_t + else: + neg_t = random.randint(0, neg_track["num_timepoints"] - 1) + + negative_patch = neg_track["data"][neg_t] + negative_patch = self._select_channels(negative_patch) + + # Add depth dimension only if not output_2d: (C, Y, X) -> (C, D=1, Y, X) + if not self.output_2d: + negative_patch = negative_patch.unsqueeze(1) # Shape: (C, 1, Y, X) + return negative_patch + + def __getitem__(self, index: int) -> TripletSample: + anchor_info = self.valid_anchors[index] + track = anchor_info["track"] + anchor_t = anchor_info["timepoint"] + + # Get anchor patch and select requested channels + anchor_patch = track["data"][anchor_t] # Shape: (C, Y, X) + anchor_patch = self._select_channels(anchor_patch) + if not self.output_2d: + anchor_patch = anchor_patch.unsqueeze(1) + + sample = {"anchor": anchor_patch} + + if self.fit: + positive_patch = self._sample_positive(anchor_info) + + if self.positive_transform: + positive_patch = _transform_channel_wise( + transform=self.positive_transform, + channel_names=self.channel_names, + patch=positive_patch, + norm_meta=None, + ) + + if self.return_negative: + negative_patch = self._sample_negative(anchor_info) + + if self.negative_transform: + negative_patch = _transform_channel_wise( + transform=self.negative_transform, + channel_names=self.channel_names, + patch=negative_patch, + norm_meta=None, + ) + + sample.update({"positive": positive_patch, "negative": negative_patch}) + else: + sample.update({"positive": positive_patch}) + else: + # For prediction mode, include index information + index_dict = { + "fov_name": anchor_info["track_id"], + "id": anchor_t, + } + sample.update({"index": index_dict}) + + if self.anchor_transform: + sample["anchor"] = _transform_channel_wise( + transform=self.anchor_transform, + channel_names=self.channel_names, + patch=sample["anchor"], + norm_meta=None, + ) + + return sample + + +class CellDivisionTripletDataModule(HCSDataModule): + def __init__( + self, + data_path: str, + source_channel: str | Sequence[str], + final_yx_patch_size: tuple[int, int] = (64, 64), # Match dataset size + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], + augment_validation: bool = True, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + output_2d: bool = False, + persistent_workers: bool = False, + prefetch_factor: int | None = None, + pin_memory: bool = False, + ): + """Lightning data module for cell division triplet sampling. + + Parameters + ---------- + data_path : str + Path to directory containing npy files + source_channel : str | Sequence[str] + List of input channel names + final_yx_patch_size : tuple[int, int], optional + Output patch size, by default (64, 64) + split_ratio : float, optional + Ratio of training samples, by default 0.8 + batch_size : int, optional + Batch size, by default 16 + num_workers : int, optional + Number of data-loading workers, by default 8 + normalizations : list[MapTransform], optional + Normalization transforms, by default [] + augmentations : list[MapTransform], optional + Augmentation transforms, by default [] + augment_validation : bool, optional + Apply augmentations to validation data, by default True + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage, by default True + output_2d : bool, optional + Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False + """ + # Initialize parent class with minimal required parameters + super().__init__( + data_path=data_path, + source_channel=source_channel, + target_channel=[], + z_window_size=1, + split_ratio=split_ratio, + batch_size=batch_size, + num_workers=num_workers, + target_2d=False, # Set to False since we're adding depth dimension + yx_patch_size=final_yx_patch_size, + normalizations=normalizations, + augmentations=augmentations, + caching=False, # NOTE: Not applicable for npy files + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.split_ratio = split_ratio + self.data_path = Path(data_path) + self.time_interval = time_interval + self.return_negative = return_negative + self.output_2d = output_2d + self.augment_validation = augment_validation + + # Find all npy files in the data directory + self.npy_files = list(self.data_path.glob("*.npy")) + if not self.npy_files: + raise ValueError(f"No .npy files found in {data_path}") + + _logger.info(f"Found {len(self.npy_files)} .npy files in {data_path}") + + @property + def _base_dataset_settings(self) -> dict: + return { + "channel_names": self.source_channel, + "time_interval": self.time_interval, + "output_2d": self.output_2d, + } + + def _setup_fit(self, dataset_settings: dict): + augment_transform, no_aug_transform = self._fit_transform() + + # Shuffle and split the npy files + shuffled_indices = self._set_fit_global_state(len(self.npy_files)) + npy_files = [self.npy_files[i] for i in shuffled_indices] + + # Set the train and eval positions + num_train_files = int(len(self.npy_files) * self.split_ratio) + train_npy_files = npy_files[:num_train_files] + val_npy_files = npy_files[num_train_files:] + + _logger.debug(f"Number of training files: {len(train_npy_files)}") + _logger.debug(f"Number of validation files: {len(val_npy_files)}") + + # Determine anchor transform based on time interval + anchor_transform = ( + no_aug_transform + if (self.time_interval == "any" or self.time_interval == 0) + else augment_transform + ) + + # Create training dataset + self.train_dataset = CellDivisionTripletDataset( + data_paths=train_npy_files, + anchor_transform=anchor_transform, + positive_transform=augment_transform, + negative_transform=augment_transform, + fit=True, + return_negative=self.return_negative, + **dataset_settings, + ) + + # Choose transforms for validation based on augment_validation parameter + val_positive_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_negative_transform = ( + augment_transform if self.augment_validation else no_aug_transform + ) + val_anchor_transform = ( + anchor_transform if self.augment_validation else no_aug_transform + ) + + # Create validation dataset + self.val_dataset = CellDivisionTripletDataset( + data_paths=val_npy_files, + anchor_transform=val_anchor_transform, + positive_transform=val_positive_transform, + negative_transform=val_negative_transform, + fit=True, + return_negative=self.return_negative, + **dataset_settings, + ) + + _logger.info(f"Training dataset size: {len(self.train_dataset)}") + _logger.info(f"Validation dataset size: {len(self.val_dataset)}") + + def _setup_predict(self, dataset_settings: dict): + self._set_predict_global_state() + + # For prediction, use all data + self.predict_dataset = CellDivisionTripletDataset( + data_paths=self.npy_files, + anchor_transform=Compose(self.normalizations), + fit=False, + **dataset_settings, + ) + + def _setup_test(self, *args, **kwargs): + raise NotImplementedError("Self-supervised model does not support testing") diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 8f5f7826e..489a09c65 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -335,6 +335,7 @@ def __init__( num_workers: int = 1, normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], + augment_validation: bool = True, caching: bool = False, fit_include_wells: list[str] | None = None, fit_exclude_fovs: list[str] | None = None, @@ -377,6 +378,9 @@ def __init__( Normalization transforms, by default [] augmentations : list[MapTransform], optional Augmentation transforms, by default [] + augment_validation : bool, optional + Apply augmentations to validation data, by default True. + Set to False for VAE training where clean validation is needed. caching : bool, optional Whether to cache the dataset, by default False fit_include_wells : list[str], optional @@ -439,6 +443,7 @@ def __init__( self.include_track_ids = include_track_ids self.time_interval = time_interval self.return_negative = return_negative + self.augment_validation = augment_validation self._cache_pool_bytes = cache_pool_bytes self._augmentation_transform = Compose( self.normalizations + self.augmentations + [self._final_crop()] @@ -501,6 +506,7 @@ def _setup_fit(self, dataset_settings: dict): return_negative=self.return_negative, **dataset_settings, ) + self.val_dataset = TripletDataset( positions=val_positions, tracks_tables=val_tracks_tables, @@ -584,6 +590,8 @@ def _find_transform(self, key: str): if self.trainer: if self.trainer.predicting: return self._no_augmentation_transform + if self.trainer.validating and not self.augment_validation: + return self._no_augmentation_transform # NOTE: for backwards compatibility if key == "anchor" and self.time_interval in ("any", 0): return self._no_augmentation_transform diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index b7938013e..c515c6145 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -11,6 +11,8 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder +from viscy.representation.vae import BetaVae25D, BetaVaeMonai +from viscy.representation.vae_logging import BetaVaeLogger from viscy.utils.log_images import detach_sample, render_images _logger = logging.getLogger("lightning.pytorch") @@ -221,12 +223,6 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_train_epoch_end(self) -> None: super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) - # Log UMAP embeddings for validation - if self.log_embeddings: - embeddings = torch.cat( - [output["embeddings"] for output in self.validation_step_outputs] - ) - self.log_embedding_umap(embeddings, tag="train") self.training_step_outputs = [] def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: @@ -264,13 +260,6 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - # Log UMAP embeddings for training - if self.log_embeddings: - embeddings = torch.cat( - [output["embeddings"] for output in self.training_step_outputs] - ) - self.log_embedding_umap(embeddings, tag="val") - self.validation_step_outputs = [] def configure_optimizers(self): @@ -287,3 +276,328 @@ def predict_step( "projections": projections, "index": batch["index"], } + + +class BetaVaeModule(LightningModule): + def __init__( + self, + vae: nn.Module | BetaVae25D | BetaVaeMonai, + loss_function: nn.Module | nn.MSELoss = nn.MSELoss(reduction="sum"), + beta: float = 1.0, + beta_schedule: Literal["linear", "cosine", "warmup"] | None = None, + beta_min: float = 0.1, + beta_warmup_epochs: int = 50, + lr: float = 1e-5, + lr_schedule: Literal["WarmupCosine", "Constant"] = "Constant", + log_batches_per_epoch: int = 8, + log_samples_per_batch: int = 1, + example_input_array_shape: Sequence[int] = (1, 2, 30, 256, 256), + log_enhanced_visualizations: bool = False, + log_enhanced_visualizations_frequency: int = 30, + ): + super().__init__() + + self.model = vae + self.loss_function = loss_function + + self.beta = beta + self.beta_schedule = beta_schedule + self.beta_min = beta_min + self.beta_warmup_epochs = beta_warmup_epochs + + self.lr = lr + self.lr_schedule = lr_schedule + + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + + self.example_input_array = torch.rand(*example_input_array_shape) + + self.log_enhanced_visualizations = log_enhanced_visualizations + self.log_enhanced_visualizations_frequency = ( + log_enhanced_visualizations_frequency + ) + self.training_step_outputs = [] + self.validation_step_outputs = [] + + self._min_beta = 1e-15 + self._logvar_minmax = (-20, 20) + + # Handle different parameter names for latent dimensions + latent_dim = None + if hasattr(self.model, "latent_dim"): + latent_dim = self.model.latent_dim + elif hasattr(self.model, "latent_size"): + latent_dim = self.model.latent_size + elif hasattr(self.model, "encoder") and hasattr( + self.model.encoder, "latent_dim" + ): + latent_dim = self.model.encoder.latent_dim + + if latent_dim is not None: + self.vae_logger = BetaVaeLogger(latent_dim=latent_dim) + else: + _logger.warning( + "No latent dimension provided for BetaVaeLogger. Using default with 128 dimensions." + ) + self.vae_logger = BetaVaeLogger() + + def setup(self, stage: str = None): + """Setup hook to initialize device-dependent components.""" + super().setup(stage) + + # Initialize the VAE logger with proper device + self.vae_logger.setup(device=self.device) + + def _get_current_beta(self) -> float: + """Get current beta value based on scheduling.""" + if self.beta_schedule is None: + return max(self.beta, self._min_beta) + + epoch = self.current_epoch + + if self.beta_schedule == "linear": + # Linear warmup from beta_min to beta + if epoch < self.beta_warmup_epochs: + beta_val = ( + self.beta_min + + (self.beta - self.beta_min) * epoch / self.beta_warmup_epochs + ) + return max(beta_val, self._min_beta) + else: + return max(self.beta, self._min_beta) + + elif self.beta_schedule == "cosine": + # Cosine warmup from beta_min to beta + if epoch < self.beta_warmup_epochs: + import math + + progress = epoch / self.beta_warmup_epochs + beta_val = self.beta_min + (self.beta - self.beta_min) * 0.5 * ( + 1 + math.cos(math.pi * (1 - progress)) + ) + return max(beta_val, self._min_beta) + else: + return max(self.beta, self._min_beta) + + elif self.beta_schedule == "warmup": + # Keep beta_min for warmup epochs, then jump to beta + beta_val = self.beta_min if epoch < self.beta_warmup_epochs else self.beta + return max(beta_val, self._min_beta) + + else: + return max(self.beta, self._min_beta) + + def forward(self, x: Tensor) -> dict: + """Forward pass through Beta-VAE.""" + + original_shape = x.shape + is_monai_2d = ( + isinstance(self.model, BetaVaeMonai) + and hasattr(self.model, "spatial_dims") + and self.model.spatial_dims == 2 + ) + if is_monai_2d and len(x.shape) == 5 and x.shape[2] == 1: + x = x.squeeze(2) + + # Handle different model output formats + model_output = self.model(x) + + recon_x = model_output.recon_x + mu = model_output.mean + logvar = model_output.logvar + z = model_output.z + + if is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1: + # Convert back (B, C, H, W) to (B, C, 1, H, W) + recon_x = recon_x.unsqueeze(2) + + current_beta = self._get_current_beta() + batch_size = original_shape[0] + + # Use original input for loss computation to ensure shape consistency + x_original = ( + x + if not (is_monai_2d and len(original_shape) == 5 and original_shape[2] == 1) + else x.unsqueeze(2) + ) + recon_loss = self.loss_function(recon_x, x_original) + if isinstance(self.loss_function, nn.MSELoss): + if ( + hasattr(self.loss_function, "reduction") + and self.loss_function.reduction == "sum" + ): + recon_loss = recon_loss / batch_size + elif ( + hasattr(self.loss_function, "reduction") + and self.loss_function.reduction == "mean" + ): + # Correct the over-normalization by PyTorch's mean reduction by multiplying by the number of elements per image + num_elements_per_image = x_original[0].numel() + recon_loss = recon_loss * num_elements_per_image + + kl_loss = -0.5 * torch.sum( + 1 + + torch.clamp(logvar, self._logvar_minmax[0], self._logvar_minmax[1]) + - mu.pow(2) + - logvar.exp(), + dim=1, + ) + kl_loss = torch.mean(kl_loss) + + total_loss = recon_loss + current_beta * kl_loss + + return { + "recon_x": recon_x, + "z": z, + "mu": mu, + "logvar": logvar, + "recon_loss": recon_loss, + "kl_loss": kl_loss, + "total_loss": total_loss, + } + + def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Training step with VAE loss computation.""" + + x = batch["anchor"] + model_output = self(x) + loss = model_output["total_loss"] + + # Log enhanced β-VAE metrics + self.vae_logger.log_enhanced_metrics( + lightning_module=self, model_output=model_output, batch=batch, stage="train" + ) + # Log samples + self._log_step_samples(batch_idx, x, model_output["recon_x"], "train") + + return loss + + def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Validation step with VAE loss computation.""" + x = batch["anchor"] + model_output = self(x) + loss = model_output["total_loss"] + + # Log enhanced β-VAE metrics + self.vae_logger.log_enhanced_metrics( + lightning_module=self, model_output=model_output, batch=batch, stage="val" + ) + + # Log samples + self._log_step_samples(batch_idx, x, model_output["recon_x"], "val") + + return loss + + def _log_step_samples( + self, batch_idx, original, reconstruction, stage: Literal["train", "val"] + ): + """Log sample reconstructions.""" + if batch_idx < self.log_batches_per_epoch: + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + + # Store samples for epoch end logging + samples = { + "original": original.detach().cpu()[: self.log_samples_per_batch], + "reconstruction": reconstruction.detach().cpu()[ + : self.log_samples_per_batch + ], + } + output_list.append(samples) + + def _log_samples(self, key: str, samples_list: list): + """Log reconstruction samples at epoch end.""" + if len(samples_list) > 0: + # Take middle z-slice for visualization + mid_z = samples_list[0]["original"].shape[2] // 2 + + originals = [] + reconstructions = [] + + for sample in samples_list: + orig = sample["original"][:, :, mid_z].numpy() + recon = sample["reconstruction"][:, :, mid_z].numpy() + + originals.extend([orig[i] for i in range(orig.shape[0])]) + reconstructions.extend([recon[i] for i in range(recon.shape[0])]) + + # Create grid with originals and reconstructions + combined = [] + for orig, recon in zip(originals[:4], reconstructions[:4]): + combined.append([orig, recon]) + + grid = render_images(combined, cmaps=["gray", "gray"]) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def on_train_epoch_end(self) -> None: + """Log training samples at epoch end.""" + super().on_train_epoch_end() + self._log_samples("train_reconstructions", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self) -> None: + """Log validation samples at epoch end.""" + super().on_validation_epoch_end() + self._log_samples("val_reconstructions", self.validation_step_outputs) + self.validation_step_outputs = [] + + if ( + self.log_enhanced_visualizations + and self.current_epoch % self.log_enhanced_visualizations_frequency == 0 + and self.current_epoch > 0 + ): + self._log_enhanced_visualizations() + + def _log_enhanced_visualizations(self): + """Log enhanced β-VAE visualizations.""" + try: + val_dataloaders = self.trainer.val_dataloaders + if val_dataloaders is None: + val_dataloader = None + elif isinstance(val_dataloaders, list): + val_dataloader = val_dataloaders[0] if val_dataloaders else None + else: + val_dataloader = val_dataloaders + + if val_dataloader is None: + _logger.warning("No validation dataloader available for visualizations") + return + + _logger.info( + f"Logging enhanced β-VAE visualizations at epoch {self.current_epoch}" + ) + + self.vae_logger.log_latent_traversal( + lightning_module=self, n_dims=8, n_steps=11 + ) + self.vae_logger.log_latent_interpolation( + lightning_module=self, n_pairs=3, n_steps=11 + ) + self.vae_logger.log_factor_traversal_matrix( + lightning_module=self, n_dims=8, n_steps=7 + ) + + except Exception as e: + _logger.error(f"Error logging enhanced visualizations: {e}") + + def configure_optimizers(self): + """Configure optimizer for VAE training.""" + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + return optimizer + + def predict_step(self, batch: TripletSample, batch_idx, dataloader_idx=0) -> dict: + """Prediction step for VAE inference.""" + x = batch["anchor"] + model_output = self(x) + + return { + "latent": model_output["z"], + "reconstruction": model_output["recon_x"], + "index": batch["index"], + } diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..8f58ef0b5 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -33,22 +33,72 @@ def knn_accuracy(embeddings, annotations, k=5): return accuracy -def pairwise_distance_matrix(features: ArrayLike, metric: str = "cosine") -> NDArray: +def pairwise_distance_matrix( + features: ArrayLike, metric: str = "cosine", device: str = "auto" +) -> NDArray: """Compute pairwise distances between all samples in the feature matrix. + Uses PyTorch with GPU acceleration when available for significant speedup. + Falls back to scipy for unsupported metrics or when PyTorch is unavailable. + Parameters ---------- features : ArrayLike Feature matrix (n_samples, n_features) metric : str, optional Distance metric to use, by default "cosine" + Supports "cosine" and "euclidean" with PyTorch acceleration. + Other scipy metrics will use scipy fallback. + device : str, optional + Device to use for computation, by default "auto" + - "auto": automatically use GPU if available, otherwise CPU + - "cuda" or "gpu": force GPU usage + - "cpu": force CPU usage + - None or "scipy": force scipy fallback Returns ------- NDArray Distance matrix of shape (n_samples, n_samples) """ - return cdist(features, features, metric=metric) + if device in (None, "scipy") or metric not in ("cosine", "euclidean"): + return cdist(features, features, metric=metric) + + try: + import torch + + if device == "auto": + device_torch = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif device in ("cuda", "gpu"): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + device_torch = torch.device("cuda") + elif device == "cpu": + device_torch = torch.device("cpu") + else: + raise ValueError( + f"Invalid device: {device}. Use 'auto', 'cuda', 'cpu', or 'scipy'" + ) + features_array = np.asarray(features) + if features_array.dtype == np.float32: + features_tensor = torch.from_numpy(features_array).double().to(device_torch) + else: + features_tensor = torch.from_numpy(features_array).to(device_torch) + if features_tensor.dtype not in (torch.float32, torch.float64): + features_tensor = features_tensor.double() + + if metric == "cosine": + features_norm = torch.nn.functional.normalize(features_tensor, p=2, dim=1) + similarity = features_norm @ features_norm.T + distances = 1 - similarity + elif metric == "euclidean": + distances = torch.cdist(features_tensor, features_tensor, p=2) + return distances.cpu().numpy() + + except ImportError: + return cdist(features, features, metric=metric) + except (RuntimeError, torch.cuda.OutOfMemoryError): + return cdist(features, features, metric=metric) def rank_nearest_neighbors( diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 81d0194f6..c97bf39f4 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -10,6 +10,7 @@ def compute_phate( embedding_dataset, + scale_embeddings: bool = False, n_components: int = 2, knn: int = 5, decay: int = 40, @@ -59,11 +60,18 @@ def compute_phate( else embedding_dataset ) + if scale_embeddings: + scaler = StandardScaler() + embeddings_scaled = scaler.fit_transform(embeddings) + else: + embeddings_scaled = embeddings + # Compute PHATE embeddings phate_model = phate.PHATE( - n_components=n_components, knn=knn, decay=decay, **phate_kwargs + n_components=n_components, knn=knn, decay=decay, random_state=42, **phate_kwargs ) - phate_embedding = phate_model.fit_transform(embeddings) + + phate_embedding = phate_model.fit_transform(embeddings_scaled) # Update dataset if requested if update_dataset and isinstance(embedding_dataset, Dataset): diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index a920eb072..f5fea2f2b 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,9 +1,14 @@ from collections import defaultdict -from typing import Literal import numpy as np +import xarray as xr from sklearn.metrics.pairwise import cosine_similarity +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, +) + def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): """Extract embeddings and calculate cosine similarities for a specific cell""" @@ -12,198 +17,71 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): & (embedding_dataset["track_id"] == track_id), drop=True, ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) + features = filtered_data["features"].values + time_points = filtered_data["t"].values first_time_point_embedding = features[0].reshape(1, -1) cosine_similarities = cosine_similarity( first_time_point_embedding, features ).flatten() + cosine_similarities = np.clip(cosine_similarities, -1.0, 1.0) return time_points, cosine_similarities.tolist() -def compute_displacement( - embedding_dataset, - distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", +def compute_track_displacement( + embedding_dataset: xr.Dataset, + distance_metric: str = "cosine", ) -> dict[int, list[float]]: - """Compute the displacement or mean square displacement (MSD) of embeddings. - - For each time difference τ, computes either: - - |r(t + τ) - r(t)|² for squared Euclidean (MSD) - - cos_sim(r(t + τ), r(t)) for cosine - for all particles and initial times t. + """ + Compute Mean Squared Displacement using pairwise distance matrix. Parameters ---------- - embedding_dataset : xarray.Dataset + embedding_dataset : xr.Dataset Dataset containing embeddings and metadata distance_metric : str - The metric to use for computing distances between embeddings. - Valid options are: - - "euclidean": Euclidean distance (L2 norm) - - "euclidean_squared": Squared Euclidean distance (for MSD, default) - - "cosine": Cosine similarity - - "cosine_dissimilarity": 1 - cosine similarity + Distance metric to use. Default is cosine. + See for other supported distance metrics. + https://github.com/scipy/scipy/blob/main/scipy/spatial/distance.py Returns ------- dict[int, list[float]] - Dictionary mapping τ to list of displacements for all particles and initial times + Dictionary mapping time lag τ to list of squared displacements """ - # Get unique tracks efficiently using pandas operations + unique_tracks_df = ( embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() ) - # Get data from dataset - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - # Initialize results dictionary with empty lists displacement_per_tau = defaultdict(list) - # Process each track for fov_name, track_id in zip( unique_tracks_df["fov_name"], unique_tracks_df["track_id"] ): - # Get sorted track data - mask = (fov_names == fov_name) & (track_ids == track_id) - times = timepoints[mask] - track_embeddings = embeddings[mask] + # Filter data for this track + track_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) # Sort by time - time_order = np.argsort(times) - times = times[time_order] - track_embeddings = track_embeddings[time_order] - - # Process each time point - for t_idx, t in enumerate(times[:-1]): - current_embedding = track_embeddings[t_idx] - - # Check all possible future time points - for future_idx, future_time in enumerate( - times[t_idx + 1 :], start=t_idx + 1 - ): - tau = future_time - t - future_embedding = track_embeddings[future_idx] - - if distance_metric in ["cosine"]: - dot_product = np.dot(current_embedding, future_embedding) - norms = np.linalg.norm(current_embedding) * np.linalg.norm( - future_embedding - ) - similarity = dot_product / norms - displacement = similarity - else: # Euclidean metrics - diff_squared = np.sum((current_embedding - future_embedding) ** 2) - displacement = diff_squared - displacement_per_tau[int(tau)].append(displacement) - - return dict(displacement_per_tau) - - -def compute_displacement_statistics( - displacement_per_tau: dict[int, list[float]], -) -> tuple[dict[int, float], dict[int, float]]: - """Compute mean and standard deviation of displacements for each tau. - - Parameters - ---------- - displacement_per_tau : dict[int, list[float]] - Dictionary mapping τ to list of displacements - - Returns - ------- - tuple[dict[int, float], dict[int, float]] - Tuple of (mean_displacements, std_displacements) where each is a - dictionary mapping τ to the statistic - """ - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - return mean_displacement_per_tau, std_displacement_per_tau - - -def compute_dynamic_range(mean_displacement_per_tau): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. + time_order = np.argsort(track_data["t"].values) + times = track_data["t"].values[time_order] + track_embeddings = track_data["features"].values[time_order] - Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value + # Compute pairwise distance matrix + distance_matrix = pairwise_distance_matrix( + track_embeddings, metric=distance_metric + ) - Returns: - float: dynamic range (max displacement - min displacement) - """ - displacements = list(mean_displacement_per_tau.values()) - return max(displacements) - min(displacements) + # Extract displacements using diagonal offsets + n_timepoints = len(times) + for time_offset in range(1, n_timepoints): + diagonal_displacements = compare_time_offset(distance_matrix, time_offset) + for i, displacement in enumerate(diagonal_displacements): + tau = int(times[i + time_offset] - times[i]) + displacement_per_tau[tau].append(displacement) -def compute_rms_per_track(embedding_dataset): - """ - Compute RMS of the time derivative of embeddings per track. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - - Returns: - list: A list of RMS values, one for each track. - """ - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - cell_identifiers = np.array( - list(zip(fov_names, track_ids)), - dtype=[("fov_name", "O"), ("track_id", "int64")], - ) - unique_cells = np.unique(cell_identifiers) - - rms_values = [] - - for cell in unique_cells: - fov_name = cell["fov_name"] - track_id = cell["track_id"] - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - cell_timepoints = timepoints[indices] - cell_embeddings = embeddings[indices] - - if len(cell_embeddings) < 2: - continue - - sorted_indices = np.argsort(cell_timepoints) - cell_embeddings = cell_embeddings[sorted_indices] - differences = np.diff(cell_embeddings, axis=0) - - if differences.shape[0] == 0: - continue - - norms = np.linalg.norm(differences, axis=1) - rms = np.sqrt(np.mean(norms**2)) - rms_values.append(rms) - - return rms_values - - -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) - first_time_point_embedding = normalized_features[0].reshape(1, -1) - euclidean_distances = np.linalg.norm( - first_time_point_embedding - normalized_features, axis=1 - ) - return time_points, euclidean_distances.tolist() + return dict(displacement_per_tau) diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 7c5216193..d82d324c0 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -9,6 +9,7 @@ from numpy.typing import NDArray from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report +from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from torch import Tensor from xarray import DataArray @@ -19,7 +20,8 @@ def fit_logistic_regression( features: DataArray, annotations: pd.Series, - train_fovs: list[str], + train_fovs: list[str] | None = None, + train_ratio: float = 0.8, remove_background_class: bool = True, scale_features: bool = False, class_weight: Mapping | str | None = "balanced", @@ -38,8 +40,14 @@ def fit_logistic_regression( annotations : pd.Series Categorical class annotations with label values starting from 0. Must have 3 classes (when remove background is True) or 2 classes. - train_fovs : list[str] + train_fovs : list[str] | None, optional List of FOVs to use for training. The rest will be used for testing. + If None, uses stratified sampling based on train_ratio. + train_ratio : float, optional + Proportion of samples to use for training (0.0 to 1.0). + Used when train_fovs is None. + Uses stratified sampling to ensure balanced class representation. + Default is 0.8 (80% training, 20% testing). remove_background_class : bool, optional Remove background class (0), by default True scale_features : bool, optional @@ -56,23 +64,48 @@ def fit_logistic_regression( tuple[LogisticRegression, tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]]] Trained classifier and data split [[X_train, y_train], [X_test, y_test]]. """ - fov_selection = features["fov_name"].isin(train_fovs) - train_selection = fov_selection - test_selection = ~fov_selection annotations = annotations.cat.codes.values.copy() + + # Handle background class removal before splitting for stratification if remove_background_class: - label_selection = annotations != 0 - train_selection &= label_selection - test_selection &= label_selection - annotations -= 1 - train_features = features.values[train_selection] - test_features = features.values[test_selection] + valid_indices = annotations != 0 + features_filtered = features[valid_indices] + annotations_filtered = annotations[valid_indices] - 1 + else: + features_filtered = features + annotations_filtered = annotations + + # Determine train FOVs + if train_fovs is None: + unique_fovs = features_filtered["fov_name"].unique() + + fov_class_dist = [] + for fov in unique_fovs: + fov_mask = features_filtered["fov_name"] == fov + fov_classes = annotations_filtered[fov_mask] + # Use majority class for stratification or class distribution + majority_class = pd.Series(fov_classes).mode()[0] + fov_class_dist.append(majority_class) + + # Split FOVs, not individual samples + train_fovs, test_fovs = train_test_split( + unique_fovs, + test_size=1 - train_ratio, + stratify=fov_class_dist, + random_state=random_state, + ) + + # Create train/test selections + train_selection = features_filtered["fov_name"].isin(train_fovs) + test_selection = ~train_selection + train_features = features_filtered.values[train_selection] + test_features = features_filtered.values[test_selection] + train_annotations = annotations_filtered[train_selection] + test_annotations = annotations_filtered[test_selection] + if scale_features: - scaler = StandardScaler() - train_features = scaler.fit_transform(train_features) - test_features = scaler.fit_transform(test_features) - train_annotations = annotations[train_selection] - test_annotations = annotations[test_selection] + train_features = StandardScaler().fit_transform(train_features) + test_features = StandardScaler().fit_transform(test_features) logistic_regression = LogisticRegression( class_weight=class_weight, random_state=random_state, diff --git a/viscy/representation/evaluation/smoothness.py b/viscy/representation/evaluation/smoothness.py new file mode 100644 index 000000000..117aa3d15 --- /dev/null +++ b/viscy/representation/evaluation/smoothness.py @@ -0,0 +1,205 @@ +from typing import Literal + +import anndata as ad +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from scipy.signal import find_peaks +from scipy.stats import gaussian_kde +from sklearn.preprocessing import StandardScaler + +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) + + +def compute_piece_wise_distance( + features_df: pd.DataFrame, + cross_dist: NDArray, + rank_fractions: NDArray, + groupby: list[str] = ["fov_name", "track_id"], +) -> tuple[list[list[float]], list[list[float]]]: + """ + Computing the piece-wise distance and rank difference + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + + Parameters + ---------- + features_df : pd.DataFrame + DataFrame containing the features + cross_dist : NDArray + Cross-distance matrix + rank_fractions : NDArray + Rank fractions + groupby : list[str], optional + Columns to group by, by default ["fov_name", "track_id"] + + Returns + ------- + piece_wise_dissimilarity_per_track : list + Piece-wise dissimilarity per track + piece_wise_rank_difference_per_track : list + Piece-wise rank difference per track + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for _, subdata in features_df.groupby(groupby): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def find_distribution_peak( + data: np.ndarray, method: Literal["histogram", "kde_robust"] = "kde_robust" +) -> float: + """Find the peak of a distribution + + Parameters + ---------- + data: np.ndarray + The data to find the peak of + method: Literal["histogram", "kde_robust"], optional + The method to use to find the peak, by default "kde_robust" + + Returns + ------- + float: The peak of the distribution (highest peak if multiple) + """ + if method == "histogram": + # Simple histogram-based peak finding + hist, bin_edges = np.histogram(data, bins=50, density=True) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + peaks, properties = find_peaks( + hist, height=np.max(hist) * 0.1 + ) # 10% of max height + if len(peaks) == 0: + return bin_centers[np.argmax(hist)] # Fallback to global max + # Return peak with highest density + peak_heights = properties["peak_heights"] + return bin_centers[peaks[np.argmax(peak_heights)]] + + elif method == "kde_robust": + # More robust KDE approach + kde = gaussian_kde(data) + x_range = np.linspace(np.min(data), np.max(data), 1000) + kde_vals = kde(x_range) + peaks, properties = find_peaks(kde_vals, height=np.max(kde_vals) * 0.1) + if len(peaks) == 0: + return x_range[np.argmax(kde_vals)] # Fallback to global max + # Return peak with highest KDE value + peak_heights = properties["peak_heights"] + return x_range[peaks[np.argmax(peak_heights)]] + + else: + raise ValueError(f"Unknown method: {method}. Use 'histogram' or 'kde_robust'.") + + +def compute_embeddings_smoothness( + features_ad: ad.AnnData, + distance_metric: Literal["cosine", "euclidean"] = "cosine", + verbose: bool = False, +) -> tuple[dict, dict, list[list[float]]]: + """ + Compute the smoothness statistics of embeddings + + Parameters + -------- + features_ad: adAnnData + distance_metric: Distance metric to use, by default "cosine" + + Returns: + ------- + stats: dict: Dictionary containing metrics including: + - adjacent_frame_mean: Mean of adjacent frame dissimilarity + - adjacent_frame_std: Standard deviation of adjacent frame dissimilarity + - adjacent_frame_median: Median of adjacent frame dissimilarity + - adjacent_frame_peak: Peak of adjacent frame distribution + - random_frame_mean: Mean of random sampling dissimilarity + - random_frame_std: Standard deviation of random sampling dissimilarity + - random_frame_median: Median of random sampling dissimilarity + - random_frame_peak: Peak of random sampling distribution + - smoothness_score: Score of smoothness + - dynamic_range: Difference between random and adjacent peaks + distributions: dict: Dictionary containing distributions including: + - adjacent_frame_distribution: Full distribution of adjacent frame dissimilarities + - random_frame_distribution: Full distribution of random sampling dissimilarities + piecewise_distance_per_track: list[list[float]] + Piece-wise distance per track + """ + features = features_ad.X + scaled_features = StandardScaler().fit_transform(features) + + # Compute the distance matrix + cross_dist = pairwise_distance_matrix(scaled_features, metric=distance_metric) + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise distance and rank difference + features_df = features_ad.obs.reset_index(drop=True) + piecewise_distance_per_track, _ = compute_piece_wise_distance( + features_df, cross_dist, rank_fractions + ) + + all_piecewise_distances = np.concatenate(piecewise_distance_per_track) + + # Random sampling values in the distance matrix with same size as adjacent frame measurements + n_samples = len(all_piecewise_distances) + # Avoid sampling the diagonal elements + np.random.seed(42) + i_indices = np.random.randint(0, len(cross_dist), size=n_samples) + j_indices = np.random.randint(0, len(cross_dist), size=n_samples) + + diagonal_mask = i_indices == j_indices + while diagonal_mask.any(): + j_indices[diagonal_mask] = np.random.randint( + 0, len(cross_dist), size=diagonal_mask.sum() + ) + diagonal_mask = i_indices == j_indices + sampled_values = cross_dist[i_indices, j_indices] + + # Compute the peaks of both distributions using KDE + adjacent_peak = find_distribution_peak(all_piecewise_distances, method="kde_robust") + random_peak = find_distribution_peak(sampled_values, method="kde_robust") + smoothness_score = np.mean(all_piecewise_distances) / np.mean(sampled_values) + dynamic_range = random_peak - adjacent_peak + + stats = { + "adjacent_frame_mean": float(np.mean(all_piecewise_distances)), + "adjacent_frame_std": float(np.std(all_piecewise_distances)), + "adjacent_frame_median": float(np.median(all_piecewise_distances)), + "adjacent_frame_peak": float(adjacent_peak), + # "adjacent_frame_p99": p99_piece_wise_distance, + # "adjacent_frame_p1": p1_percentile_piece_wise_distance, + # "adjacent_frame_distribution": all_piecewise_distances, + "random_frame_mean": float(np.mean(sampled_values)), + "random_frame_std": float(np.std(sampled_values)), + "random_frame_median": float(np.median(sampled_values)), + "random_frame_peak": float(random_peak), + # "random_frame_distribution": sampled_values, + "smoothness_score": float(smoothness_score), + "dynamic_range": float(dynamic_range), + } + distributions = { + "adjacent_frame_distribution": all_piecewise_distances, + "random_frame_distribution": sampled_values, + } + + if verbose: + for key, value in stats.items(): + print(f"{key}: {value}") + + return stats, distributions, piecewise_distance_per_track diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 55481d434..ad4f48717 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -50,6 +50,7 @@ def __init__( log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, log_embeddings: bool = False, + embedding_log_frequency: int = 10, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), prediction_arm: Literal["source", "target"] = "source", ) -> None: @@ -61,6 +62,7 @@ def __init__( log_batches_per_epoch=log_batches_per_epoch, log_samples_per_batch=log_samples_per_batch, log_embeddings=log_embeddings, + embedding_log_frequency=embedding_log_frequency, example_input_array_shape=example_input_array_shape, ) self.example_input_array = (self.example_input_array, self.example_input_array) diff --git a/viscy/representation/vae.py b/viscy/representation/vae.py new file mode 100644 index 000000000..f34c2828d --- /dev/null +++ b/viscy/representation/vae.py @@ -0,0 +1,412 @@ +from collections.abc import Sequence +from types import SimpleNamespace +from typing import Callable, Literal + +import timm +import torch +from monai.networks.blocks import ResidualUnit, UpSample +from monai.networks.blocks.dynunet_block import get_conv_layer +from monai.networks.layers.factories import Norm +from monai.networks.nets import VarAutoEncoder +from torch import Tensor, nn + +from viscy.unet.networks.unext2 import ( + PixelToVoxelHead, + StemDepthtoChannels, +) + + +class VaeUpStage(nn.Module): + """VAE upsampling stage without skip connections.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: int, + mode: Literal["deconv", "pixelshuffle"], + conv_blocks: int, + norm_name: Literal["batch", "instance"], + upsample_pre_conv: Literal["default"] | Callable | None, + ) -> None: + super().__init__() + spatial_dims = 2 + + if mode == "deconv": + self.upsample = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + stride=scale_factor, + kernel_size=scale_factor, + norm=norm_name, + is_transposed=True, + ) + # Simple conv blocks for deconv mode + self.conv = nn.Sequential( + ResidualUnit( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + norm=norm_name, + ), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + ) + elif mode == "pixelshuffle": + mid_channels = in_channels // scale_factor**2 + self.upsample = UpSample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=mid_channels, + scale_factor=scale_factor, + mode=mode, + pre_conv=upsample_pre_conv, + apply_pad_pool=False, + ) + conv_layers = [] + current_channels = mid_channels + + for i in range(conv_blocks): + block_out_channels = out_channels + conv_layers.append( + ResidualUnit( + spatial_dims=spatial_dims, + in_channels=current_channels, + out_channels=block_out_channels, + norm=norm_name, + ) + ) + current_channels = block_out_channels + + self.conv = nn.Sequential(*conv_layers) + + def forward(self, inp: Tensor) -> Tensor: + """ + Parameters + ---------- + inp : Tensor + Low resolution features + + Returns + ------- + Tensor + High resolution features + """ + inp = self.upsample(inp) + return self.conv(inp) + + +class VaeEncoder(nn.Module): + """VAE encoder for microscopy data with 3D to 2D conversion.""" + + def __init__( + self, + backbone: Literal["resnet50", "convnext_tiny"] = "resnet50", + in_channels: int = 2, + in_stack_depth: int = 16, + latent_dim: int = 1024, + input_spatial_size: tuple[int, int] = (256, 256), + stem_kernel_size: tuple[int, int, int] = (2, 4, 4), + stem_stride: tuple[int, int, int] = (2, 4, 4), + drop_path_rate: float = 0.0, + pretrained: bool = True, + ): + super().__init__() + self.backbone = backbone + self.latent_dim = latent_dim + + encoder = timm.create_model( + backbone, + pretrained=pretrained, + features_only=True, + drop_path_rate=drop_path_rate, + ) + num_channels = encoder.feature_info.channels() + in_channels_encoder = num_channels[0] + out_channels_encoder = num_channels[-1] + + if "convnext" in backbone: + num_channels = encoder.feature_info.channels() + encoder.stem_0 = nn.Identity() + elif "resnet" in backbone: + encoder.conv1 = nn.Identity() + out_channels_encoder = num_channels[-1] + else: + raise ValueError( + f"Backbone {backbone} not supported. Use 'resnet50', 'convnext_tiny', or 'convnextv2_tiny'" + ) + + # Stem for 3d multichannel and to convert 3D to 2D + self.stem = StemDepthtoChannels( + in_channels=in_channels, + in_stack_depth=in_stack_depth, + in_channels_encoder=in_channels_encoder, + stem_kernel_size=stem_kernel_size, + stem_stride=stem_stride, + ) + self.encoder = encoder + self.num_channels = num_channels + self.in_channels_encoder = in_channels_encoder + self.out_channels_encoder = out_channels_encoder + + # Calculate spatial size after stem + stem_spatial_size_h = input_spatial_size[0] // stem_stride[1] + stem_spatial_size_w = input_spatial_size[1] // stem_stride[2] + + # Spatial size after backbone + backbone_reduction = 2 ** (len(num_channels) - 1) + final_spatial_size_h = stem_spatial_size_h // backbone_reduction + final_spatial_size_w = stem_spatial_size_w // backbone_reduction + + flattened_size = ( + out_channels_encoder * final_spatial_size_h * final_spatial_size_w + ) + + self.fc = nn.Linear(flattened_size, latent_dim) + self.fc_mu = nn.Linear(latent_dim, latent_dim) + self.fc_logvar = nn.Linear(latent_dim, latent_dim) + + # Store final spatial size for decoder (assuming square for simplicity) + self.encoder_spatial_size = final_spatial_size_h # Assuming square output + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """Reparameterization trick: sample from N(mu, var) using N(0,1).""" + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x: Tensor) -> SimpleNamespace: + """Forward pass returning VAE encoder outputs.""" + x = self.stem(x) + + features = self.encoder(x) + + # NOTE: taking the highest semantic features and flatten + # When features_only=False, encoder returns single tensor, not list + if isinstance(features, list): + x = features[-1] # [B, C, H, W] + else: + x = features # [B, C, H, W] + x_flat = x.flatten(1) # [B, C*H*W] - flatten from dim 1 onwards + + x_intermediate = self.fc(x_flat) + + mu = self.fc_mu(x_intermediate) + logvar = self.fc_logvar(x_intermediate) + z = self.reparameterize(mu, logvar) + + return SimpleNamespace(mean=mu, log_covariance=logvar, z=z) + + +class VaeDecoder(nn.Module): + """VAE decoder for microscopy data with 2D to 3D conversion.""" + + def __init__( + self, + decoder_channels: list[int] = [1024, 512, 256, 128], + latent_dim: int = 1024, + out_channels: int = 2, + out_stack_depth: int = 16, + head_expansion_ratio: int = 2, + strides: list[int] = [2, 2, 2, 1], + encoder_spatial_size: int = 16, + head_pool: bool = False, + upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", + conv_blocks: int = 2, + norm_name: Literal["batch", "instance"] = "batch", + upsample_pre_conv: Literal["default"] | Callable | None = None, + ): + super().__init__() + self.decoder_channels = decoder_channels + self.latent_dim = latent_dim + self.out_channels = out_channels + self.out_stack_depth = out_stack_depth + + self.spatial_size = encoder_spatial_size + self.spatial_channels = latent_dim // (self.spatial_size * self.spatial_size) + + self.latent_reshape = nn.Linear( + latent_dim, self.spatial_channels * self.spatial_size * self.spatial_size + ) + self.latent_proj = nn.Conv2d( + self.spatial_channels, decoder_channels[0], kernel_size=1 + ) + + # Build the decoder stages + self.decoder_stages = nn.ModuleList() + num_stages = len(self.decoder_channels) - 1 + for i in range(num_stages): + stage = VaeUpStage( + in_channels=self.decoder_channels[i], + out_channels=self.decoder_channels[i + 1], + scale_factor=strides[i], + mode=upsample_mode, + conv_blocks=conv_blocks, + norm_name=norm_name, + upsample_pre_conv=upsample_pre_conv, + ) + self.decoder_stages.append(stage) + + # Head to convert back to 3D + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=self.out_channels, + out_stack_depth=self.out_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=head_pool, + ) + + def forward(self, z: Tensor) -> Tensor: + """Forward pass converting latent to 3D output.""" + + batch_size = z.size(0) + + # Reshape 1D latent back to spatial format so we can reconstruct the 2.5D image + z_spatial = self.latent_reshape(z) # [batch, spatial_channels * H * W] + z_spatial = z_spatial.view( + batch_size, self.spatial_channels, self.spatial_size, self.spatial_size + ) + + # Project spatial latent to first decoder channels using 1x1 conv + x = self.latent_proj( + z_spatial + ) # [batch, decoder_channels[0], spatial_H, spatial_W] + + for stage in self.decoder_stages: + x = stage(x) + + output = self.head(x) + + return output + + +class BetaVae25D(nn.Module): + """2.5D Beta-VAE combining VaeEncoder and VaeDecoder.""" + + def __init__( + self, + backbone: Literal["resnet50", "convnext_tiny"] = "resnet50", + in_channels: int = 2, + in_stack_depth: int = 16, + out_stack_depth: int = 16, + latent_dim: int = 1024, + input_spatial_size: tuple[int, int] = (256, 256), + stem_kernel_size: tuple[int, int, int] = (2, 4, 4), + stem_stride: tuple[int, int, int] = (2, 4, 4), + drop_path_rate: float = 0.0, + decoder_stages: int = 4, + head_expansion_ratio: int = 2, + head_pool: bool = False, + upsample_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle", + conv_blocks: int = 2, + norm_name: Literal["batch", "instance"] = "batch", + upsample_pre_conv: Literal["default"] | Callable | None = None, + ): + super().__init__() + + self.encoder = VaeEncoder( + backbone=backbone, + in_channels=in_channels, + in_stack_depth=in_stack_depth, + latent_dim=latent_dim, + input_spatial_size=input_spatial_size, + stem_kernel_size=stem_kernel_size, + stem_stride=stem_stride, + drop_path_rate=drop_path_rate, + ) + + base_channels = self.encoder.num_channels[-1] + decoder_channels = [base_channels] + for i in range(decoder_stages - 1): + decoder_channels.append(base_channels // (2 ** (i + 1))) + decoder_channels.append( + (out_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + + strides = [2] * decoder_stages + [1] + + self.decoder = VaeDecoder( + decoder_channels=decoder_channels, + latent_dim=latent_dim, + out_channels=in_channels, + out_stack_depth=out_stack_depth, + head_expansion_ratio=head_expansion_ratio, + head_pool=head_pool, + upsample_mode=upsample_mode, + conv_blocks=conv_blocks, + norm_name=norm_name, + upsample_pre_conv=upsample_pre_conv, + strides=strides, + encoder_spatial_size=self.encoder.encoder_spatial_size, + ) + + def forward(self, x: Tensor) -> SimpleNamespace: + """Forward pass returning VAE outputs.""" + encoder_output = self.encoder(x) + recon_x = self.decoder(encoder_output.z) + + return SimpleNamespace( + recon_x=recon_x, + mean=encoder_output.mean, + logvar=encoder_output.log_covariance, + z=encoder_output.z, + ) + + +class BetaVaeMonai(nn.Module): + """Beta-VAE with Monai architecture.""" + + def __init__( + self, + spatial_dims: int, + in_shape: Sequence[int], + out_channels: int, + latent_size: int, + channels: Sequence[int], + strides: Sequence[int] | Sequence[Sequence[int]], + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, + num_res_units: int = 0, + use_sigmoid: bool = False, + norm: Literal["batch", "instance"] = "instance", + **kwargs, + ): + super().__init__() + + self.spatial_dims = spatial_dims + self.in_shape = in_shape + self.out_channels = out_channels + self.latent_size = latent_size + self.channels = channels + self.strides = strides + self.kernel_size = kernel_size + self.up_kernel_size = up_kernel_size + self.num_res_units = num_res_units + self.use_sigmoid = use_sigmoid + self.norm = norm + if self.norm not in ["batch", "instance"]: + raise ValueError("norm must be 'batch' or 'instance'") + if self.norm == "batch": + self.norm = Norm.BATCH + else: + self.norm = Norm.INSTANCE + + self.model = VarAutoEncoder( + spatial_dims=self.spatial_dims, + in_shape=self.in_shape, + out_channels=self.out_channels, + latent_size=self.latent_size, + channels=self.channels, + strides=self.strides, + kernel_size=self.kernel_size, + up_kernel_size=self.up_kernel_size, + num_res_units=self.num_res_units, + use_sigmoid=self.use_sigmoid, + norm=self.norm, + **kwargs, + ) + + def forward(self, x: Tensor) -> SimpleNamespace: + """Forward pass returning VAE encoder outputs.""" + recon_x, mu, logvar, z = self.model(x) + return SimpleNamespace(recon_x=recon_x, mean=mu, logvar=logvar, z=z) diff --git a/viscy/representation/vae_logging.py b/viscy/representation/vae_logging.py new file mode 100644 index 000000000..b06894f9a --- /dev/null +++ b/viscy/representation/vae_logging.py @@ -0,0 +1,359 @@ +from typing import Callable, Optional, Tuple + +import numpy as np +import torch +from torchvision.utils import make_grid + + +class BetaVaeLogger: + """ + Enhanced logging utilities for β-VAE training with TensorBoard. + + Provides comprehensive logging of β-VAE specific metrics, visualizations, + and latent space analysis for microscopy data. + """ + + def __init__(self, latent_dim: int = 128): + self.latent_dim = latent_dim + self.device = None + + def setup(self, device: str): + """Initialize device-dependent components.""" + self.device = device + + def log_enhanced_metrics( + self, lightning_module, model_output: dict, batch: dict, stage: str = "train" + ): + """ + Log enhanced β-VAE metrics. + + Args: + lightning_module: Lightning module instance + model_output: VAE model output + batch: Input batch + stage: Training stage ("train" or "val") + """ + # Extract components + x = batch["anchor"] + + z = model_output["z"] + recon_x = model_output["recon_x"] + recon_loss = model_output["recon_loss"] + kl_loss = model_output["kl_loss"] + total_loss = model_output["total_loss"] + + # Get current β (scheduled value, not static) + beta = getattr( + lightning_module, + "_get_current_beta", + lambda: getattr(lightning_module, "beta", 1.0), + )() + + # Check for explosion and NaN/Inf + grad_diagnostics = self._compute_gradient_diagnostics(lightning_module) + nan_inf_diagnostics = self._check_nan_inf(recon_x, x, z) + + metrics = { + f"loss/{stage}/total": total_loss, + f"loss/{stage}/reconstruction": recon_loss, + f"loss/{stage}/kl": kl_loss, + f"beta/{stage}": beta, + } + + # Add diagnostic metrics + metrics.update(grad_diagnostics) + metrics.update(nan_inf_diagnostics) + + # Latent space statistics + latent_mean = torch.mean(z, dim=0) + latent_std = torch.std(z, dim=0) + + active_dims = torch.sum(torch.var(z, dim=0) > 0.01) + variances = torch.var(z, dim=0) + effective_dim = torch.sum(variances) ** 2 / torch.sum(variances**2) + + metrics.update( + { + # Consolidated latent statistics + f"latent_statistics/mean_avg/{stage}": torch.mean(latent_mean), + f"latent_statistics/std_avg/{stage}": torch.mean(latent_std), + f"latent_statistics/mean_max/{stage}": torch.max(latent_mean), + f"latent_statistics/std_max/{stage}": torch.max(latent_std), + f"latent_statistics/active_dims/{stage}": active_dims.float(), + f"latent_statistics/effective_dim/{stage}": effective_dim, + f"latent_statistics/utilization/{stage}": active_dims / self.latent_dim, + } + ) + + # Log all metrics + lightning_module.log_dict( + metrics, + on_step=False, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + # Log latent dimension histograms (periodically) + if stage == "val" and lightning_module.current_epoch % 10 == 0: + self._log_latent_histograms(lightning_module, z, stage) + + def _compute_gradient_diagnostics(self, lightning_module): + """Compute gradient norms and parameter statistics for explosion detection.""" + grad_diagnostics = {} + + # Compute gradient norms for encoder and decoder + encoder_grad_norm = 0.0 + decoder_grad_norm = 0.0 + encoder_param_norm = 0.0 + decoder_param_norm = 0.0 + + for name, param in lightning_module.named_parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + if "encoder" in name: + encoder_grad_norm += param_norm.item() ** 2 + elif "decoder" in name: + decoder_grad_norm += param_norm.item() ** 2 + + # Parameter magnitudes + if "encoder" in name: + encoder_param_norm += param.data.norm(2).item() ** 2 + elif "decoder" in name: + decoder_param_norm += param.data.norm(2).item() ** 2 + + grad_diagnostics.update( + { + "diagnostics/encoder_grad_norm": encoder_grad_norm**0.5, + "diagnostics/decoder_grad_norm": decoder_grad_norm**0.5, + "diagnostics/encoder_param_norm": encoder_param_norm**0.5, + "diagnostics/decoder_param_norm": decoder_param_norm**0.5, + } + ) + + return grad_diagnostics + + def _check_nan_inf(self, recon_x, x, z): + """Check for NaN/Inf values in tensors.""" + diagnostics = { + "diagnostics/recon_has_nan": torch.isnan(recon_x).any().float(), + "diagnostics/recon_has_inf": torch.isinf(recon_x).any().float(), + "diagnostics/input_has_nan": torch.isnan(x).any().float(), + "diagnostics/latent_has_nan": torch.isnan(z).any().float(), + "diagnostics/recon_max_val": torch.max(torch.abs(recon_x)), + "diagnostics/recon_min_val": torch.min(recon_x), + } + return diagnostics + + def _log_latent_histograms(self, lightning_module, z: torch.Tensor, stage: str): + """Log histograms of latent dimensions.""" + z_np = z.detach().cpu().numpy() + + # Log first 16 dimensions to avoid clutter + n_dims_to_log = min(16, z_np.shape[1]) + + for i in range(n_dims_to_log): + lightning_module.logger.experiment.add_histogram( + f"latent_distributions/dim_{i}_{stage}", + z_np[:, i], + lightning_module.current_epoch, + ) + + def log_latent_traversal( + self, + lightning_module, + n_dims: int = 8, + n_steps: int = 11, + range_vals: Tuple[float, float] = (-3, 3), + ): + """ + Log latent space traversal visualizations. + + Args: + lightning_module: Lightning module instance + n_dims: Number of latent dimensions to traverse + n_steps: Number of steps in traversal + range_vals: Range of values to traverse + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + # Sample a base latent vector + z_base = torch.randn(1, self.latent_dim, device=lightning_module.device) + + # Traverse each dimension + for dim in range(min(n_dims, self.latent_dim)): + traversal_images = [] + + for val in np.linspace(range_vals[0], range_vals[1], n_steps): + z_modified = z_base.clone() + z_modified[0, dim] = val + + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_modified) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + traversal_images.append(img_2d) + + # Create grid + grid = make_grid( + torch.stack(traversal_images).unsqueeze(1), + nrow=n_steps, + normalize=True, + ) + + lightning_module.logger.experiment.add_image( + f"latent_traversal/dim_{dim}", + grid, + lightning_module.current_epoch, + dataformats="CHW", + ) + + def log_latent_interpolation( + self, lightning_module, n_pairs: int = 3, n_steps: int = 11 + ): + """ + Log latent space interpolation between random pairs. + + Args: + lightning_module: Lightning module instance + n_pairs: Number of interpolation pairs + n_steps: Number of interpolation steps + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + for pair_idx in range(n_pairs): + # Sample two random latent vectors + z1 = torch.randn(1, self.latent_dim, device=lightning_module.device) + z2 = torch.randn(1, self.latent_dim, device=lightning_module.device) + + interp_images = [] + + for alpha in np.linspace(0, 1, n_steps): + z_interp = alpha * z1 + (1 - alpha) * z2 + + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_interp) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + interp_images.append(img_2d) + + # Create grid + grid = make_grid( + torch.stack(interp_images).unsqueeze(1), + nrow=n_steps, + normalize=True, + ) + + lightning_module.logger.experiment.add_image( + f"latent_interpolation/pair_{pair_idx}", + grid, + lightning_module.current_epoch, + dataformats="CHW", + ) + + def log_factor_traversal_matrix( + self, lightning_module, n_dims: int = 8, n_steps: int = 7 + ): + """ + Log factor traversal matrix showing effect of each latent dimension. + + Args: + lightning_module: Lightning module instance + n_dims: Number of latent dimensions to show + n_steps: Number of steps per dimension + """ + if not hasattr(lightning_module, "model"): + return + + lightning_module.model.eval() + + with torch.no_grad(): + # Base latent vector + z_base = torch.randn(1, self.latent_dim, device=lightning_module.device) + + matrix_rows = [] + + for dim in range(min(n_dims, self.latent_dim)): + row_images = [] + + for step in range(n_steps): + val = -3 + 6 * step / (n_steps - 1) # Range [-3, 3] + z_mod = z_base.clone() + z_mod[0, dim] = val + + # Generate reconstruction using lightning module's decoder + recon = lightning_module.decoder(z_mod) + + # Take middle z-slice for visualization + mid_z = recon.shape[2] // 2 + img_2d = recon[0, 0, mid_z].cpu() # First channel, middle z-slice + + # Normalize for visualization + img_2d = (img_2d - img_2d.min()) / ( + img_2d.max() - img_2d.min() + 1e-8 + ) + row_images.append(img_2d) + + matrix_rows.append(torch.stack(row_images)) + + # Create matrix grid + all_images = torch.cat(matrix_rows, dim=0) + grid = make_grid(all_images.unsqueeze(1), nrow=n_steps, normalize=True) + + lightning_module.logger.experiment.add_image( + "factor_traversal_matrix", + grid, + lightning_module.current_epoch, + dataformats="CHW", + ) + + def log_beta_schedule( + self, lightning_module, beta_schedule: Optional[Callable] = None + ): + """ + Log β annealing schedule. + + Args: + lightning_module: Lightning module instance + beta_schedule: Function that returns β value for given epoch + """ + if beta_schedule is None: + # Default β schedule + max_epochs = lightning_module.trainer.max_epochs + epoch = lightning_module.current_epoch + + if epoch < max_epochs * 0.1: # Warm up + beta = 0.1 + elif epoch < max_epochs * 0.5: # Gradual increase + beta = 0.1 + (4.0 - 0.1) * (epoch - max_epochs * 0.1) / ( + max_epochs * 0.4 + ) + else: # Final β + beta = 4.0 + else: + beta = beta_schedule(lightning_module.current_epoch) + + lightning_module.log("beta_schedule", beta) + return beta diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 7ce9f0f8b..78a9b2487 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -23,6 +23,7 @@ from viscy.transforms._redef import ( CenterSpatialCropd, Decollated, + NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, @@ -85,6 +86,7 @@ "Decollate", "Decollated", "NormalizeSampled", + "NormalizeIntensityd", "RandAdjustContrastd", "RandAffined", "RandFlipd", diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 696c81abc..258d3a274 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -5,6 +5,7 @@ from monai.transforms import ( CenterSpatialCropd, Decollated, + NormalizeIntensityd, RandAdjustContrastd, RandAffined, RandFlipd, @@ -187,13 +188,17 @@ def __init__( ): super().__init__(keys=keys, roi_size=roi_size, **kwargs) + class RandFlipd(RandFlipd): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + spatial_axis: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) -class RandFlipd(RandFlipd): - def __init__( - self, - keys: Sequence[str] | str, - prob: float, - spatial_axis: Sequence[int] | int, - **kwargs, - ): - super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) + +class NormalizeIntensityd(NormalizeIntensityd): + def __init__(self, keys: Sequence[str] | str, **kwargs): + super().__init__(keys=keys, **kwargs)