diff --git a/tests/conftest.py b/tests/conftest.py index bc3d440c1..3fe5315ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ def _build_hcs( max_value: int | float, sharded: bool = False, multiscales: bool = False, + num_timepoints: int = 2, ): dataset = open_ome_zarr( path, @@ -37,10 +38,17 @@ def _build_hcs( pos.create_image( "0", ( - np.random.rand(2, len(channel_names), *zyx_shape) * max_value + np.random.rand(num_timepoints, len(channel_names), *zyx_shape) + * max_value ).astype(dtype), chunks=(1, 1, 1, *zyx_shape[1:]), - shards_ratio=(2, len(channel_names), zyx_shape[0], 1, 1) + shards_ratio=( + num_timepoints, + len(channel_names), + zyx_shape[0], + 1, + 1, + ) if sharded else None, ) @@ -116,6 +124,21 @@ def tracks_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: return dataset_path +@fixture(scope="function") +def temporal_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: + """Provides a temporal HCS OME-Zarr dataset with multiple timepoints.""" + dataset_path = tmp_path_factory.mktemp("temporal.zarr") + _build_hcs( + dataset_path, + channel_names[:2], # Use first 2 channels + (10, 50, 50), + np.uint16, + 65535, + num_timepoints=5, + ) + return dataset_path + + @fixture(scope="function") def tracks_with_gaps_dataset(tmp_path_factory: TempPathFactory) -> Path: """Provides a HCS OME-Zarr dataset with tracking results with gaps in time.""" diff --git a/tests/utils/test_meta_utils.py b/tests/utils/test_meta_utils.py new file mode 100644 index 000000000..ef30cea95 --- /dev/null +++ b/tests/utils/test_meta_utils.py @@ -0,0 +1,218 @@ +from pathlib import Path + +import numpy as np +from iohub import open_ome_zarr + +from viscy.utils.meta_utils import ( + _grid_sample_timepoint, + generate_normalization_metadata, +) + + +def test_grid_sample_timepoint(temporal_hcs_dataset: Path): + """Test that _grid_sample_timepoint samples the correct timepoint.""" + plate = open_ome_zarr(temporal_hcs_dataset, mode="r") + + # Get first position + position_keys = list(plate.positions()) + _, position = position_keys[0] + + grid_spacing = 10 + channel_index = 0 + num_workers = 2 + num_timepoints = position["0"].shape[0] + + # Sample different timepoints + samples_per_timepoint = [] + for timepoint_index in range(num_timepoints): + samples = _grid_sample_timepoint( + position=position, + grid_spacing=grid_spacing, + channel_index=channel_index, + timepoint_index=timepoint_index, + num_workers=num_workers, + ) + samples_per_timepoint.append(samples) + + # Verify shape is correct (should be 3D: Z, Y_sampled, X_sampled) + assert len(samples.shape) == 3, ( + f"Expected 3D samples, got shape {samples.shape}" + ) + + # Verify that samples from different timepoints differ + # (due to random data generation) + means = [np.mean(s) for s in samples_per_timepoint] + unique_means = len(set(means)) + assert unique_means > 1, "All timepoint samples are identical, expected variation" + + plate.close() + + +def test_generate_normalization_metadata_structure(temporal_hcs_dataset: Path): + """Test that generate_normalization_metadata creates correct metadata structure.""" + num_timepoints = 5 # As specified in the temporal_hcs_dataset fixture + + # Generate normalization metadata + generate_normalization_metadata( + str(temporal_hcs_dataset), num_workers=2, channel_ids=-1, grid_spacing=10 + ) + + # Reopen and check metadata + plate = open_ome_zarr(temporal_hcs_dataset, mode="r") + + # Check plate-level metadata + assert "normalization" in plate.zattrs, "Normalization field not found in metadata" + + for channel_name in plate.channel_names: + assert channel_name in plate.zattrs["normalization"], ( + f"Channel {channel_name} not found in normalization metadata" + ) + + channel_norm = plate.zattrs["normalization"][channel_name] + + # Check that dataset statistics exist + assert "dataset_statistics" in channel_norm, ( + "dataset_statistics not found in metadata" + ) + + # Check that timepoint statistics exist + assert "timepoint_statistics" in channel_norm, ( + "timepoint_statistics not found in metadata" + ) + + timepoint_stats = channel_norm["timepoint_statistics"] + + # Check that all timepoints are present + for t in range(num_timepoints): + assert str(t) in timepoint_stats, ( + f"Timepoint {t} not found in timepoint_statistics" + ) + + # Verify statistics structure + t_stats = timepoint_stats[str(t)] + required_keys = ["mean", "std", "median", "iqr", "p5", "p95", "p95_p5"] + for key in required_keys: + assert key in t_stats, f"{key} not found in timepoint {t} statistics" + + # Check position-level metadata + position_keys = list(plate.positions()) + for _, position in position_keys: + assert "normalization" in position.zattrs, ( + "Normalization field not found in position metadata" + ) + + for channel_name in plate.channel_names: + assert channel_name in position.zattrs["normalization"], ( + f"Channel {channel_name} not found in position normalization metadata" + ) + + pos_channel_norm = position.zattrs["normalization"][channel_name] + + # Check that all three types of statistics exist + assert "dataset_statistics" in pos_channel_norm, ( + "dataset_statistics not found in position metadata" + ) + assert "fov_statistics" in pos_channel_norm, ( + "fov_statistics not found in position metadata" + ) + assert "timepoint_statistics" in pos_channel_norm, ( + "timepoint_statistics not found in position metadata" + ) + + plate.close() + + +def test_generate_normalization_timepoint_values_differ(temporal_hcs_dataset: Path): + """Test that per-timepoint statistics have different values across timepoints.""" + num_timepoints = 5 + + # Generate normalization metadata + generate_normalization_metadata( + str(temporal_hcs_dataset), num_workers=2, channel_ids=0, grid_spacing=10 + ) + + # Reopen and check metadata + plate = open_ome_zarr(temporal_hcs_dataset, mode="r") + + channel_name = plate.channel_names[0] + timepoint_stats = plate.zattrs["normalization"][channel_name][ + "timepoint_statistics" + ] + + # Extract median values for each timepoint + medians = [timepoint_stats[str(t)]["median"] for t in range(num_timepoints)] + + # Since data is randomly generated per timepoint, + # medians should vary across timepoints + unique_medians = len(set(medians)) + assert unique_medians > 1, "All timepoint medians are identical, expected variation" + + # All medians should be positive floats + for t, median in enumerate(medians): + assert isinstance(median, (int, float)), f"Timepoint {t} median is not numeric" + assert median >= 0, f"Timepoint {t} median is negative" + + plate.close() + + +def test_generate_normalization_single_channel(temporal_hcs_dataset: Path): + """Test normalization metadata generation for a single channel.""" + + # Generate normalization for only channel 0 + generate_normalization_metadata( + str(temporal_hcs_dataset), num_workers=2, channel_ids=0, grid_spacing=15 + ) + + # Reopen and check metadata + plate = open_ome_zarr(temporal_hcs_dataset, mode="r") + + # Only channel 0 should have normalization metadata + assert "normalization" in plate.zattrs, ( + "Normalization metadata not created at plate level" + ) + assert len(plate.zattrs["normalization"]) == 1, ( + "Expected only one channel in normalization metadata" + ) + + channel_name = plate.channel_names[0] + assert channel_name in plate.zattrs["normalization"], ( + f"Channel {channel_name} not found in metadata" + ) + + plate.close() + + +def test_grid_sample_timepoint_shape(temporal_hcs_dataset: Path): + """Test that _grid_sample_timepoint returns correctly shaped array.""" + plate = open_ome_zarr(temporal_hcs_dataset, mode="r") + + # Get first position + position_keys = list(plate.positions()) + _, position = position_keys[0] + + grid_spacing = 10 + channel_index = 0 + timepoint_index = 0 + num_workers = 2 + + samples = _grid_sample_timepoint( + position=position, + grid_spacing=grid_spacing, + channel_index=channel_index, + timepoint_index=timepoint_index, + num_workers=num_workers, + ) + + # Expected shape: (Z, Y//grid_spacing, X//grid_spacing) + expected_z = position["0"].shape[2] + expected_y = (position["0"].shape[3] + grid_spacing - 1) // grid_spacing + expected_x = (position["0"].shape[4] + grid_spacing - 1) // grid_spacing + + assert samples.shape[0] == expected_z, ( + f"Expected Z={expected_z}, got {samples.shape[0]}" + ) + # Y and X dimensions might be slightly different due to grid sampling + assert samples.shape[1] <= expected_y, "Y dimension larger than expected" + assert samples.shape[2] <= expected_x, "X dimension larger than expected" + + plate.close() diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 2bd79081d..92fc1d1be 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -69,21 +69,55 @@ def _grid_sample( ) +def _grid_sample_timepoint( + position: ngff.Position, + grid_spacing: int, + channel_index: int, + timepoint_index: int, + num_workers: int, +): + """ + Sample a specific timepoint from a position using grid sampling. + + :param Position position: NGFF position node object + :param int grid_spacing: distance between points in sampling grid + :param int channel_index: index of channel to sample + :param int timepoint_index: index of timepoint to sample + :param int num_workers: number of cpu workers for multiprocessing + :return: sampled values for the specified timepoint + """ + return ( + position["0"] + .tensorstore( + context=tensorstore.Context( + {"data_copy_concurrency": {"limit": num_workers}} + ) + )[timepoint_index, channel_index, :, ::grid_spacing, ::grid_spacing] + .read() + .result() + ) + + def generate_normalization_metadata( zarr_dir, num_workers=4, channel_ids=-1, grid_spacing=32 ): """ Generate pixel intensity metadata to be later used in on-the-fly normalization during training and inference. Sampling is used for efficient estimation of median - and interquartile range for intensity values on both a dataset and field-of-view - level. + and interquartile range for intensity values on both a dataset, field-of-view, + and timepoint level. Normalization values are recorded in the image-level metadata in the corresponding position of each zarr_dir store. Format of metadata is as follows: { channel_idx : { dataset_statistics: dataset level normalization values (positive float), - fov_statistics: field-of-view level normalization values (positive float) + fov_statistics: field-of-view level normalization values (positive float), + timepoint_statistics: { + "0": timepoint 0 normalization values (positive float), + "1": timepoint 1 normalization values (positive float), + ... + } }, . . @@ -104,6 +138,11 @@ def generate_normalization_metadata( elif isinstance(channel_ids, int): channel_ids = [channel_ids] + # Get number of timepoints from first position + _, first_position = position_map[0] + num_timepoints = first_position["0"].shape[0] + print(f"Detected {num_timepoints} timepoints in dataset") + # get arguments for multiprocessed grid sampling mp_grid_sampler_args = [] for _, position in position_map: @@ -126,17 +165,35 @@ def generate_normalization_metadata( dataset_statistics = { "dataset_statistics": get_val_stats(np.stack(dataset_sample_values)), } + + # Compute per-timepoint statistics across all FOVs + print(f"Computing per-timepoint statistics for channel {channel_name}") + timepoint_statistics = {} + for t in tqdm(range(num_timepoints), desc="Timepoints"): + timepoint_samples = [] + for _, pos in position_map: + t_samples = _grid_sample_timepoint( + pos, grid_spacing, channel_index, t, num_workers + ) + timepoint_samples.append(t_samples) + timepoint_statistics[str(t)] = get_val_stats(np.stack(timepoint_samples)) + + # Write plate-level metadata (dataset + timepoint statistics) write_meta_field( position=plate, - metadata=dataset_statistics, + metadata=dataset_statistics + | {"timepoint_statistics": timepoint_statistics}, field_name="normalization", subfield_name=channel_name, ) + # Write position-level metadata (dataset + FOV + timepoint statistics) for pos, position_statistics in position_and_statistics: write_meta_field( position=pos, - metadata=dataset_statistics | position_statistics, + metadata=dataset_statistics + | position_statistics + | {"timepoint_statistics": timepoint_statistics}, field_name="normalization", subfield_name=channel_name, )