-
Notifications
You must be signed in to change notification settings - Fork 12
Implementation of Beta VAE for benchmarking #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for β-VAE (Beta Variational Autoencoder) models and enhances the representation learning capabilities of VisCy. It includes VAE architectures, logging utilities, evaluation metrics, and data handling improvements.
Key Changes:
- Added β-VAE model architectures (2.5D and MONAI-based) with encoder/decoder implementations
- Implemented comprehensive VAE logging utilities for training monitoring and latent space visualization
- Enhanced evaluation metrics including smoothness analysis, displacement computation, and GPU-accelerated distance calculations
- Added cell division triplet dataset for .npy file handling
- Improved data module with validation augmentation control
Reviewed Changes
Copilot reviewed 31 out of 42 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
viscy/transforms/_redef.py |
Added NormalizeIntensityd import and class wrapper; moved RandFlipd inside another class (indentation issue) |
viscy/transforms/__init__.py |
Exported NormalizeIntensityd transform |
viscy/representation/vae_logging.py |
New comprehensive VAE logging utilities for metrics, visualizations, and diagnostics |
viscy/representation/vae.py |
New VAE model implementations (encoder, decoder, 2.5D and MONAI variants) |
viscy/representation/engine.py |
Added BetaVaeModule Lightning wrapper and removed log_embeddings parameter |
viscy/representation/multi_modal.py |
Added embedding_log_frequency parameter |
viscy/representation/evaluation/smoothness.py |
New smoothness metrics computation for embeddings |
viscy/representation/evaluation/lca.py |
Enhanced logistic regression with train_ratio parameter and stratified sampling |
viscy/representation/evaluation/distance.py |
Refactored displacement computation using pairwise distance matrix |
viscy/representation/evaluation/dimensionality_reduction.py |
Added scaling option and random_state to PHATE computation |
viscy/representation/evaluation/clustering.py |
Added GPU-accelerated pairwise distance computation with PyTorch |
viscy/data/triplet.py |
Added augment_validation parameter for controlled augmentation |
viscy/data/cell_division_triplet.py |
New dataset and data module for cell division .npy files |
| Application scripts | Various evaluation, visualization, and benchmarking scripts |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py
Outdated
Show resolved
Hide resolved
applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py
Outdated
Show resolved
Hide resolved
applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py
Outdated
Show resolved
Hide resolved
applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py
Outdated
Show resolved
Hide resolved
applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…ings_sam2.py Co-authored-by: Copilot <[email protected]>
…ings_sam2.py Co-authored-by: Copilot <[email protected]>
…pute_smoothness.py Co-authored-by: Copilot <[email protected]>
…ings_sam2.py Co-authored-by: Copilot <[email protected]>
…pute_smoothness.py Co-authored-by: Copilot <[email protected]>
…SD_v2.py Co-authored-by: Copilot <[email protected]>
mattersoflight
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving and will do integration test by trying to reduce a couple of panels.
* simualte different embeddings * update the msd calculation to re-use cdist functions in the repo * adding a test for the msd * removing unused msd functions * renaming msd to compute_track_displacement * default to cosine distance * adding the gradient attribution video. * extend to training ratios * demo beta_vae 2.5D * improving the logging for readability and drop pythae baseclasses * condense the logging to have less tabs. * fix disentagle metrics * fixing beta warmup bug * renaming to loss * updating architecture to flatten vs spatial VAE with convs * chaning to use mse with mean reduction and normalizing the kl loss by batch size. * optunea proof of concept * add normalized sampled into the transforms so we can use it with MONAIs vae * update loss debugging code * adding sync for disentaglement metrics * adding the dataloader for rpe1 dataset and plotting utils * cleanup the vae and add the monai to lightning. adding configs * add saving hyperparameters * fix hyperparameter logging * add embedding logging to the CLIP version * test and plot of monaivae * handle monai_vae 2d * redifining rotation agumentsations * adding optional scaling to phate * adding alias and output 2d * normalizing by also the latent dim and swapping to FP32 for forward pass to avoid overflow with log and exp * update test for magnitudes * expose the normalization for vae * add sam 2 test * refactor smoothness metrics * rever to normalalize kl wrt to batch size and removing the the beta min value * commit dtwembeddings w sam * added a clamp to logvar, switch to mse loss sum reduction like the original formulation. * remove unecessary vae logging losses. * add a way to handle when using 'mean' reduction for proper scaling * adding optional config for middle slice index for computing sam2 embeddings and dinov3 * converting latent stats active_dimensions parameter to float to remove warning * ruff * removing the optuna config * numpy docstring * fix compute smoothness script * archiving old scripts * re org the pc features scripts * embeddings for phase * add smoothness (mean rand vs adj frame) to the csv * archiving old beta vae code * ruff * fix format * fix typo * remove the archived unecessary files * remove the test run archived file * adding normalizeintensity * fixing the vae_logging typing and removing PC plotting from here * fixing the compute_embedding_smoothness docstring * simplify the distance metrics and removing deprecated functions and scripts * remove deprecated functions from clustering.py * add timelpase to grad_attr.py script * refactoring the betavaemodule. removing the hyperparamter logging, adding the nn.Module as input for typing purposes and removing the fp32 custom fwd * remove the optuna dependency * deleting old msd test * ruff format * fix to explicitly stratify on fov level * adding reference to dataset for rpe1 * fix pyproject.toml dev * format and lint * restore no-augmentation flag effect * format tests * rename the sam2 file * removing unused arguments for logging embeddings. * removing duplication in the lca * remove disentaglement metrics * vectorized the anchor filtering for celldivisiontriplet dataset * map the channels to the rpe dataset convention * fix logistic regresion standardization * update rpe classifier to include mitosis * ruff * remove unused logging * datamodule agnostic * cleaning up duplicated code in the benchmarking * cleanup vae * keeping it consistent and using residual units * fix typings betavaemonai * update smoothness to handle adata * update clustering method and add test * pre-commit * Update viscy/data/cell_division_triplet.py Co-authored-by: Copilot <[email protected]> * Update applications/benchmarking/DynaCLR/SAM2/sam2_visualizations.py Co-authored-by: Copilot <[email protected]> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <[email protected]> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <[email protected]> * Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <[email protected]> * Update applications/pseudotime_analysis/evaluation/compare_dtw_embeddings_sam2.py Co-authored-by: Copilot <[email protected]> * Update applications/contrastive_phenotyping/evaluation/smoothness/compute_smoothness.py Co-authored-by: Copilot <[email protected]> * Update applications/contrastive_phenotyping/evaluation/archive/ALFI_MSD_v2.py Co-authored-by: Copilot <[email protected]> * valuerror on the fidn peaks function * add literal to the betavae25d normalization * clipping similarity that was breaking the tests --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Copilot <[email protected]>
This PR adds the following: