Skip to content

Commit e00448b

Browse files
edyoshikunziw-liuCopilot
authored
Implementation of Beta VAE for benchmarking (#273)
* simualte different embeddings * update the msd calculation to re-use cdist functions in the repo * adding a test for the msd * removing unused msd functions * renaming msd to compute_track_displacement * default to cosine distance * adding the gradient attribution video. * extend to training ratios * demo beta_vae 2.5D * improving the logging for readability and drop pythae baseclasses * condense the logging to have less tabs. * fix disentagle metrics * fixing beta warmup bug * renaming to loss * updating architecture to flatten vs spatial VAE with convs * chaning to use mse with mean reduction and normalizing the kl loss by batch size. * optunea proof of concept * add normalized sampled into the transforms so we can use it with MONAIs vae * update loss debugging code * adding sync for disentaglement metrics * adding the dataloader for rpe1 dataset and plotting utils * cleanup the vae and add the monai to lightning. adding configs * add saving hyperparameters * fix hyperparameter logging * add embedding logging to the CLIP version * test and plot of monaivae * handle monai_vae 2d * redifining rotation agumentsations * adding optional scaling to phate * adding alias and output 2d * normalizing by also the latent dim and swapping to FP32 for forward pass to avoid overflow with log and exp * update test for magnitudes * expose the normalization for vae * add sam 2 test * refactor smoothness metrics * rever to normalalize kl wrt to batch size and removing the the beta min value * commit dtwembeddings w sam * added a clamp to logvar, switch to mse loss sum reduction like the original formulation. * remove unecessary vae logging losses. * add a way to handle when using 'mean' reduction for proper scaling * adding optional config for middle slice index for computing sam2 embeddings and dinov3 * converting latent stats active_dimensions parameter to float to remove warning * ruff * removing the optuna config * numpy docstring * fix compute smoothness script * archiving old scripts * re org the pc features scripts * embeddings for phase * add smoothness (mean rand vs adj frame) to the csv * archiving old beta vae code * ruff * fix format * fix typo * remove the archived unecessary files * remove the test run archived file * adding normalizeintensity * fixing the vae_logging typing and removing PC plotting from here * fixing the compute_embedding_smoothness docstring * simplify the distance metrics and removing deprecated functions and scripts * remove deprecated functions from clustering.py * add timelpase to grad_attr.py script * refactoring the betavaemodule. removing the hyperparamter logging, adding the nn.Module as input for typing purposes and removing the fp32 custom fwd * remove the optuna dependency * deleting old msd test * ruff format * fix to explicitly stratify on fov level * adding reference to dataset for rpe1 * fix pyproject.toml dev * format and lint * restore no-augmentation flag effect * format tests * rename the sam2 file * removing unused arguments for logging embeddings. * removing duplication in the lca * remove disentaglement metrics * vectorized the anchor filtering for celldivisiontriplet dataset * map the channels to the rpe dataset convention * fix logistic regresion standardization * update rpe classifier to include mitosis * ruff * remove unused logging * datamodule agnostic * cleaning up duplicated code in the benchmarking * cleanup vae * keeping it consistent and using residual units * fix typings betavaemonai * update smoothness to handle adata * update clustering method and add test * pre-commit * Update viscy/data/cell_division_triplet.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * valuerror on the fidn peaks function * add literal to the betavae25d normalization * clipping similarity that was breaking the tests --------- Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b56dc9f commit e00448b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4668
-904
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
datamodule_class: viscy.data.triplet.TripletDataModule
2+
datamodule:
3+
data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr
4+
tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr
5+
batch_size: 32
6+
final_yx_patch_size:
7+
- 256
8+
- 256
9+
include_fov_names: null
10+
include_track_ids: null
11+
initial_yx_patch_size:
12+
- 256
13+
- 256
14+
normalizations:
15+
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd
16+
init_args:
17+
b_max: 1.0
18+
b_min: 0.0
19+
keys:
20+
- RFP
21+
lower: 50
22+
upper: 99
23+
- class_path: viscy.transforms.NormalizeIntensityd
24+
init_args:
25+
keys:
26+
- Phase3D
27+
num_workers: 10
28+
source_channel:
29+
- RFP
30+
- Phase3D
31+
z_range:
32+
- 15
33+
- 45
34+
35+
embedding:
36+
pca_kwargs:
37+
n_components: 8
38+
phate_kwargs:
39+
decay: 40
40+
knn: 5
41+
n_components: 2
42+
n_jobs: -1
43+
random_state: 42
44+
reductions:
45+
- PHATE
46+
- PCA
47+
48+
execution:
49+
overwrite: false
50+
save_config: true
51+
show_config: true
52+
53+
model:
54+
model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m
55+
pooling_method: mean # Options: "mean", "max", "cls_token"
56+
middle_slice_index: 18 # Specific z-slice index (if null, uses D//2)
57+
channel_reduction_methods:
58+
Phase3D: middle_slice
59+
RFP: max
60+
channel_names:
61+
- RFP
62+
- Phase3D
63+
64+
paths:
65+
output_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/DINOV3/embeddings_convnext_tiny_mean.zarr
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import sys
2+
from pathlib import Path
3+
from typing import Dict, List, Literal, Optional
4+
5+
import numpy as np
6+
import torch
7+
from PIL import Image
8+
from skimage.exposure import rescale_intensity
9+
from transformers import AutoImageProcessor, AutoModel
10+
11+
sys.path.append(str(Path(__file__).parent.parent))
12+
13+
from base_embedding_module import BaseEmbeddingModule, create_embedding_cli
14+
15+
16+
class DINOv3Module(BaseEmbeddingModule):
17+
def __init__(
18+
self,
19+
model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m",
20+
channel_reduction_methods: Optional[
21+
Dict[str, Literal["middle_slice", "mean", "max"]]
22+
] = None,
23+
channel_names: Optional[List[str]] = None,
24+
pooling_method: Literal["mean", "max", "cls_token"] = "mean",
25+
middle_slice_index: Optional[int] = None,
26+
):
27+
super().__init__(channel_reduction_methods, channel_names, middle_slice_index)
28+
self.model_name = model_name
29+
self.pooling_method = pooling_method
30+
31+
self.model = None
32+
self.processor = None
33+
34+
@classmethod
35+
def from_config(cls, cfg):
36+
"""Create model instance from configuration."""
37+
model_config = cfg.get("model", {})
38+
return cls(
39+
model_name=model_config.get(
40+
"model_name", "facebook/dinov3-vitb16-pretrain-lvd1689m"
41+
),
42+
pooling_method=model_config.get("pooling_method", "mean"),
43+
channel_reduction_methods=model_config.get("channel_reduction_methods", {}),
44+
channel_names=model_config.get("channel_names", []),
45+
middle_slice_index=model_config.get("middle_slice_index", None),
46+
)
47+
48+
def on_predict_start(self):
49+
if self.model is None:
50+
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
51+
self.model = AutoModel.from_pretrained(self.model_name)
52+
self.model.eval()
53+
self.model.to(self.device)
54+
55+
def _process_input(self, x: torch.Tensor):
56+
"""Convert tensor to PIL Images for DINOv3 processing."""
57+
return self._convert_to_pil_images(x)
58+
59+
def _extract_features(self, pil_images):
60+
"""Extract features using DINOv3 model."""
61+
inputs = self.processor(pil_images, return_tensors="pt")
62+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
63+
64+
with torch.no_grad():
65+
outputs = self.model(**inputs)
66+
token_features = outputs.last_hidden_state
67+
features = self._pool_features(token_features)
68+
69+
return features
70+
71+
def _convert_to_pil_images(self, x: torch.Tensor) -> List[Image.Image]:
72+
"""
73+
Convert tensor to list of PIL Images for DINOv3 processing.
74+
75+
Parameters
76+
----------
77+
x : torch.Tensor
78+
Input tensor with shape (B, C, H, W).
79+
80+
Returns
81+
-------
82+
list of PIL.Image.Image
83+
List of PIL Images ready for DINOv3 processing.
84+
"""
85+
images = []
86+
87+
for b in range(x.shape[0]):
88+
img_tensor = x[b] # (C, H, W)
89+
90+
if img_tensor.shape[0] == 1:
91+
# Single channel - convert to grayscale PIL
92+
img_array = img_tensor[0].cpu().numpy()
93+
# Normalize to 0-255
94+
img_normalized = (
95+
(img_array - img_array.min())
96+
/ (img_array.max() - img_array.min())
97+
* 255
98+
).astype(np.uint8)
99+
pil_img = Image.fromarray(img_normalized, mode="L")
100+
101+
elif img_tensor.shape[0] == 2:
102+
img_array = img_tensor.cpu().numpy()
103+
rgb_array = np.zeros(
104+
(img_array.shape[1], img_array.shape[2], 3), dtype=np.uint8
105+
)
106+
107+
ch0_norm = rescale_intensity(img_array[0], out_range=(0, 255)).astype(
108+
np.uint8
109+
)
110+
ch1_norm = rescale_intensity(img_array[1], out_range=(0, 255)).astype(
111+
np.uint8
112+
)
113+
114+
rgb_array[:, :, 0] = ch0_norm # Red
115+
rgb_array[:, :, 1] = ch1_norm # Green
116+
rgb_array[:, :, 2] = (ch0_norm + ch1_norm) // 2 # Blue
117+
118+
pil_img = Image.fromarray(rgb_array, mode="RGB")
119+
120+
elif img_tensor.shape[0] == 3:
121+
# Three channels - direct RGB
122+
img_array = img_tensor.cpu().numpy().transpose(1, 2, 0) # HWC
123+
img_normalized = rescale_intensity(
124+
img_array, out_range=(0, 255)
125+
).astype(np.uint8)
126+
pil_img = Image.fromarray(img_normalized, mode="RGB")
127+
128+
else:
129+
# More than 3 channels - use first 3
130+
img_array = img_tensor[:3].cpu().numpy().transpose(1, 2, 0) # HWC
131+
img_normalized = rescale_intensity(
132+
img_array, out_range=(0, 255)
133+
).astype(np.uint8)
134+
pil_img = Image.fromarray(img_normalized, mode="RGB")
135+
136+
images.append(pil_img)
137+
138+
return images
139+
140+
def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
141+
"""
142+
Pool spatial features from DINOv3 tokens.
143+
144+
Parameters
145+
----------
146+
features : torch.Tensor
147+
Token features with shape (B, num_tokens, hidden_dim).
148+
149+
Returns
150+
-------
151+
torch.Tensor
152+
Pooled features with shape (B, hidden_dim).
153+
"""
154+
if self.pooling_method == "cls_token":
155+
# For ViT models, first token is usually CLS token
156+
if "vit" in self.model_name.lower():
157+
return features[:, 0, :] # CLS token
158+
else:
159+
# For ConvNeXt, no CLS token, fall back to mean
160+
return features.mean(dim=1)
161+
162+
elif self.pooling_method == "max":
163+
return features.max(dim=1)[0]
164+
else: # mean pooling
165+
return features.mean(dim=1)
166+
167+
168+
if __name__ == "__main__":
169+
main = create_embedding_cli(DINOv3Module, "DINOv3")
170+
main()

applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
# Paths section
44
paths:
5-
data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr
6-
tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr
75
output_path: "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_sec61b_n_phase_3.zarr"
86

97
# Model configuration
@@ -16,7 +14,10 @@ model:
1614
"raw GFP EX488 EM525-45": "max"
1715

1816
# Data module configuration
17+
datamodule_class: viscy.data.triplet.TripletDataModule
1918
datamodule:
19+
data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr
20+
tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr
2021
source_channel:
2122
- Phase3D
2223
- "raw GFP EX488 EM525-45"

0 commit comments

Comments
 (0)