Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,58 @@ def tracks_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
)
fake_tracks.to_csv(dataset_path / fov_name / "tracks.csv", index=False)
return dataset_path


@fixture(scope="function")
def tracks_with_gaps_dataset(tmp_path_factory: TempPathFactory) -> Path:
"""Provides a HCS OME-Zarr dataset with tracking results with gaps in time."""
dataset_path = tmp_path_factory.mktemp("tracks_gaps.zarr")
_build_hcs(dataset_path, ["nuclei_labels"], (1, 256, 256), np.uint16, 3)

# Define different track patterns for different FOVs
track_patterns = {
"A/1/0": [
# Track 0: complete sequence t=[0,1,2,3]
{"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0},
{"track_id": 0, "t": 1, "y": 128, "x": 128, "id": 1},
{"track_id": 0, "t": 2, "y": 128, "x": 128, "id": 2},
{"track_id": 0, "t": 3, "y": 128, "x": 128, "id": 3},
# Track 1: ends early t=[0,1]
{"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 4},
{"track_id": 1, "t": 1, "y": 100, "x": 100, "id": 5},
],
"A/1/1": [
# Track 0: gap at t=2, has t=[0,1,3]
{"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0},
{"track_id": 0, "t": 1, "y": 128, "x": 128, "id": 1},
{"track_id": 0, "t": 3, "y": 128, "x": 128, "id": 2},
# Track 1: even timepoints only t=[0,2,4]
{"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 3},
{"track_id": 1, "t": 2, "y": 100, "x": 100, "id": 4},
{"track_id": 1, "t": 4, "y": 100, "x": 100, "id": 5},
],
"A/2/0": [
# Track 0: single timepoint t=[0]
{"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0},
# Track 1: complete short sequence t=[0,1,2]
{"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 1},
{"track_id": 1, "t": 1, "y": 100, "x": 100, "id": 2},
{"track_id": 1, "t": 2, "y": 100, "x": 100, "id": 3},
],
}

for fov_name, _ in open_ome_zarr(dataset_path).positions():
if fov_name in track_patterns:
tracks_data = track_patterns[fov_name]
else:
# Default tracks for other FOVs
tracks_data = [
{"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0},
]

tracks_df = pd.DataFrame(tracks_data)
tracks_df["parent_track_id"] = -1
tracks_df["parent_id"] = -1
tracks_df.to_csv(dataset_path / fov_name / "tracks.csv", index=False)

return dataset_path
243 changes: 242 additions & 1 deletion tests/data/test_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from iohub import open_ome_zarr
from pytest import mark

from viscy.data.triplet import TripletDataModule
from viscy.data.triplet import TripletDataModule, TripletDataset


@mark.parametrize("include_wells", [None, ["A/1", "A/2", "B/1"]])
Expand Down Expand Up @@ -109,3 +109,244 @@ def test_datamodule_z_window_size(
expected_z_shape,
*yx_patch_size,
)


def test_filter_anchors_time_interval_any(
preprocessed_hcs_dataset, tracks_with_gaps_dataset
):
"""Test that time_interval='any' returns all tracks unchanged."""
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
channel_names = dataset.channel_names
positions = list(dataset.positions())

# Create dataset with time_interval="any"
tracks_tables = []
for fov_name, _ in positions:
tracks_df = pd.read_csv(
next((tracks_with_gaps_dataset / fov_name).glob("*.csv"))
).astype(int)
tracks_tables.append(tracks_df)

total_tracks = sum(len(df) for df in tracks_tables)

ds = TripletDataset(
positions=[pos for _, pos in positions],
tracks_tables=tracks_tables,
channel_names=channel_names,
initial_yx_patch_size=(64, 64),
z_range=slice(4, 9),
fit=True,
time_interval="any",
)

# Should return all tracks
assert len(ds.valid_anchors) == total_tracks


def test_filter_anchors_time_interval_1(
preprocessed_hcs_dataset, tracks_with_gaps_dataset
):
"""Test filtering with time_interval=1."""
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
channel_names = dataset.channel_names
positions = list(dataset.positions())

tracks_tables = []
for fov_name, _ in positions:
tracks_df = pd.read_csv(
next((tracks_with_gaps_dataset / fov_name).glob("*.csv"))
).astype(int)
tracks_tables.append(tracks_df)

ds = TripletDataset(
positions=[pos for _, pos in positions],
tracks_tables=tracks_tables,
channel_names=channel_names,
initial_yx_patch_size=(64, 64),
z_range=slice(4, 9),
fit=True,
time_interval=1,
)

# Check expected anchors per FOV/track
valid_anchors = ds.valid_anchors

# FOV A/1/0, Track 0: t=[0,1,2,3] -> valid anchors at t=[0,1,2]
fov_a10_track0 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 0)
]
assert set(fov_a10_track0["t"]) == {0, 1, 2}

# FOV A/1/0, Track 1: t=[0,1] -> valid anchor at t=[0]
fov_a10_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 1)
]
assert set(fov_a10_track1["t"]) == {0}

# FOV A/1/1, Track 0: t=[0,1,3] -> valid anchor at t=[0] only (t=1 has no t+1=2)
fov_a11_track0 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 0)
]
assert set(fov_a11_track0["t"]) == {0}

# FOV A/1/1, Track 1: t=[0,2,4] -> no valid anchors (gaps of 2, no consecutive t+1)
fov_a11_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 1)
]
assert len(fov_a11_track1) == 0

# FOV A/2/0, Track 0: t=[0] -> no valid anchors (no t+1)
fov_a20_track0 = valid_anchors[
(valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 0)
]
assert len(fov_a20_track0) == 0

# FOV A/2/0, Track 1: t=[0,1,2] -> valid anchors at t=[0,1]
fov_a20_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 1)
]
assert set(fov_a20_track1["t"]) == {0, 1}


def test_filter_anchors_time_interval_2(
preprocessed_hcs_dataset, tracks_with_gaps_dataset
):
"""Test filtering with time_interval=2."""
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
channel_names = dataset.channel_names
positions = list(dataset.positions())

tracks_tables = []
for fov_name, _ in positions:
tracks_df = pd.read_csv(
next((tracks_with_gaps_dataset / fov_name).glob("*.csv"))
).astype(int)
tracks_tables.append(tracks_df)

ds = TripletDataset(
positions=[pos for _, pos in positions],
tracks_tables=tracks_tables,
channel_names=channel_names,
initial_yx_patch_size=(64, 64),
z_range=slice(4, 9),
fit=True,
time_interval=2,
)

valid_anchors = ds.valid_anchors

# FOV A/1/0, Track 0: t=[0,1,2,3] -> valid anchors at t=[0,1] (t+2 available)
fov_a10_track0 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 0)
]
assert set(fov_a10_track0["t"]) == {0, 1}

# FOV A/1/0, Track 1: t=[0,1] -> no valid anchors (no t+2)
fov_a10_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 1)
]
assert len(fov_a10_track1) == 0

# FOV A/1/1, Track 0: t=[0,1,3] -> valid anchor at t=[1] (t=1+2=3 exists)
fov_a11_track0 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 0)
]
assert set(fov_a11_track0["t"]) == {1}

# FOV A/1/1, Track 1: t=[0,2,4] -> valid anchors at t=[0,2]
fov_a11_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 1)
]
assert set(fov_a11_track1["t"]) == {0, 2}

# FOV A/2/0, Track 1: t=[0,1,2] -> valid anchor at t=[0]
fov_a20_track1 = valid_anchors[
(valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 1)
]
assert set(fov_a20_track1["t"]) == {0}


def test_filter_anchors_cross_fov_independence(
preprocessed_hcs_dataset, tracks_with_gaps_dataset
):
"""Test that same track_id in different FOVs are treated independently."""
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
channel_names = dataset.channel_names
positions = list(dataset.positions())

tracks_tables = []
for fov_name, _ in positions:
tracks_df = pd.read_csv(
next((tracks_with_gaps_dataset / fov_name).glob("*.csv"))
).astype(int)
tracks_tables.append(tracks_df)

ds = TripletDataset(
positions=[pos for _, pos in positions],
tracks_tables=tracks_tables,
channel_names=channel_names,
initial_yx_patch_size=(64, 64),
z_range=slice(4, 9),
fit=True,
time_interval=1,
)

# Check global_track_id format and uniqueness
assert "global_track_id" in ds.tracks.columns
global_track_ids = ds.tracks["global_track_id"].unique()

# Verify format: should be "fov_name_track_id"
for gid in global_track_ids:
assert "_" in gid
fov_part, track_id_part = gid.rsplit("_", 1)
assert "/" in fov_part # FOV names contain slashes like "A/1/0"

# Track 0 exists in multiple FOVs (A/1/0, A/1/1, A/2/0) but should have different global_track_ids
track0_global_ids = ds.tracks[ds.tracks["track_id"] == 0][
"global_track_id"
].unique()
assert len(track0_global_ids) >= 3 # At least 3 different FOVs with track_id=0

# Verify that filtering is independent per FOV
# A/1/0 Track 0 (continuous) should have more valid anchors than A/1/1 Track 0 (with gap)
valid_a10_track0 = ds.valid_anchors[
(ds.valid_anchors["fov_name"] == "A/1/0") & (ds.valid_anchors["track_id"] == 0)
]
valid_a11_track0 = ds.valid_anchors[
(ds.valid_anchors["fov_name"] == "A/1/1") & (ds.valid_anchors["track_id"] == 0)
]
# A/1/0 Track 0 has t=[0,1,2] valid (3 anchors)
# A/1/1 Track 0 has t=[0] valid (1 anchor, gap at t=2)
assert len(valid_a10_track0) == 3
assert len(valid_a11_track0) == 1


def test_filter_anchors_predict_mode(
preprocessed_hcs_dataset, tracks_with_gaps_dataset
):
"""Test that predict mode (fit=False) returns all tracks regardless of time_interval."""
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
channel_names = dataset.channel_names
positions = list(dataset.positions())

tracks_tables = []
for fov_name, _ in positions:
tracks_df = pd.read_csv(
next((tracks_with_gaps_dataset / fov_name).glob("*.csv"))
).astype(int)
tracks_tables.append(tracks_df)

total_tracks = sum(len(df) for df in tracks_tables)

ds = TripletDataset(
positions=[pos for _, pos in positions],
tracks_tables=tracks_tables,
channel_names=channel_names,
initial_yx_patch_size=(64, 64),
z_range=slice(4, 9),
fit=False, # Predict mode
time_interval=1,
)

# Should return all tracks even with time_interval=1
assert len(ds.valid_anchors) == total_tracks
46 changes: 44 additions & 2 deletions viscy/representation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
log_batches_per_epoch: int = 8,
log_samples_per_batch: int = 1,
log_embeddings: bool = False,
log_negative_metrics_every_n_epochs: int = 2,
example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256),
) -> None:
super().__init__()
Expand All @@ -49,6 +50,7 @@ def __init__(
self.training_step_outputs = []
self.validation_step_outputs = []
self.log_embeddings = log_embeddings
self.log_negative_metrics_every_n_epochs = log_negative_metrics_every_n_epochs

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Return both features and projections.
Expand Down Expand Up @@ -94,8 +96,8 @@ def _log_metrics(
cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean()
euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean()
log_metric_dict = {
f"metrics/cosine_similarity_positive/{stage}": cosine_sim_pos,
f"metrics/euclidean_distance_positive/{stage}": euclidean_dist_pos,
f"metrics/cosine_similarity/positive/{stage}": cosine_sim_pos,
f"metrics/euclidean_distance/positive/{stage}": euclidean_dist_pos,
}

if negative is not None:
Expand All @@ -107,6 +109,46 @@ def _log_metrics(
log_metric_dict[f"metrics/euclidean_distance_negative/{stage}"] = (
euclidean_dist_neg
)
elif isinstance(self.loss_function, NTXentLoss):
if self.current_epoch % self.log_negative_metrics_every_n_epochs == 0:
batch_size = anchor.size(0)

# Cosine similarity metrics
anchor_norm = F.normalize(anchor, dim=1)
positive_norm = F.normalize(positive, dim=1)
all_embeddings_norm = torch.cat([anchor_norm, positive_norm], dim=0)
sim_matrix = torch.mm(anchor_norm, all_embeddings_norm.t())

mask = torch.ones_like(sim_matrix, dtype=torch.bool)
mask[range(batch_size), range(batch_size)] = False # Exclude self
mask[range(batch_size), range(batch_size, 2 * batch_size)] = (
False # Exclude positive
)

negative_sims = sim_matrix[mask].view(batch_size, -1)

mean_neg_sim = negative_sims.mean()
sum_neg_sim = negative_sims.sum(dim=1).mean()
margin_cosine = cosine_sim_pos - mean_neg_sim

all_embeddings = torch.cat([anchor, positive], dim=0)
dist_matrix = torch.cdist(anchor, all_embeddings, p=2)
negative_dists = dist_matrix[mask].view(batch_size, -1)

mean_neg_dist = negative_dists.mean()
sum_neg_dist = negative_dists.sum(dim=1).mean()
margin_euclidean = mean_neg_dist - euclidean_dist_pos

log_metric_dict.update(
{
f"metrics/cosine_similarity/negative_mean/{stage}": mean_neg_sim,
f"metrics/cosine_similarity/negative_sum/{stage}": sum_neg_sim,
f"metrics/margin_positive/negative/{stage}": margin_cosine,
f"metrics/euclidean_distance/negative_mean/{stage}": mean_neg_dist,
f"metrics/euclidean_distance/negative_sum/{stage}": sum_neg_dist,
f"metrics/margin_euclidean_positive/negative/{stage}": margin_euclidean,
}
)

self.log_dict(
log_metric_dict,
Expand Down