diff --git a/tests/conftest.py b/tests/conftest.py index 7893b29d5..bc3d440c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,3 +114,58 @@ def tracks_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: ) fake_tracks.to_csv(dataset_path / fov_name / "tracks.csv", index=False) return dataset_path + + +@fixture(scope="function") +def tracks_with_gaps_dataset(tmp_path_factory: TempPathFactory) -> Path: + """Provides a HCS OME-Zarr dataset with tracking results with gaps in time.""" + dataset_path = tmp_path_factory.mktemp("tracks_gaps.zarr") + _build_hcs(dataset_path, ["nuclei_labels"], (1, 256, 256), np.uint16, 3) + + # Define different track patterns for different FOVs + track_patterns = { + "A/1/0": [ + # Track 0: complete sequence t=[0,1,2,3] + {"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0}, + {"track_id": 0, "t": 1, "y": 128, "x": 128, "id": 1}, + {"track_id": 0, "t": 2, "y": 128, "x": 128, "id": 2}, + {"track_id": 0, "t": 3, "y": 128, "x": 128, "id": 3}, + # Track 1: ends early t=[0,1] + {"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 4}, + {"track_id": 1, "t": 1, "y": 100, "x": 100, "id": 5}, + ], + "A/1/1": [ + # Track 0: gap at t=2, has t=[0,1,3] + {"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0}, + {"track_id": 0, "t": 1, "y": 128, "x": 128, "id": 1}, + {"track_id": 0, "t": 3, "y": 128, "x": 128, "id": 2}, + # Track 1: even timepoints only t=[0,2,4] + {"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 3}, + {"track_id": 1, "t": 2, "y": 100, "x": 100, "id": 4}, + {"track_id": 1, "t": 4, "y": 100, "x": 100, "id": 5}, + ], + "A/2/0": [ + # Track 0: single timepoint t=[0] + {"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0}, + # Track 1: complete short sequence t=[0,1,2] + {"track_id": 1, "t": 0, "y": 100, "x": 100, "id": 1}, + {"track_id": 1, "t": 1, "y": 100, "x": 100, "id": 2}, + {"track_id": 1, "t": 2, "y": 100, "x": 100, "id": 3}, + ], + } + + for fov_name, _ in open_ome_zarr(dataset_path).positions(): + if fov_name in track_patterns: + tracks_data = track_patterns[fov_name] + else: + # Default tracks for other FOVs + tracks_data = [ + {"track_id": 0, "t": 0, "y": 128, "x": 128, "id": 0}, + ] + + tracks_df = pd.DataFrame(tracks_data) + tracks_df["parent_track_id"] = -1 + tracks_df["parent_id"] = -1 + tracks_df.to_csv(dataset_path / fov_name / "tracks.csv", index=False) + + return dataset_path diff --git a/tests/data/test_triplet.py b/tests/data/test_triplet.py index 3d40e1f99..3f763b96c 100644 --- a/tests/data/test_triplet.py +++ b/tests/data/test_triplet.py @@ -2,7 +2,7 @@ from iohub import open_ome_zarr from pytest import mark -from viscy.data.triplet import TripletDataModule +from viscy.data.triplet import TripletDataModule, TripletDataset @mark.parametrize("include_wells", [None, ["A/1", "A/2", "B/1"]]) @@ -109,3 +109,244 @@ def test_datamodule_z_window_size( expected_z_shape, *yx_patch_size, ) + + +def test_filter_anchors_time_interval_any( + preprocessed_hcs_dataset, tracks_with_gaps_dataset +): + """Test that time_interval='any' returns all tracks unchanged.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + positions = list(dataset.positions()) + + # Create dataset with time_interval="any" + tracks_tables = [] + for fov_name, _ in positions: + tracks_df = pd.read_csv( + next((tracks_with_gaps_dataset / fov_name).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) + + total_tracks = sum(len(df) for df in tracks_tables) + + ds = TripletDataset( + positions=[pos for _, pos in positions], + tracks_tables=tracks_tables, + channel_names=channel_names, + initial_yx_patch_size=(64, 64), + z_range=slice(4, 9), + fit=True, + time_interval="any", + ) + + # Should return all tracks + assert len(ds.valid_anchors) == total_tracks + + +def test_filter_anchors_time_interval_1( + preprocessed_hcs_dataset, tracks_with_gaps_dataset +): + """Test filtering with time_interval=1.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + positions = list(dataset.positions()) + + tracks_tables = [] + for fov_name, _ in positions: + tracks_df = pd.read_csv( + next((tracks_with_gaps_dataset / fov_name).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) + + ds = TripletDataset( + positions=[pos for _, pos in positions], + tracks_tables=tracks_tables, + channel_names=channel_names, + initial_yx_patch_size=(64, 64), + z_range=slice(4, 9), + fit=True, + time_interval=1, + ) + + # Check expected anchors per FOV/track + valid_anchors = ds.valid_anchors + + # FOV A/1/0, Track 0: t=[0,1,2,3] -> valid anchors at t=[0,1,2] + fov_a10_track0 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 0) + ] + assert set(fov_a10_track0["t"]) == {0, 1, 2} + + # FOV A/1/0, Track 1: t=[0,1] -> valid anchor at t=[0] + fov_a10_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 1) + ] + assert set(fov_a10_track1["t"]) == {0} + + # FOV A/1/1, Track 0: t=[0,1,3] -> valid anchor at t=[0] only (t=1 has no t+1=2) + fov_a11_track0 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 0) + ] + assert set(fov_a11_track0["t"]) == {0} + + # FOV A/1/1, Track 1: t=[0,2,4] -> no valid anchors (gaps of 2, no consecutive t+1) + fov_a11_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 1) + ] + assert len(fov_a11_track1) == 0 + + # FOV A/2/0, Track 0: t=[0] -> no valid anchors (no t+1) + fov_a20_track0 = valid_anchors[ + (valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 0) + ] + assert len(fov_a20_track0) == 0 + + # FOV A/2/0, Track 1: t=[0,1,2] -> valid anchors at t=[0,1] + fov_a20_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 1) + ] + assert set(fov_a20_track1["t"]) == {0, 1} + + +def test_filter_anchors_time_interval_2( + preprocessed_hcs_dataset, tracks_with_gaps_dataset +): + """Test filtering with time_interval=2.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + positions = list(dataset.positions()) + + tracks_tables = [] + for fov_name, _ in positions: + tracks_df = pd.read_csv( + next((tracks_with_gaps_dataset / fov_name).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) + + ds = TripletDataset( + positions=[pos for _, pos in positions], + tracks_tables=tracks_tables, + channel_names=channel_names, + initial_yx_patch_size=(64, 64), + z_range=slice(4, 9), + fit=True, + time_interval=2, + ) + + valid_anchors = ds.valid_anchors + + # FOV A/1/0, Track 0: t=[0,1,2,3] -> valid anchors at t=[0,1] (t+2 available) + fov_a10_track0 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 0) + ] + assert set(fov_a10_track0["t"]) == {0, 1} + + # FOV A/1/0, Track 1: t=[0,1] -> no valid anchors (no t+2) + fov_a10_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/0") & (valid_anchors["track_id"] == 1) + ] + assert len(fov_a10_track1) == 0 + + # FOV A/1/1, Track 0: t=[0,1,3] -> valid anchor at t=[1] (t=1+2=3 exists) + fov_a11_track0 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 0) + ] + assert set(fov_a11_track0["t"]) == {1} + + # FOV A/1/1, Track 1: t=[0,2,4] -> valid anchors at t=[0,2] + fov_a11_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/1/1") & (valid_anchors["track_id"] == 1) + ] + assert set(fov_a11_track1["t"]) == {0, 2} + + # FOV A/2/0, Track 1: t=[0,1,2] -> valid anchor at t=[0] + fov_a20_track1 = valid_anchors[ + (valid_anchors["fov_name"] == "A/2/0") & (valid_anchors["track_id"] == 1) + ] + assert set(fov_a20_track1["t"]) == {0} + + +def test_filter_anchors_cross_fov_independence( + preprocessed_hcs_dataset, tracks_with_gaps_dataset +): + """Test that same track_id in different FOVs are treated independently.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + positions = list(dataset.positions()) + + tracks_tables = [] + for fov_name, _ in positions: + tracks_df = pd.read_csv( + next((tracks_with_gaps_dataset / fov_name).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) + + ds = TripletDataset( + positions=[pos for _, pos in positions], + tracks_tables=tracks_tables, + channel_names=channel_names, + initial_yx_patch_size=(64, 64), + z_range=slice(4, 9), + fit=True, + time_interval=1, + ) + + # Check global_track_id format and uniqueness + assert "global_track_id" in ds.tracks.columns + global_track_ids = ds.tracks["global_track_id"].unique() + + # Verify format: should be "fov_name_track_id" + for gid in global_track_ids: + assert "_" in gid + fov_part, track_id_part = gid.rsplit("_", 1) + assert "/" in fov_part # FOV names contain slashes like "A/1/0" + + # Track 0 exists in multiple FOVs (A/1/0, A/1/1, A/2/0) but should have different global_track_ids + track0_global_ids = ds.tracks[ds.tracks["track_id"] == 0][ + "global_track_id" + ].unique() + assert len(track0_global_ids) >= 3 # At least 3 different FOVs with track_id=0 + + # Verify that filtering is independent per FOV + # A/1/0 Track 0 (continuous) should have more valid anchors than A/1/1 Track 0 (with gap) + valid_a10_track0 = ds.valid_anchors[ + (ds.valid_anchors["fov_name"] == "A/1/0") & (ds.valid_anchors["track_id"] == 0) + ] + valid_a11_track0 = ds.valid_anchors[ + (ds.valid_anchors["fov_name"] == "A/1/1") & (ds.valid_anchors["track_id"] == 0) + ] + # A/1/0 Track 0 has t=[0,1,2] valid (3 anchors) + # A/1/1 Track 0 has t=[0] valid (1 anchor, gap at t=2) + assert len(valid_a10_track0) == 3 + assert len(valid_a11_track0) == 1 + + +def test_filter_anchors_predict_mode( + preprocessed_hcs_dataset, tracks_with_gaps_dataset +): + """Test that predict mode (fit=False) returns all tracks regardless of time_interval.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + positions = list(dataset.positions()) + + tracks_tables = [] + for fov_name, _ in positions: + tracks_df = pd.read_csv( + next((tracks_with_gaps_dataset / fov_name).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) + + total_tracks = sum(len(df) for df in tracks_tables) + + ds = TripletDataset( + positions=[pos for _, pos in positions], + tracks_tables=tracks_tables, + channel_names=channel_names, + initial_yx_patch_size=(64, 64), + z_range=slice(4, 9), + fit=False, # Predict mode + time_interval=1, + ) + + # Should return all tracks even with time_interval=1 + assert len(ds.valid_anchors) == total_tracks diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 7a35d93f1..b7938013e 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -36,6 +36,7 @@ def __init__( log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, log_embeddings: bool = False, + log_negative_metrics_every_n_epochs: int = 2, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -49,6 +50,7 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] self.log_embeddings = log_embeddings + self.log_negative_metrics_every_n_epochs = log_negative_metrics_every_n_epochs def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Return both features and projections. @@ -94,8 +96,8 @@ def _log_metrics( cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean() euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean() log_metric_dict = { - f"metrics/cosine_similarity_positive/{stage}": cosine_sim_pos, - f"metrics/euclidean_distance_positive/{stage}": euclidean_dist_pos, + f"metrics/cosine_similarity/positive/{stage}": cosine_sim_pos, + f"metrics/euclidean_distance/positive/{stage}": euclidean_dist_pos, } if negative is not None: @@ -107,6 +109,46 @@ def _log_metrics( log_metric_dict[f"metrics/euclidean_distance_negative/{stage}"] = ( euclidean_dist_neg ) + elif isinstance(self.loss_function, NTXentLoss): + if self.current_epoch % self.log_negative_metrics_every_n_epochs == 0: + batch_size = anchor.size(0) + + # Cosine similarity metrics + anchor_norm = F.normalize(anchor, dim=1) + positive_norm = F.normalize(positive, dim=1) + all_embeddings_norm = torch.cat([anchor_norm, positive_norm], dim=0) + sim_matrix = torch.mm(anchor_norm, all_embeddings_norm.t()) + + mask = torch.ones_like(sim_matrix, dtype=torch.bool) + mask[range(batch_size), range(batch_size)] = False # Exclude self + mask[range(batch_size), range(batch_size, 2 * batch_size)] = ( + False # Exclude positive + ) + + negative_sims = sim_matrix[mask].view(batch_size, -1) + + mean_neg_sim = negative_sims.mean() + sum_neg_sim = negative_sims.sum(dim=1).mean() + margin_cosine = cosine_sim_pos - mean_neg_sim + + all_embeddings = torch.cat([anchor, positive], dim=0) + dist_matrix = torch.cdist(anchor, all_embeddings, p=2) + negative_dists = dist_matrix[mask].view(batch_size, -1) + + mean_neg_dist = negative_dists.mean() + sum_neg_dist = negative_dists.sum(dim=1).mean() + margin_euclidean = mean_neg_dist - euclidean_dist_pos + + log_metric_dict.update( + { + f"metrics/cosine_similarity/negative_mean/{stage}": mean_neg_sim, + f"metrics/cosine_similarity/negative_sum/{stage}": sum_neg_sim, + f"metrics/margin_positive/negative/{stage}": margin_cosine, + f"metrics/euclidean_distance/negative_mean/{stage}": mean_neg_dist, + f"metrics/euclidean_distance/negative_sum/{stage}": sum_neg_dist, + f"metrics/margin_euclidean_positive/negative/{stage}": margin_euclidean, + } + ) self.log_dict( log_metric_dict,