diff --git a/examples/DynaCLR/vcp_tutorials/README.md b/examples/DynaCLR/vcp_tutorials/README.md new file mode 100644 index 000000000..c4e8fb763 --- /dev/null +++ b/examples/DynaCLR/vcp_tutorials/README.md @@ -0,0 +1,17 @@ +# Virtual Cell Platform Tutorials + +This directory contains tutorial notebooks for the Virtual Cell Platform, +available in both Python scripts and Jupyter notebooks. + +- [Quick Start](quick_start.ipynb): +get started with model inference in Python with a A549 cell dataset. + +## Development + +The development happens on the Python scripts, +which are converted to Jupyter notebooks with: + +```sh +# TODO: change the file name at the end to be the script to convert +jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update quickstart.py +``` diff --git a/examples/DynaCLR/vcp_tutorials/quickstart.ipynb b/examples/DynaCLR/vcp_tutorials/quickstart.ipynb new file mode 100644 index 000000000..642825e26 --- /dev/null +++ b/examples/DynaCLR/vcp_tutorials/quickstart.ipynb @@ -0,0 +1,743 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "36b436bf", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "# Quick Start: DynaCLR\n", + "## Cell Dynamics Contrastive Learning of Representations\n", + "\n", + "**Estimated time to complete:** 25-30 minutes" + ] + }, + { + "cell_type": "markdown", + "id": "c002c086", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "## Learning Goals\n", + "\n", + "* Download the DynaCLR model and run it on an example dataset\n", + "* Visualize the learned embeddings" + ] + }, + { + "cell_type": "markdown", + "id": "2ca8c339", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "## Prerequisites\n", + "- Python>=3.11" + ] + }, + { + "cell_type": "markdown", + "id": "1818081a", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "## Introduction\n", + "\n", + "### Model\n", + "The DynaCLR model architecture consists of three main components designed to map 3D multi-channel patches of single cells to a temporally regularized embedding space.\n", + "\n", + "### Example Dataset\n", + "\n", + "The A549 example dataset used in this quick-start guide contains\n", + "quantitative phase and paired fluorescence images of viral sensor reporter.\n", + "It is stored in OME-Zarr format and can be downloaded from\n", + "[here](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/registered_test.zarr/).\n", + "\n", + "It has pre-computed statistics for normalization, generated using the `viscy preprocess` CLI.\n", + "\n", + "Refer to our [preprint](https://arxiv.org/abs/2410.11281) for more details\n", + "about how the dataset and model were generated.\n", + "\n", + "### User Data\n", + "\n", + "The DynaCLR-DENV-VS+Ph model only requires label-free (quantitative phase) and fluorescence images for inference.\n", + "\n", + "To run inference on your own data (Experimental):\n", + "- Convert the label-free images into the OME-Zarr data format using iohub or other\n", + "[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion),\n", + "- Run [pre-processing](https://github.com/mehta-lab/VisCy/blob/main/docs/usage.md#preprocessing)\n", + "with the `viscy preprocess` CLI\n", + "- Generate pseudo-tracks or tracking data from [Ultrack](https://github.com/royerlab/ultrack)" + ] + }, + { + "cell_type": "markdown", + "id": "ad63eb9e", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "### Setup\n", + "\n", + "The commands below will install the required packages and download the example dataset and model checkpoint.\n", + "\n", + "Setup notes:\n", + "\n", + "- **Setting up Google Colab**: To run this quickstart guide using Google Colab, choose the 'T4' GPU runtime from the 'Connect' dropdown menu in the upper-right corner of this notebook for faster execution.\n", + "Using a GPU significantly speeds up running model inference, but CPU compute can also be used.\n", + "\n", + "- **Google Colab Kaggle prompt**: When running `datamodule.setup(\"predict\")`, Colab may prompt for Kaggle credentials. This is a Colab-specific behavior triggered by certain file I/O patterns and can be safely dismissed by clicking \"Cancel\" - no Kaggle account is required for this tutorial.\n", + "\n", + "- **Setting up local environment**: The commands below assume a Unix-like shell with `wget` installed. On Windows, the files can be downloaded manually from the URLs.\n", + "\n", + "### Install VisCy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69b3b31b", + "metadata": {}, + "outputs": [], + "source": [ + "# Install VisCy with the optional dependencies for this example\n", + "# See the [repository](https://github.com/mehta-lab/VisCy) for more details\n", + "!pip install \"viscy[metrics,visual,phate]==0.4.0a3\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d860546d", + "metadata": {}, + "outputs": [], + "source": [ + "# Restart kernel if running in Google Colab\n", + "if \"get_ipython\" in globals():\n", + " session = get_ipython() # noqa: F821\n", + " if \"google.colab\" in str(session):\n", + " print(\"Shutting down colab session.\")\n", + " session.kernel.do_shutdown(restart=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ea0587b", + "metadata": {}, + "outputs": [], + "source": [ + "# Validate installation\n", + "!viscy --help" + ] + }, + { + "cell_type": "markdown", + "id": "98cdb574", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "### Download example data and model checkpoint\n", + "Estimated download time: 15-20 minutes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dec2a9e", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# Download the example tracks data (5-8 minutes)\n", + "!wget -m -np -nH --cut-dirs=6 -R \"index.html*\" \"https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/track_test.zarr/\"\n", + "# Download the example registered timelapse data (5-10 minutes)\n", + "!wget -m -np -nH --cut-dirs=6 -R \"index.html*\" \"https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/registered_test_demo_crop.zarr/\"\n", + "# Download the model checkpoint (3 minutes)\n", + "!wget -m -np -nH --cut-dirs=5 \"index.html*\" \"https://public.czbiohub.org/comp.micro/viscy/DynaCLR_models/DynaCLR-DENV/VS_n_Ph/epoch=94-step=2375.ckpt\"\n", + "# Download the annotations for the infected state\n", + "!wget -m -np -nH --cut-dirs=6 \"index.html*\" \"https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/extracted_inf_state.csv\"" + ] + }, + { + "cell_type": "markdown", + "id": "dc74d3e7", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Run Model Inference\n", + "\n", + "The following code will run inference on a single field of view (FOV) of the example dataset.\n", + "This can also be achieved by using the VisCy CLI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c5bbe59", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path # noqa: E402\n", + "\n", + "import matplotlib.pyplot as plt # noqa: E402\n", + "import pandas as pd # noqa: E402\n", + "import seaborn as sns # noqa: E402\n", + "from anndata import read_zarr # noqa: E402\n", + "from iohub import open_ome_zarr # noqa: E402\n", + "from torchview import draw_graph # noqa: E402\n", + "\n", + "from viscy.data.triplet import TripletDataModule # noqa: E402\n", + "from viscy.representation.embedding_writer import EmbeddingWriter # noqa: E402\n", + "from viscy.representation.engine import (\n", + " ContrastiveEncoder,\n", + " ContrastiveModule,\n", + ") # noqa: E402\n", + "from viscy.trainer import VisCyTrainer # noqa: E402\n", + "from viscy.transforms import ( # noqa: E402\n", + " NormalizeSampled,\n", + " ScaleIntensityRangePercentilesd,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2764122", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Nothing needs to be changed in this code block for the example to work.\n", + "# If using your own data, please modify the paths below.\n", + "\n", + "# TODO: Set download paths, by default the working directory is used\n", + "root_dir = Path(\"\")\n", + "# TODO: modify the path to the input dataset\n", + "input_data_path = root_dir / \"registered_test_demo_crop.zarr\"\n", + "# TODO: modify the path to the track dataset\n", + "tracks_path = root_dir / \"track_test.zarr\"\n", + "# TODO: modify the path to the model checkpoint\n", + "model_ckpt_path = root_dir / \"epoch=94-step=2375.ckpt\"\n", + "# TODO\" modify the path to load the extracted infected cell annotation\n", + "annotations_path = root_dir / \"extracted_inf_state.csv\"\n", + "\n", + "# TODO: modify the path to save the predictions\n", + "output_path = root_dir / \"dynaclr_prediction.zarr\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86121d5a", + "metadata": {}, + "outputs": [], + "source": [ + "# Default parameters for the test dataset\n", + "z_range = [0, 30]\n", + "yx_patch_size = (160, 160)\n", + "channels_to_display = [\"Phase3D\", \"RFP\"] # label-free and viral sensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afd7a12e", + "metadata": {}, + "outputs": [], + "source": [ + "# Configure the data module for loading example images in prediction mode.\n", + "# See API documentation for how to use it with a different dataset.\n", + "# For example, View the documentation for the TripletDataModule class by running:\n", + "?TripletDataModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd1a8063", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup the data module to use the example dataset\n", + "datamodule = TripletDataModule(\n", + " data_path=input_data_path,\n", + " tracks_path=tracks_path,\n", + " source_channel=channels_to_display,\n", + " z_range=z_range,\n", + " initial_yx_patch_size=yx_patch_size,\n", + " final_yx_patch_size=yx_patch_size,\n", + " # predict_cells=True,\n", + " batch_size=64, # TODO reduce this number if you see OOM errors when running the trainer\n", + " num_workers=1,\n", + " normalizations=[\n", + " NormalizeSampled(\n", + " [\"Phase3D\"],\n", + " level=\"fov_statistics\",\n", + " subtrahend=\"mean\",\n", + " divisor=\"std\",\n", + " ),\n", + " ScaleIntensityRangePercentilesd(\n", + " [\"RFP\"],\n", + " lower=50,\n", + " upper=99,\n", + " b_min=0.0,\n", + " b_max=1.0,\n", + " ),\n", + " ],\n", + ")\n", + "datamodule.setup(\"predict\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d6960dc", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the DynaCLR checkpoint from the downloaded checkpoint\n", + "# See this module for options to configure the model:\n", + "\n", + "?ContrastiveModule\n", + "?ContrastiveEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "886229f2", + "metadata": {}, + "outputs": [], + "source": [ + "dynaclr_model = ContrastiveModule.load_from_checkpoint(\n", + " model_ckpt_path, # checkpoint path\n", + " encoder=ContrastiveEncoder(\n", + " backbone=\"convnext_tiny\",\n", + " in_channels=len(channels_to_display),\n", + " in_stack_depth=z_range[1] - z_range[0],\n", + " stem_kernel_size=(5, 4, 4),\n", + " stem_stride=(5, 4, 4),\n", + " embedding_dim=768,\n", + " projection_dim=32,\n", + " drop_path_rate=0.0,\n", + " ),\n", + " example_input_array_shape=(1, 2, 30, 256, 256),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "892b5385", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize the model graph\n", + "model_graph = draw_graph(\n", + " dynaclr_model,\n", + " dynaclr_model.example_input_array,\n", + " graph_name=\"DynaCLR\",\n", + " roll=True,\n", + " depth=3,\n", + " expand_nested=True,\n", + ")\n", + "\n", + "model_graph.visual_graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1cc8edb", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup the trainer for prediction\n", + "# The trainer can be further configured to better utilize the available hardware,\n", + "# For example using GPUs and half precision.\n", + "# Callbacks can also be used to customize logging and prediction writing.\n", + "# See the API documentation for more details:\n", + "?VisCyTrainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4477cd99", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the trainer\n", + "# The prediction writer callback will save the predictions to an OME-Zarr store\n", + "trainer = VisCyTrainer(\n", + " callbacks=[\n", + " EmbeddingWriter(\n", + " output_path,\n", + " pca_kwargs={\"n_components\": 8},\n", + " phate_kwargs={\"knn\": 5, \"decay\": 40, \"n_jobs\": -1},\n", + " )\n", + " ]\n", + ")\n", + "\n", + "# Run prediction\n", + "trainer.predict(model=dynaclr_model, datamodule=datamodule, return_predictions=False)" + ] + }, + { + "cell_type": "markdown", + "id": "0b3f7a24", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Model Outputs\n", + "\n", + "The model outputs are also stored in an ANNData. The embeddings can then be visualized with a dimensionality reduction method (i.e UMAP, PHATE, PCA)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "907fe5df", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# NOTE: We have chosen these tracks to be representative of the data. Feel free to open the dataset and select other tracks\n", + "features_anndata = read_zarr(output_path)\n", + "annotation = pd.read_csv(annotations_path)\n", + "ANNOTATION_COLUMN = \"infection_state\"\n", + "\n", + "# Combine embeddings and annotations\n", + "# Reload annotation to ensure clean state (in case cell is re-run)\n", + "annotation = pd.read_csv(annotations_path)\n", + "\n", + "# Strip whitespace from fov_name to match features\n", + "annotation[\"fov_name\"] = annotation[\"fov_name\"].str.strip()\n", + "\n", + "# Merge on (fov_name, track_id, t) as these uniquely identify each cell observation\n", + "annotation_indexed = annotation.set_index([\"fov_name\", \"track_id\", \"t\"])\n", + "mi = pd.MultiIndex.from_arrays(\n", + " [\n", + " features_anndata.obs[\"fov_name\"],\n", + " features_anndata.obs[\"track_id\"],\n", + " features_anndata.obs[\"t\"],\n", + " ],\n", + " names=[\"fov_name\", \"track_id\", \"t\"],\n", + ")\n", + "features_anndata.obs[\"annotations_infections_state\"] = annotation_indexed.reindex(mi)[\n", + " ANNOTATION_COLUMN\n", + "].values\n", + "\n", + "# Plot the PCA and PHATE embeddings colored by infection state\n", + "# Prepare data for plotting\n", + "# Map numeric labels to readable labels for legend\n", + "infection_state_labels = {0: \"Unknown\", 1: \"Uninfected\", 2: \"Infected\"}\n", + "\n", + "plot_df = pd.DataFrame(\n", + " {\n", + " \"PC1\": features_anndata.obsm[\"X_pca\"][:, 0],\n", + " \"PC2\": features_anndata.obsm[\"X_pca\"][:, 1],\n", + " \"PHATE1\": features_anndata.obsm[\"X_phate\"][:, 0],\n", + " \"PHATE2\": features_anndata.obsm[\"X_phate\"][:, 1],\n", + " \"infection_state\": features_anndata.obs[\"annotations_infections_state\"]\n", + " .fillna(0)\n", + " .map(infection_state_labels),\n", + " }\n", + ")\n", + "\n", + "# Define color palette (colorblind-friendly: blue for uninfected, orange for infected)\n", + "color_palette = {\n", + " \"Unknown\": \"lightgray\", # Unlabeled\n", + " \"Uninfected\": \"cornflowerblue\", # Uninfected\n", + " \"Infected\": \"darkorange\", # Infected\n", + "}\n", + "\n", + "# Create figure with two subplots\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", + "\n", + "# Plot PCA\n", + "sns.scatterplot(\n", + " data=plot_df,\n", + " x=\"PC1\",\n", + " y=\"PC2\",\n", + " hue=\"infection_state\",\n", + " palette=color_palette,\n", + " ax=axes[0],\n", + " alpha=0.6,\n", + " s=20,\n", + ")\n", + "axes[0].set_title(\"PCA Embedding\")\n", + "axes[0].set_xlabel(\"PC1\")\n", + "axes[0].set_ylabel(\"PC2\")\n", + "\n", + "# Plot PHATE\n", + "sns.scatterplot(\n", + " data=plot_df,\n", + " x=\"PHATE1\",\n", + " y=\"PHATE2\",\n", + " hue=\"infection_state\",\n", + " palette=color_palette,\n", + " ax=axes[1],\n", + " alpha=0.6,\n", + " s=20,\n", + ")\n", + "axes[1].set_title(\"PHATE Embedding\")\n", + "axes[1].set_xlabel(\"PHATE 1\")\n", + "axes[1].set_ylabel(\"PHATE 2\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5c107401", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "## Visualize Images Over Time\n", + "Below we show phase and fluorescence images of the uninfected and infected cells over time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934fcb12", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: We have chosen these tracks to be representative of the data. Feel free to open the dataset and select other tracks\n", + "fov_name_mock = \"A/3/9\"\n", + "track_id_mock = [19]\n", + "fov_name_inf = \"B/4/9\"\n", + "track_id_inf = [42]\n", + "\n", + "\n", + "## Show the images over time\n", + "def get_patch(data, cell_centroid, patch_size):\n", + " \"\"\"Extract patch centered on cell centroid across all channels.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : ndarray\n", + " Image data with shape (C, Y, X) or (Y, X)\n", + " cell_centroid : tuple\n", + " (y, x) coordinates of cell centroid\n", + " patch_size : int\n", + " Size of the square patch to extract\n", + "\n", + " Returns\n", + " -------\n", + " ndarray\n", + " Extracted patch with shape (C, patch_size, patch_size) or (patch_size, patch_size)\n", + " \"\"\"\n", + " y_centroid, x_centroid = cell_centroid\n", + " x_start = max(0, x_centroid - patch_size // 2)\n", + " x_end = min(data.shape[-1], x_centroid + patch_size // 2)\n", + " y_start = max(0, y_centroid - patch_size // 2)\n", + " y_end = min(data.shape[-2], y_centroid + patch_size // 2)\n", + "\n", + " if data.ndim == 3: # CYX format\n", + " patch = data[:, int(y_start) : int(y_end), int(x_start) : int(x_end)]\n", + " else: # YX format\n", + " patch = data[int(y_start) : int(y_end), int(x_start) : int(x_end)]\n", + " return patch\n", + "\n", + "\n", + "# Open the dataset\n", + "plate = open_ome_zarr(input_data_path)\n", + "uninfected_position = plate[fov_name_mock]\n", + "infected_position = plate[fov_name_inf]\n", + "\n", + "# Get channel indices for the channels we want to display\n", + "channel_names = uninfected_position.channel_names\n", + "channels_to_display_idx = [channel_names.index(c) for c in channels_to_display]\n", + "\n", + "# Filter the centroids of these two tracks\n", + "filtered_centroid_mock = features_anndata.obs[\n", + " (features_anndata.obs[\"fov_name\"] == fov_name_mock)\n", + " & (features_anndata.obs[\"track_id\"].isin(track_id_mock))\n", + "].sort_values(\"t\")\n", + "filtered_centroid_inf = features_anndata.obs[\n", + " (features_anndata.obs[\"fov_name\"] == fov_name_inf)\n", + " & (features_anndata.obs[\"track_id\"].isin(track_id_inf))\n", + "].sort_values(\"t\")\n", + "\n", + "# Define patch size for visualization\n", + "patch_size = 160\n", + "\n", + "# Extract patches for uninfected cells over time\n", + "import numpy as np\n", + "\n", + "uinfected_stack = []\n", + "for idx, row in filtered_centroid_mock.iterrows():\n", + " t = int(row[\"t\"])\n", + " # Load the image data for this timepoint (CZYX format), select only required channels\n", + " img_data = uninfected_position.data[\n", + " t, channels_to_display_idx, z_range[0] : z_range[1]\n", + " ]\n", + " # For Phase3D take middle slice, for fluorescence take max projection\n", + " cyx = []\n", + " for ch_idx, ch_name in enumerate(channels_to_display):\n", + " if ch_name == \"Phase3D\":\n", + " # Take middle Z slice for phase\n", + " mid_z = img_data.shape[1] // 2\n", + " cyx.append(img_data[ch_idx, mid_z, :, :])\n", + " else:\n", + " # Max projection for fluorescence\n", + " cyx.append(img_data[ch_idx].max(axis=0))\n", + " cyx = np.array(cyx)\n", + " uinfected_stack.append(get_patch(cyx, (row[\"y\"], row[\"x\"]), patch_size))\n", + "uinfected_stack = np.array(uinfected_stack)\n", + "\n", + "# Extract patches for infected cells over time\n", + "infected_stack = []\n", + "for idx, row in filtered_centroid_inf.iterrows():\n", + " t = int(row[\"t\"])\n", + " # Load the image data for this timepoint (CZYX format), select only required channels\n", + " img_data = infected_position.data[\n", + " t, channels_to_display_idx, z_range[0] : z_range[1]\n", + " ]\n", + " # For Phase3D take middle slice, for fluorescence take max projection\n", + " cyx = []\n", + " for ch_idx, ch_name in enumerate(channels_to_display):\n", + " if ch_name == \"Phase3D\":\n", + " # Take middle Z slice for phase\n", + " mid_z = img_data.shape[1] // 2\n", + " cyx.append(img_data[ch_idx, mid_z, :, :])\n", + " else:\n", + " # Max projection for fluorescence\n", + " cyx.append(img_data[ch_idx].max(axis=0))\n", + " cyx = np.array(cyx)\n", + " infected_stack.append(get_patch(cyx, (row[\"y\"], row[\"x\"]), patch_size))\n", + "infected_stack = np.array(infected_stack)\n", + "\n", + "# Interactive visualization for Google Colab\n", + "# This creates an interactive widget to scrub through timepoints\n", + "try:\n", + " import numpy as np\n", + " from ipywidgets import IntSlider, interact\n", + "\n", + " max_t = min(len(uinfected_stack), len(infected_stack))\n", + "\n", + " def plot_timepoint(t):\n", + " \"\"\"Plot both infected and uninfected cells at a specific timepoint\"\"\"\n", + " fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", + " fig.suptitle(f\"Timepoint: {t}\", fontsize=16)\n", + "\n", + " # Plot uninfected cell\n", + " for channel_idx, channel_name in enumerate(channels_to_display):\n", + " ax = axes[0, channel_idx]\n", + " img = uinfected_stack[t, channel_idx, :, :]\n", + " ax.imshow(img, cmap=\"gray\")\n", + " ax.set_title(f\"Uninfected - {channel_name}\")\n", + " ax.axis(\"off\")\n", + "\n", + " # Plot infected cell\n", + " channel_names = uninfected_position.channel_names\n", + " channels_to_display_idx = [channel_names.index(c) for c in channels_to_display]\n", + " for channel_idx, channel_name in enumerate(channels_to_display_idx):\n", + " ax = axes[1, channel_idx]\n", + " img = infected_stack[t, channel_idx, :, :]\n", + " ax.imshow(img, cmap=\"gray\")\n", + " ax.set_title(f\"Infected - {channel_name}\")\n", + " ax.axis(\"off\")\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " # Create interactive slider\n", + " interact(\n", + " plot_timepoint,\n", + " t=IntSlider(min=0, max=max_t - 1, step=1, value=0, description=\"Timepoint:\"),\n", + " )\n", + "\n", + "except ImportError:\n", + " # Fallback to static plot if ipywidgets not available\n", + " print(\"ipywidgets not available, showing static plots instead\")\n", + "\n", + " # Plot 10 equally spaced timepoints\n", + " n_timepoints = 10\n", + " max_t = min(len(uinfected_stack), len(infected_stack))\n", + " timepoint_indices = np.linspace(0, max_t - 1, n_timepoints, dtype=int)\n", + "\n", + " # Create figure with 2 rows (channels) x 10 columns (timepoints) for uninfected\n", + " fig, axes = plt.subplots(2, n_timepoints, figsize=(20, 4))\n", + " fig.suptitle(\"Uninfected Cell Over Time\", fontsize=16, y=1.02)\n", + " channel_names = uninfected_position.channel_names\n", + " channels_to_display_idx = [channel_names.index(c) for c in channels_to_display]\n", + " for channel_idx, channel_name in enumerate(channels_to_display):\n", + " for col_idx, t_idx in enumerate(timepoint_indices):\n", + " ax = axes[channel_idx, col_idx]\n", + " img = uinfected_stack[t_idx, channel_idx, :, :]\n", + " ax.imshow(img, cmap=\"gray\")\n", + " ax.axis(\"off\")\n", + " if channel_idx == 0:\n", + " ax.set_title(f\"t={t_idx}\", fontsize=10)\n", + " if col_idx == 0:\n", + " ax.set_ylabel(channel_name, fontsize=12)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " # Create figure with 2 rows (channels) x 10 columns (timepoints) for infected\n", + " fig, axes = plt.subplots(2, n_timepoints, figsize=(20, 4))\n", + " fig.suptitle(\"Infected Cell Over Time\", fontsize=16, y=1.02)\n", + "\n", + " for channel_idx, channel_name in enumerate(channels_to_display):\n", + " for col_idx, t_idx in enumerate(timepoint_indices):\n", + " ax = axes[channel_idx, col_idx]\n", + " img = infected_stack[t_idx, channel_idx, :, :]\n", + " ax.imshow(img, cmap=\"gray\")\n", + " ax.axis(\"off\")\n", + " if channel_idx == 0:\n", + " ax.set_title(f\"t={t_idx}\", fontsize=10)\n", + " if col_idx == 0:\n", + " ax.set_ylabel(channel_name, fontsize=12)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "85de10d5", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "## Contact Information\n", + "For issues with this notebook please contact eduardo.hirata@czbiohub.org.\n", + "\n", + "## Responsible Use\n", + "\n", + "We are committed to advancing the responsible development and use of artificial intelligence.\n", + "Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services.\n", + "\n", + "Should you have any security or privacy issues or questions related to the services,\n", + "please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "all", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/DynaCLR/vcp_tutorials/quickstart.py b/examples/DynaCLR/vcp_tutorials/quickstart.py new file mode 100644 index 000000000..8b8f53ce0 --- /dev/null +++ b/examples/DynaCLR/vcp_tutorials/quickstart.py @@ -0,0 +1,559 @@ +# %% [markdown] +""" +# Quick Start: DynaCLR +## Cell Dynamics Contrastive Learning of Representations + +**Estimated time to complete:** 25-30 minutes +""" + +# %% [markdown] +""" +## Learning Goals + +* Download the DynaCLR model and run it on an example dataset +* Visualize the learned embeddings +""" + +# %% [markdown] +""" +## Prerequisites +- Python>=3.11 + +""" + +# %% [markdown] +""" +## Introduction + +### Model +The DynaCLR model architecture consists of three main components designed to map 3D multi-channel patches of single cells to a temporally regularized embedding space. + +### Example Dataset + +The A549 example dataset used in this quick-start guide contains +quantitative phase and paired fluorescence images of viral sensor reporter. +It is stored in OME-Zarr format and can be downloaded from +[here](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/registered_test.zarr/). + +It has pre-computed statistics for normalization, generated using the `viscy preprocess` CLI. + +Refer to our [preprint](https://arxiv.org/abs/2410.11281) for more details +about how the dataset and model were generated. + +### User Data + +The DynaCLR-DENV-VS+Ph model only requires label-free (quantitative phase) and fluorescence images for inference. + +To run inference on your own data (Experimental): +- Convert the label-free images into the OME-Zarr data format using iohub or other +[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion), +- Run [pre-processing](https://github.com/mehta-lab/VisCy/blob/main/docs/usage.md#preprocessing) +with the `viscy preprocess` CLI +- Generate pseudo-tracks or tracking data from [Ultrack](https://github.com/royerlab/ultrack) +""" + +# %% [markdown] +""" +### Setup + +The commands below will install the required packages and download the example dataset and model checkpoint. + +Setup notes: + +- **Setting up Google Colab**: To run this quickstart guide using Google Colab, choose the 'T4' GPU runtime from the 'Connect' dropdown menu in the upper-right corner of this notebook for faster execution. +Using a GPU significantly speeds up running model inference, but CPU compute can also be used. + +- **Google Colab Kaggle prompt**: When running `datamodule.setup("predict")`, Colab may prompt for Kaggle credentials. This is a Colab-specific behavior triggered by certain file I/O patterns and can be safely dismissed by clicking "Cancel" - no Kaggle account is required for this tutorial. + +- **Setting up local environment**: The commands below assume a Unix-like shell with `wget` installed. On Windows, the files can be downloaded manually from the URLs. + +### Install VisCy +""" +# %% +# Install VisCy with the optional dependencies for this example +# See the [repository](https://github.com/mehta-lab/VisCy) for more details +# !pip install "viscy[metrics,visual,phate]==0.4.0a3" + +# %% +# Restart kernel if running in Google Colab +if "get_ipython" in globals(): + session = get_ipython() # noqa: F821 + if "google.colab" in str(session): + print("Shutting down colab session.") + session.kernel.do_shutdown(restart=True) + +# %% +# Validate installation +# !viscy --help + +# %% [markdown] +""" +### Download example data and model checkpoint +Estimated download time: 15-20 minutes +""" +# %% +# Download the example tracks data (5-8 minutes) +# !wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/track_test.zarr/" +# Download the example registered timelapse data (5-10 minutes) +# !wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/registered_test_demo_crop.zarr/" +# Download the model checkpoint (3 minutes) +# !wget -m -np -nH --cut-dirs=5 "index.html*" "https://public.czbiohub.org/comp.micro/viscy/DynaCLR_models/DynaCLR-DENV/VS_n_Ph/epoch=94-step=2375.ckpt" +# Download the annotations for the infected state +# !wget -m -np -nH --cut-dirs=6 "index.html*" "https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/extracted_inf_state.csv" + + +# %% [markdown] +""" +## Run Model Inference + +The following code will run inference on a single field of view (FOV) of the example dataset. +This can also be achieved by using the VisCy CLI. +""" +# %% +from pathlib import Path # noqa: E402 + +import matplotlib.pyplot as plt # noqa: E402 +import pandas as pd # noqa: E402 +import seaborn as sns # noqa: E402 +from anndata import read_zarr # noqa: E402 +from iohub import open_ome_zarr # noqa: E402 +from torchview import draw_graph # noqa: E402 + +from viscy.data.triplet import TripletDataModule # noqa: E402 +from viscy.representation.embedding_writer import EmbeddingWriter # noqa: E402 +from viscy.representation.engine import ( + ContrastiveEncoder, + ContrastiveModule, +) # noqa: E402 +from viscy.trainer import VisCyTrainer # noqa: E402 +from viscy.transforms import ( # noqa: E402 + NormalizeSampled, + ScaleIntensityRangePercentilesd, +) + +# %% +# NOTE: Nothing needs to be changed in this code block for the example to work. +# If using your own data, please modify the paths below. + +# TODO: Set download paths, by default the working directory is used +root_dir = Path("") +# TODO: modify the path to the input dataset +input_data_path = root_dir / "registered_test_demo_crop.zarr" +# TODO: modify the path to the track dataset +tracks_path = root_dir / "track_test.zarr" +# TODO: modify the path to the model checkpoint +model_ckpt_path = root_dir / "epoch=94-step=2375.ckpt" +# TODO" modify the path to load the extracted infected cell annotation +annotations_path = root_dir / "extracted_inf_state.csv" + +# TODO: modify the path to save the predictions +output_path = root_dir / "dynaclr_prediction.zarr" + +# %% +# Default parameters for the test dataset +z_range = [0, 30] +yx_patch_size = (160, 160) +channels_to_display = ["Phase3D", "RFP"] # label-free and viral sensor + +# %% +# Configure the data module for loading example images in prediction mode. +# See API documentation for how to use it with a different dataset. +# For example, View the documentation for the TripletDataModule class by running: +# ?TripletDataModule + +# %% +# Setup the data module to use the example dataset +datamodule = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=channels_to_display, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + # predict_cells=True, + batch_size=64, # TODO reduce this number if you see OOM errors when running the trainer + num_workers=1, + normalizations=[ + NormalizeSampled( + ["Phase3D"], + level="fov_statistics", + subtrahend="mean", + divisor="std", + ), + ScaleIntensityRangePercentilesd( + ["RFP"], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + ), + ], +) +datamodule.setup("predict") + +# %% +# Load the DynaCLR checkpoint from the downloaded checkpoint +# See this module for options to configure the model: + +# ?ContrastiveModule +# ?ContrastiveEncoder + +# %% +dynaclr_model = ContrastiveModule.load_from_checkpoint( + model_ckpt_path, # checkpoint path + encoder=ContrastiveEncoder( + backbone="convnext_tiny", + in_channels=len(channels_to_display), + in_stack_depth=z_range[1] - z_range[0], + stem_kernel_size=(5, 4, 4), + stem_stride=(5, 4, 4), + embedding_dim=768, + projection_dim=32, + drop_path_rate=0.0, + ), + example_input_array_shape=(1, 2, 30, 256, 256), +) + +# %% +# Visualize the model graph +model_graph = draw_graph( + dynaclr_model, + dynaclr_model.example_input_array, + graph_name="DynaCLR", + roll=True, + depth=3, + expand_nested=True, +) + +model_graph.visual_graph + +# %% +# Setup the trainer for prediction +# The trainer can be further configured to better utilize the available hardware, +# For example using GPUs and half precision. +# Callbacks can also be used to customize logging and prediction writing. +# See the API documentation for more details: +# ?VisCyTrainer + +# %% +# Initialize the trainer +# The prediction writer callback will save the predictions to an OME-Zarr store +trainer = VisCyTrainer( + callbacks=[ + EmbeddingWriter( + output_path, + pca_kwargs={"n_components": 8}, + phate_kwargs={"knn": 5, "decay": 40, "n_jobs": -1}, + ) + ] +) + +# Run prediction +trainer.predict(model=dynaclr_model, datamodule=datamodule, return_predictions=False) + +# %% [markdown] +""" +## Model Outputs + +The model outputs are also stored in an ANNData. The embeddings can then be visualized with a dimensionality reduction method (i.e UMAP, PHATE, PCA) +""" +# %% +# NOTE: We have chosen these tracks to be representative of the data. Feel free to open the dataset and select other tracks +features_anndata = read_zarr(output_path) +annotation = pd.read_csv(annotations_path) +ANNOTATION_COLUMN = "infection_state" + +# Combine embeddings and annotations +# Reload annotation to ensure clean state (in case cell is re-run) +annotation = pd.read_csv(annotations_path) + +# Strip whitespace from fov_name to match features +annotation["fov_name"] = annotation["fov_name"].str.strip() + +# Merge on (fov_name, track_id, t) as these uniquely identify each cell observation +annotation_indexed = annotation.set_index(["fov_name", "track_id", "t"]) +mi = pd.MultiIndex.from_arrays( + [ + features_anndata.obs["fov_name"], + features_anndata.obs["track_id"], + features_anndata.obs["t"], + ], + names=["fov_name", "track_id", "t"], +) +features_anndata.obs["annotations_infections_state"] = annotation_indexed.reindex(mi)[ + ANNOTATION_COLUMN +].values + +# Plot the PCA and PHATE embeddings colored by infection state +# Prepare data for plotting +# Map numeric labels to readable labels for legend +infection_state_labels = {0: "Unknown", 1: "Uninfected", 2: "Infected"} + +plot_df = pd.DataFrame( + { + "PC1": features_anndata.obsm["X_pca"][:, 0], + "PC2": features_anndata.obsm["X_pca"][:, 1], + "PHATE1": features_anndata.obsm["X_phate"][:, 0], + "PHATE2": features_anndata.obsm["X_phate"][:, 1], + "infection_state": features_anndata.obs["annotations_infections_state"] + .fillna(0) + .map(infection_state_labels), + } +) + +# Define color palette (colorblind-friendly: blue for uninfected, orange for infected) +color_palette = { + "Unknown": "lightgray", # Unlabeled + "Uninfected": "cornflowerblue", # Uninfected + "Infected": "darkorange", # Infected +} + +# Create figure with two subplots +fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + +# Plot PCA +sns.scatterplot( + data=plot_df, + x="PC1", + y="PC2", + hue="infection_state", + palette=color_palette, + ax=axes[0], + alpha=0.6, + s=20, +) +axes[0].set_title("PCA Embedding") +axes[0].set_xlabel("PC1") +axes[0].set_ylabel("PC2") + +# Plot PHATE +sns.scatterplot( + data=plot_df, + x="PHATE1", + y="PHATE2", + hue="infection_state", + palette=color_palette, + ax=axes[1], + alpha=0.6, + s=20, +) +axes[1].set_title("PHATE Embedding") +axes[1].set_xlabel("PHATE 1") +axes[1].set_ylabel("PHATE 2") + +plt.tight_layout() +plt.show() + + +# %% [markdown] +""" +## Visualize Images Over Time +Below we show phase and fluorescence images of the uninfected and infected cells over time. +""" + +# %% +# NOTE: We have chosen these tracks to be representative of the data. Feel free to open the dataset and select other tracks +fov_name_mock = "A/3/9" +track_id_mock = [19] +fov_name_inf = "B/4/9" +track_id_inf = [42] + + +## Show the images over time +def get_patch(data, cell_centroid, patch_size): + """Extract patch centered on cell centroid across all channels. + + Parameters + ---------- + data : ndarray + Image data with shape (C, Y, X) or (Y, X) + cell_centroid : tuple + (y, x) coordinates of cell centroid + patch_size : int + Size of the square patch to extract + + Returns + ------- + ndarray + Extracted patch with shape (C, patch_size, patch_size) or (patch_size, patch_size) + """ + y_centroid, x_centroid = cell_centroid + x_start = max(0, x_centroid - patch_size // 2) + x_end = min(data.shape[-1], x_centroid + patch_size // 2) + y_start = max(0, y_centroid - patch_size // 2) + y_end = min(data.shape[-2], y_centroid + patch_size // 2) + + if data.ndim == 3: # CYX format + patch = data[:, int(y_start) : int(y_end), int(x_start) : int(x_end)] + else: # YX format + patch = data[int(y_start) : int(y_end), int(x_start) : int(x_end)] + return patch + + +# Open the dataset +plate = open_ome_zarr(input_data_path) +uninfected_position = plate[fov_name_mock] +infected_position = plate[fov_name_inf] + +# Get channel indices for the channels we want to display +channel_names = uninfected_position.channel_names +channels_to_display_idx = [channel_names.index(c) for c in channels_to_display] + +# Filter the centroids of these two tracks +filtered_centroid_mock = features_anndata.obs[ + (features_anndata.obs["fov_name"] == fov_name_mock) + & (features_anndata.obs["track_id"].isin(track_id_mock)) +].sort_values("t") +filtered_centroid_inf = features_anndata.obs[ + (features_anndata.obs["fov_name"] == fov_name_inf) + & (features_anndata.obs["track_id"].isin(track_id_inf)) +].sort_values("t") + +# Define patch size for visualization +patch_size = 160 + +# Extract patches for uninfected cells over time +import numpy as np + +uinfected_stack = [] +for idx, row in filtered_centroid_mock.iterrows(): + t = int(row["t"]) + # Load the image data for this timepoint (CZYX format), select only required channels + img_data = uninfected_position.data[ + t, channels_to_display_idx, z_range[0] : z_range[1] + ] + # For Phase3D take middle slice, for fluorescence take max projection + cyx = [] + for ch_idx, ch_name in enumerate(channels_to_display): + if ch_name == "Phase3D": + # Take middle Z slice for phase + mid_z = img_data.shape[1] // 2 + cyx.append(img_data[ch_idx, mid_z, :, :]) + else: + # Max projection for fluorescence + cyx.append(img_data[ch_idx].max(axis=0)) + cyx = np.array(cyx) + uinfected_stack.append(get_patch(cyx, (row["y"], row["x"]), patch_size)) +uinfected_stack = np.array(uinfected_stack) + +# Extract patches for infected cells over time +infected_stack = [] +for idx, row in filtered_centroid_inf.iterrows(): + t = int(row["t"]) + # Load the image data for this timepoint (CZYX format), select only required channels + img_data = infected_position.data[ + t, channels_to_display_idx, z_range[0] : z_range[1] + ] + # For Phase3D take middle slice, for fluorescence take max projection + cyx = [] + for ch_idx, ch_name in enumerate(channels_to_display): + if ch_name == "Phase3D": + # Take middle Z slice for phase + mid_z = img_data.shape[1] // 2 + cyx.append(img_data[ch_idx, mid_z, :, :]) + else: + # Max projection for fluorescence + cyx.append(img_data[ch_idx].max(axis=0)) + cyx = np.array(cyx) + infected_stack.append(get_patch(cyx, (row["y"], row["x"]), patch_size)) +infected_stack = np.array(infected_stack) + +# Interactive visualization for Google Colab +# This creates an interactive widget to scrub through timepoints +try: + import numpy as np + from ipywidgets import IntSlider, interact + + max_t = min(len(uinfected_stack), len(infected_stack)) + + def plot_timepoint(t): + """Plot both infected and uninfected cells at a specific timepoint""" + fig, axes = plt.subplots(2, 2, figsize=(10, 10)) + fig.suptitle(f"Timepoint: {t}", fontsize=16) + + # Plot uninfected cell + for channel_idx, channel_name in enumerate(channels_to_display): + ax = axes[0, channel_idx] + img = uinfected_stack[t, channel_idx, :, :] + ax.imshow(img, cmap="gray") + ax.set_title(f"Uninfected - {channel_name}") + ax.axis("off") + + # Plot infected cell + channel_names = uninfected_position.channel_names + channels_to_display_idx = [channel_names.index(c) for c in channels_to_display] + for channel_idx, channel_name in enumerate(channels_to_display_idx): + ax = axes[1, channel_idx] + img = infected_stack[t, channel_idx, :, :] + ax.imshow(img, cmap="gray") + ax.set_title(f"Infected - {channel_name}") + ax.axis("off") + + plt.tight_layout() + plt.show() + + # Create interactive slider + interact( + plot_timepoint, + t=IntSlider(min=0, max=max_t - 1, step=1, value=0, description="Timepoint:"), + ) + +except ImportError: + # Fallback to static plot if ipywidgets not available + print("ipywidgets not available, showing static plots instead") + + # Plot 10 equally spaced timepoints + n_timepoints = 10 + max_t = min(len(uinfected_stack), len(infected_stack)) + timepoint_indices = np.linspace(0, max_t - 1, n_timepoints, dtype=int) + + # Create figure with 2 rows (channels) x 10 columns (timepoints) for uninfected + fig, axes = plt.subplots(2, n_timepoints, figsize=(20, 4)) + fig.suptitle("Uninfected Cell Over Time", fontsize=16, y=1.02) + channel_names = uninfected_position.channel_names + channels_to_display_idx = [channel_names.index(c) for c in channels_to_display] + for channel_idx, channel_name in enumerate(channels_to_display): + for col_idx, t_idx in enumerate(timepoint_indices): + ax = axes[channel_idx, col_idx] + img = uinfected_stack[t_idx, channel_idx, :, :] + ax.imshow(img, cmap="gray") + ax.axis("off") + if channel_idx == 0: + ax.set_title(f"t={t_idx}", fontsize=10) + if col_idx == 0: + ax.set_ylabel(channel_name, fontsize=12) + + plt.tight_layout() + plt.show() + + # Create figure with 2 rows (channels) x 10 columns (timepoints) for infected + fig, axes = plt.subplots(2, n_timepoints, figsize=(20, 4)) + fig.suptitle("Infected Cell Over Time", fontsize=16, y=1.02) + + for channel_idx, channel_name in enumerate(channels_to_display): + for col_idx, t_idx in enumerate(timepoint_indices): + ax = axes[channel_idx, col_idx] + img = infected_stack[t_idx, channel_idx, :, :] + ax.imshow(img, cmap="gray") + ax.axis("off") + if channel_idx == 0: + ax.set_title(f"t={t_idx}", fontsize=10) + if col_idx == 0: + ax.set_ylabel(channel_name, fontsize=12) + + plt.tight_layout() + plt.show() + +# %% [markdown] +""" +## Contact Information +For issues with this notebook please contact eduardo.hirata@czbiohub.org. + +## Responsible Use + +We are committed to advancing the responsible development and use of artificial intelligence. +Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services. + +Should you have any security or privacy issues or questions related to the services, +please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively. +"""