Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _build_hcs(
max_value: int | float,
sharded: bool = False,
multiscales: bool = False,
num_timepoints: int = 2,
):
dataset = open_ome_zarr(
path,
Expand All @@ -37,10 +38,17 @@ def _build_hcs(
pos.create_image(
"0",
(
np.random.rand(2, len(channel_names), *zyx_shape) * max_value
np.random.rand(num_timepoints, len(channel_names), *zyx_shape)
* max_value
).astype(dtype),
chunks=(1, 1, 1, *zyx_shape[1:]),
shards_ratio=(2, len(channel_names), zyx_shape[0], 1, 1)
shards_ratio=(
num_timepoints,
len(channel_names),
zyx_shape[0],
1,
1,
)
if sharded
else None,
)
Expand Down Expand Up @@ -116,6 +124,21 @@ def tracks_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
return dataset_path


@fixture(scope="function")
def temporal_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
"""Provides a temporal HCS OME-Zarr dataset with multiple timepoints."""
dataset_path = tmp_path_factory.mktemp("temporal.zarr")
_build_hcs(
dataset_path,
channel_names[:2], # Use first 2 channels
(10, 50, 50),
np.uint16,
65535,
num_timepoints=5,
)
return dataset_path


@fixture(scope="function")
def tracks_with_gaps_dataset(tmp_path_factory: TempPathFactory) -> Path:
"""Provides a HCS OME-Zarr dataset with tracking results with gaps in time."""
Expand Down
218 changes: 218 additions & 0 deletions tests/utils/test_meta_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from pathlib import Path

import numpy as np
from iohub import open_ome_zarr

from viscy.utils.meta_utils import (
_grid_sample_timepoint,
generate_normalization_metadata,
)


def test_grid_sample_timepoint(temporal_hcs_dataset: Path):
"""Test that _grid_sample_timepoint samples the correct timepoint."""
plate = open_ome_zarr(temporal_hcs_dataset, mode="r")

# Get first position
position_keys = list(plate.positions())
_, position = position_keys[0]

grid_spacing = 10
channel_index = 0
num_workers = 2
num_timepoints = position["0"].shape[0]

# Sample different timepoints
samples_per_timepoint = []
for timepoint_index in range(num_timepoints):
samples = _grid_sample_timepoint(
position=position,
grid_spacing=grid_spacing,
channel_index=channel_index,
timepoint_index=timepoint_index,
num_workers=num_workers,
)
samples_per_timepoint.append(samples)

# Verify shape is correct (should be 3D: Z, Y_sampled, X_sampled)
assert len(samples.shape) == 3, (
f"Expected 3D samples, got shape {samples.shape}"
)

# Verify that samples from different timepoints differ
# (due to random data generation)
means = [np.mean(s) for s in samples_per_timepoint]
unique_means = len(set(means))
assert unique_means > 1, "All timepoint samples are identical, expected variation"

plate.close()


def test_generate_normalization_metadata_structure(temporal_hcs_dataset: Path):
"""Test that generate_normalization_metadata creates correct metadata structure."""
num_timepoints = 5 # As specified in the temporal_hcs_dataset fixture

# Generate normalization metadata
generate_normalization_metadata(
str(temporal_hcs_dataset), num_workers=2, channel_ids=-1, grid_spacing=10
)

# Reopen and check metadata
plate = open_ome_zarr(temporal_hcs_dataset, mode="r")

# Check plate-level metadata
assert "normalization" in plate.zattrs, "Normalization field not found in metadata"

for channel_name in plate.channel_names:
assert channel_name in plate.zattrs["normalization"], (
f"Channel {channel_name} not found in normalization metadata"
)

channel_norm = plate.zattrs["normalization"][channel_name]

# Check that dataset statistics exist
assert "dataset_statistics" in channel_norm, (
"dataset_statistics not found in metadata"
)

# Check that timepoint statistics exist
assert "timepoint_statistics" in channel_norm, (
"timepoint_statistics not found in metadata"
)

timepoint_stats = channel_norm["timepoint_statistics"]

# Check that all timepoints are present
for t in range(num_timepoints):
assert str(t) in timepoint_stats, (
f"Timepoint {t} not found in timepoint_statistics"
)

# Verify statistics structure
t_stats = timepoint_stats[str(t)]
required_keys = ["mean", "std", "median", "iqr", "p5", "p95", "p95_p5"]
for key in required_keys:
assert key in t_stats, f"{key} not found in timepoint {t} statistics"

# Check position-level metadata
position_keys = list(plate.positions())
for _, position in position_keys:
assert "normalization" in position.zattrs, (
"Normalization field not found in position metadata"
)

for channel_name in plate.channel_names:
assert channel_name in position.zattrs["normalization"], (
f"Channel {channel_name} not found in position normalization metadata"
)

pos_channel_norm = position.zattrs["normalization"][channel_name]

# Check that all three types of statistics exist
assert "dataset_statistics" in pos_channel_norm, (
"dataset_statistics not found in position metadata"
)
assert "fov_statistics" in pos_channel_norm, (
"fov_statistics not found in position metadata"
)
assert "timepoint_statistics" in pos_channel_norm, (
"timepoint_statistics not found in position metadata"
)

plate.close()


def test_generate_normalization_timepoint_values_differ(temporal_hcs_dataset: Path):
"""Test that per-timepoint statistics have different values across timepoints."""
num_timepoints = 5

# Generate normalization metadata
generate_normalization_metadata(
str(temporal_hcs_dataset), num_workers=2, channel_ids=0, grid_spacing=10
)

# Reopen and check metadata
plate = open_ome_zarr(temporal_hcs_dataset, mode="r")

channel_name = plate.channel_names[0]
timepoint_stats = plate.zattrs["normalization"][channel_name][
"timepoint_statistics"
]

# Extract median values for each timepoint
medians = [timepoint_stats[str(t)]["median"] for t in range(num_timepoints)]

# Since data is randomly generated per timepoint,
# medians should vary across timepoints
unique_medians = len(set(medians))
assert unique_medians > 1, "All timepoint medians are identical, expected variation"

# All medians should be positive floats
for t, median in enumerate(medians):
assert isinstance(median, (int, float)), f"Timepoint {t} median is not numeric"
assert median >= 0, f"Timepoint {t} median is negative"

plate.close()


def test_generate_normalization_single_channel(temporal_hcs_dataset: Path):
"""Test normalization metadata generation for a single channel."""

# Generate normalization for only channel 0
generate_normalization_metadata(
str(temporal_hcs_dataset), num_workers=2, channel_ids=0, grid_spacing=15
)

# Reopen and check metadata
plate = open_ome_zarr(temporal_hcs_dataset, mode="r")

# Only channel 0 should have normalization metadata
assert "normalization" in plate.zattrs, (
"Normalization metadata not created at plate level"
)
assert len(plate.zattrs["normalization"]) == 1, (
"Expected only one channel in normalization metadata"
)

channel_name = plate.channel_names[0]
assert channel_name in plate.zattrs["normalization"], (
f"Channel {channel_name} not found in metadata"
)

plate.close()


def test_grid_sample_timepoint_shape(temporal_hcs_dataset: Path):
"""Test that _grid_sample_timepoint returns correctly shaped array."""
plate = open_ome_zarr(temporal_hcs_dataset, mode="r")

# Get first position
position_keys = list(plate.positions())
_, position = position_keys[0]

grid_spacing = 10
channel_index = 0
timepoint_index = 0
num_workers = 2

samples = _grid_sample_timepoint(
position=position,
grid_spacing=grid_spacing,
channel_index=channel_index,
timepoint_index=timepoint_index,
num_workers=num_workers,
)

# Expected shape: (Z, Y//grid_spacing, X//grid_spacing)
expected_z = position["0"].shape[2]
expected_y = (position["0"].shape[3] + grid_spacing - 1) // grid_spacing
expected_x = (position["0"].shape[4] + grid_spacing - 1) // grid_spacing

assert samples.shape[0] == expected_z, (
f"Expected Z={expected_z}, got {samples.shape[0]}"
)
# Y and X dimensions might be slightly different due to grid sampling
assert samples.shape[1] <= expected_y, "Y dimension larger than expected"
assert samples.shape[2] <= expected_x, "X dimension larger than expected"

plate.close()
67 changes: 62 additions & 5 deletions viscy/utils/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,55 @@ def _grid_sample(
)


def _grid_sample_timepoint(
position: ngff.Position,
grid_spacing: int,
channel_index: int,
timepoint_index: int,
num_workers: int,
):
"""
Sample a specific timepoint from a position using grid sampling.

:param Position position: NGFF position node object
:param int grid_spacing: distance between points in sampling grid
:param int channel_index: index of channel to sample
:param int timepoint_index: index of timepoint to sample
:param int num_workers: number of cpu workers for multiprocessing
:return: sampled values for the specified timepoint
"""
return (
position["0"]
.tensorstore(
context=tensorstore.Context(
{"data_copy_concurrency": {"limit": num_workers}}
)
)[timepoint_index, channel_index, :, ::grid_spacing, ::grid_spacing]
.read()
.result()
)


def generate_normalization_metadata(
zarr_dir, num_workers=4, channel_ids=-1, grid_spacing=32
):
"""
Generate pixel intensity metadata to be later used in on-the-fly normalization
during training and inference. Sampling is used for efficient estimation of median
and interquartile range for intensity values on both a dataset and field-of-view
level.
and interquartile range for intensity values on both a dataset, field-of-view,
and timepoint level.

Normalization values are recorded in the image-level metadata in the corresponding
position of each zarr_dir store. Format of metadata is as follows:
{
channel_idx : {
dataset_statistics: dataset level normalization values (positive float),
fov_statistics: field-of-view level normalization values (positive float)
fov_statistics: field-of-view level normalization values (positive float),
timepoint_statistics: {
"0": timepoint 0 normalization values (positive float),
"1": timepoint 1 normalization values (positive float),
...
}
},
.
.
Expand All @@ -104,6 +138,11 @@ def generate_normalization_metadata(
elif isinstance(channel_ids, int):
channel_ids = [channel_ids]

# Get number of timepoints from first position
_, first_position = position_map[0]
num_timepoints = first_position["0"].shape[0]
print(f"Detected {num_timepoints} timepoints in dataset")

# get arguments for multiprocessed grid sampling
mp_grid_sampler_args = []
for _, position in position_map:
Expand All @@ -126,17 +165,35 @@ def generate_normalization_metadata(
dataset_statistics = {
"dataset_statistics": get_val_stats(np.stack(dataset_sample_values)),
}

# Compute per-timepoint statistics across all FOVs
print(f"Computing per-timepoint statistics for channel {channel_name}")
timepoint_statistics = {}
for t in tqdm(range(num_timepoints), desc="Timepoints"):
timepoint_samples = []
for _, pos in position_map:
t_samples = _grid_sample_timepoint(
pos, grid_spacing, channel_index, t, num_workers
)
timepoint_samples.append(t_samples)
timepoint_statistics[str(t)] = get_val_stats(np.stack(timepoint_samples))

# Write plate-level metadata (dataset + timepoint statistics)
write_meta_field(
position=plate,
metadata=dataset_statistics,
metadata=dataset_statistics
| {"timepoint_statistics": timepoint_statistics},
field_name="normalization",
subfield_name=channel_name,
)

# Write position-level metadata (dataset + FOV + timepoint statistics)
for pos, position_statistics in position_and_statistics:
write_meta_field(
position=pos,
metadata=dataset_statistics | position_statistics,
metadata=dataset_statistics
| position_statistics
| {"timepoint_statistics": timepoint_statistics},
field_name="normalization",
subfield_name=channel_name,
)
Expand Down