Skip to content
402 changes: 402 additions & 0 deletions examples/DynaCLR/evaluation/vae_reconstruction_explorer.py

Large diffs are not rendered by default.

124 changes: 94 additions & 30 deletions examples/transforms/batched_transforms.ipynb

Large diffs are not rendered by default.

208 changes: 208 additions & 0 deletions tests/utils/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import pytest

from viscy.utils.scheduler import ParameterScheduler


def test_scheduler_constant():
"""Test constant schedule returns target value at all epochs."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.1,
target_value=1.0,
warmup_epochs=10,
schedule_type="constant",
)

assert scheduler.get_value(0) == 1.0
assert scheduler.get_value(5) == 1.0
assert scheduler.get_value(10) == 1.0
assert scheduler.get_value(100) == 1.0


def test_scheduler_linear_warmup():
"""Test linear schedule interpolates from initial to target."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=1.0,
warmup_epochs=10,
schedule_type="linear",
)

# At epoch 0, should be at initial
assert abs(scheduler.get_value(0) - 0.0) < 0.01

# At epoch 5 (halfway), should be approximately 0.5
assert abs(scheduler.get_value(5) - 0.5) < 0.01

# At epoch 10 (end), should be at target
assert abs(scheduler.get_value(10) - 1.0) < 0.01

# After warmup, should stay at target
assert abs(scheduler.get_value(20) - 1.0) < 0.01


def test_scheduler_cosine_warmup():
"""Test cosine schedule increases smoothly from initial to target."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=1.0,
warmup_epochs=10,
schedule_type="cosine",
)

# At epoch 0, should be close to initial
assert scheduler.get_value(0) < 0.1

# At halfway through warmup
val_halfway = scheduler.get_value(5)
assert 0.0 < val_halfway < 1.0

# At end of warmup, should be at target
assert abs(scheduler.get_value(10) - 1.0) < 0.01

# After warmup, should stay at target
assert abs(scheduler.get_value(20) - 1.0) < 0.01


def test_scheduler_warmup_step():
"""Test warmup (step function) schedule."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.1,
target_value=1.0,
warmup_epochs=10,
schedule_type="warmup",
)

# Before warmup completes, should be at initial
assert abs(scheduler.get_value(0) - 0.1) < 0.01
assert abs(scheduler.get_value(5) - 0.1) < 0.01
assert abs(scheduler.get_value(9) - 0.1) < 0.01

# At and after warmup, should jump to target
assert abs(scheduler.get_value(10) - 1.0) < 0.01
assert abs(scheduler.get_value(20) - 1.0) < 0.01


def test_scheduler_min_value_clipping():
"""Test that values are clipped to min_value."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=0.001,
warmup_epochs=10,
schedule_type="linear",
min_value=0.01, # Higher than target
)

# All values should be clipped to min_value
assert scheduler.get_value(0) >= 0.01
assert scheduler.get_value(5) >= 0.01
assert scheduler.get_value(10) >= 0.01


def test_scheduler_negative_warmup_epochs_raises():
"""Test that negative warmup_epochs raises ValueError."""
with pytest.raises(ValueError, match="warmup_epochs must be >= 0"):
ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=1.0,
warmup_epochs=-5,
schedule_type="linear",
)


def test_scheduler_invalid_schedule_type_raises():
"""Test that invalid schedule_type raises ValueError."""
with pytest.raises(ValueError, match="Invalid schedule_type"):
ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=1.0,
warmup_epochs=10,
schedule_type="invalid_schedule", # type: ignore
)


def test_scheduler_repr():
"""Test scheduler string representation."""
scheduler = ParameterScheduler(
param_name="beta",
initial_value=0.1,
target_value=1.0,
warmup_epochs=50,
schedule_type="linear",
)

repr_str = repr(scheduler)
assert "beta" in repr_str
assert "0.1" in repr_str
assert "1.0" in repr_str
assert "50" in repr_str
assert "linear" in repr_str


def test_scheduler_zero_warmup_epochs():
"""Test scheduler with zero warmup epochs."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=1.0,
warmup_epochs=0,
schedule_type="linear",
)

# With zero warmup, should immediately be at target
assert abs(scheduler.get_value(0) - 1.0) < 0.01
assert abs(scheduler.get_value(10) - 1.0) < 0.01


def test_scheduler_decreasing_schedule():
"""Test scheduler can decrease from high to low value."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=1.0,
target_value=0.1,
warmup_epochs=10,
schedule_type="linear",
)

# At epoch 0, should be at initial (high)
assert abs(scheduler.get_value(0) - 1.0) < 0.01

# At halfway, should be decreasing
assert 0.1 < scheduler.get_value(5) < 1.0

# At end, should be at target (low)
assert abs(scheduler.get_value(10) - 0.1) < 0.01


def test_scheduler_linear_exact_values():
"""Test linear schedule produces exact expected values."""
scheduler = ParameterScheduler(
param_name="test_param",
initial_value=0.0,
target_value=10.0,
warmup_epochs=10,
schedule_type="linear",
)

# Test exact values at key points
expected_values = [
(0, 0.0),
(1, 1.0),
(2, 2.0),
(5, 5.0),
(9, 9.0),
(10, 10.0),
(15, 10.0),
]

for epoch, expected in expected_values:
actual = scheduler.get_value(epoch)
assert abs(actual - expected) < 0.01, (
f"At epoch {epoch}, expected {expected}, got {actual}"
)
65 changes: 52 additions & 13 deletions viscy/representation/embedding_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def write_embedding_dataset(
features: np.ndarray,
index_df: pd.DataFrame,
projections: Optional[np.ndarray] = None,
logvar: Optional[np.ndarray] = None,
z: Optional[np.ndarray] = None,
umap_kwargs: Optional[Dict[str, Any]] = None,
phate_kwargs: Optional[Dict[str, Any]] = None,
pca_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -105,10 +107,18 @@ def write_embedding_dataset(
Path to the zarr store.
features : np.ndarray
Array of shape (n_samples, n_features) containing the embeddings.
For VAE models, this is mu. For contrastive models, this is features.
index_df : pd.DataFrame
DataFrame containing the index information for each embedding.
projections : np.ndarray, optional
Array of shape (n_samples, n_projections) containing projection values, by default None.
Only used for contrastive models.
logvar : np.ndarray, optional
Array of shape (n_samples, n_features) containing log variance values, by default None.
Only used for VAE models to store uncertainty information.
z : np.ndarray, optional
Array of shape (n_samples, latent_dim) containing sampled latent codes, by default None.
Only used for VAE models. Can be used to regenerate reconstructions later.
umap_kwargs : Dict[str, Any], optional
Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed)
Common parameters include:
Expand Down Expand Up @@ -149,8 +159,14 @@ def write_embedding_dataset(
n_samples = len(features)

adata = ad.AnnData(X=features, obs=ultrack_indices)

# Store model-specific outputs
if projections is not None:
adata.obsm["X_projections"] = projections
if logvar is not None:
adata.obsm["X_logvar"] = logvar
if z is not None:
adata.obsm["X_z"] = z

# Set up default kwargs for each method
if umap_kwargs:
Expand Down Expand Up @@ -253,23 +269,46 @@ def write_on_epoch_end(
trainer : Trainer
Placeholder, ignored.
pl_module : LightningModule
Placeholder, ignored.
Lightning module (ContrastiveModule or BetaVaeModule).
predictions : Sequence[ContrastivePrediction]
Sequence of output from the prediction steps.
batch_indices : Sequence[int]
Placeholder, ignored.
"""
features = _move_and_stack_embeddings(predictions, "features")
projections = _move_and_stack_embeddings(predictions, "projections")
from viscy.representation.engine import BetaVaeModule

ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions])

write_embedding_dataset(
output_path=self.output_path,
features=features,
index_df=ultrack_indices,
projections=projections,
umap_kwargs=self.umap_kwargs,
phate_kwargs=self.phate_kwargs,
pca_kwargs=self.pca_kwargs,
overwrite=self.overwrite,
)
# Detect model type and extract appropriate embeddings
if isinstance(pl_module, BetaVaeModule):
# VAE model: mu is primary embedding, logvar and z are additional info
features = _move_and_stack_embeddings(predictions, "mu")
logvar = _move_and_stack_embeddings(predictions, "logvar")
z = _move_and_stack_embeddings(predictions, "z")

write_embedding_dataset(
output_path=self.output_path,
features=features,
index_df=ultrack_indices,
logvar=logvar,
z=z,
umap_kwargs=self.umap_kwargs,
phate_kwargs=self.phate_kwargs,
pca_kwargs=self.pca_kwargs,
overwrite=self.overwrite,
)
else:
# Contrastive model: features and projections
features = _move_and_stack_embeddings(predictions, "features")
projections = _move_and_stack_embeddings(predictions, "projections")

write_embedding_dataset(
output_path=self.output_path,
features=features,
index_df=ultrack_indices,
projections=projections,
umap_kwargs=self.umap_kwargs,
phate_kwargs=self.phate_kwargs,
pca_kwargs=self.pca_kwargs,
overwrite=self.overwrite,
)
Loading