Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
normalizations:
array_key: "0"
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [source]
Expand Down Expand Up @@ -92,3 +93,4 @@ data:
sigma_y: [0.25, 1.5]
sigma_x: [0.25, 1.5]
caching: false

1 change: 1 addition & 0 deletions examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ predict:
- 256
- 256
caching: false
array_key: "0"
return_predictions: false
ckpt_path: null
3 changes: 2 additions & 1 deletion examples/configs/test_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ data:
- 256
caching: false
ground_truth_masks: null
array_key: "0"
ckpt_path: null
verbose: true
verbose: true
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _build_hcs(
zyx_shape: tuple[int, int, int],
dtype: DTypeLike,
max_value: int | float,
multiscales: bool = False,
):
dataset = open_ome_zarr(
path,
Expand All @@ -37,13 +38,17 @@ def _build_hcs(
np.random.rand(2, len(channel_names), *zyx_shape) * max_value
).astype(dtype),
)
if multiscales:
pos["1"] = pos["0"][::2, :, ::2, ::2, ::2]


@fixture(scope="session")
def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
"""Provides a preprocessed HCS OME-Zarr dataset."""
dataset_path = tmp_path_factory.mktemp("preprocessed.zarr")
_build_hcs(dataset_path, channel_names, (12, 256, 256), np.float32, 1.0)
_build_hcs(
dataset_path, channel_names, (12, 256, 256), np.float32, 1.0, multiscales=True
)
# U[0, 1)
expected = {"mean": 0.5, "std": 1 / np.sqrt(12), "median": 0.5, "iqr": 0.5}
norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names}
Expand Down
11 changes: 9 additions & 2 deletions tests/translation/test_predict_writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from iohub import open_ome_zarr

from viscy.data.hcs import HCSDataModule
Expand All @@ -13,12 +14,16 @@ def test_pad_shape():
assert _pad_shape(full_shape, 5) == full_shape


def test_predict_writer(preprocessed_hcs_dataset, tmp_path):
@pytest.mark.parametrize("array_key", ["0", "1"])
def test_predict_writer(preprocessed_hcs_dataset, tmp_path, array_key):
z_window_size = 5
data_path = preprocessed_hcs_dataset
channel_split = 2
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
expected_shape = list(next(dataset.positions())[1][array_key].shape)
expected_shape[1] = len(channel_names) - channel_split
expected_shape = tuple(expected_shape)
dm = HCSDataModule(
data_path=data_path,
source_channel=channel_names[:channel_split],
Expand All @@ -27,6 +32,7 @@ def test_predict_writer(preprocessed_hcs_dataset, tmp_path):
target_2d=bool(z_window_size == 1),
batch_size=2,
num_workers=0,
array_key=array_key,
)

model = VSUNet(
Expand Down Expand Up @@ -59,4 +65,5 @@ def test_predict_writer(preprocessed_hcs_dataset, tmp_path):
assert output_path.exists()
with open_ome_zarr(output_path) as result:
for _, pos in result.positions():
assert pos["0"][:].any()
assert pos[array_key][:].any()
assert pos[array_key].shape == expected_shape
18 changes: 16 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,22 @@ class SlidingWindowDataset(Dataset):
:param ChannelMap channels: source and target channel names,
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param str array_key:
Name of the image arrays (multiscales level), by default "0"
:param DictTransform | None transform:
a callable that transforms data, defaults to None
:param bool load_normalization_metadata:
whether to load normalization metadata, defaults to True
"""

def __init__(
self,
positions: list[Position],
channels: ChannelMap,
z_window_size: int,
array_key: str = "0",
transform: DictTransform | None = None,
load_normalization_metadata: bool = True,
) -> None:
super().__init__()
self.positions = positions
Expand All @@ -128,7 +134,9 @@ def __init__(
)
self.z_window_size = z_window_size
self.transform = transform
self.array_key = array_key
self._get_windows()
self.load_normalization_metadata = load_normalization_metadata

def _get_windows(self) -> None:
"""Count the sliding windows along T and Z,
Expand All @@ -138,7 +146,7 @@ def _get_windows(self) -> None:
self.window_arrays = []
self.window_norm_meta: list[NormMeta | None] = []
for fov in self.positions:
img_arr: ImageArray = fov["0"]
img_arr: ImageArray = fov[str(self.array_key)]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
if zs < 1:
Expand Down Expand Up @@ -225,10 +233,11 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
if self.load_normalization_metadata:
sample["norm_meta"] = norm_meta
return sample


Expand Down Expand Up @@ -326,6 +335,8 @@ class HCSDataModule(LightningDataModule):
prefetch_factor : int or None, optional
Number of samples loaded in advance by each worker during fitting,
defaults to None (2 per PyTorch default).
array_key : str, optional
Name of the image arrays (multiscales level), by default "0"
"""

def __init__(
Expand All @@ -345,6 +356,7 @@ def __init__(
ground_truth_masks: Path | None = None,
persistent_workers=False,
prefetch_factor=None,
array_key: str = "0",
pin_memory=False,
):
super().__init__()
Expand All @@ -364,6 +376,7 @@ def __init__(
self.prepare_data_per_node = True
self.persistent_workers = persistent_workers
self.prefetch_factor = prefetch_factor
self.array_key = array_key
self.pin_memory = pin_memory

@property
Expand Down Expand Up @@ -421,6 +434,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
return {
"channels": {"source": self.source_channel},
"z_window_size": self.z_window_size,
"array_key": self.array_key,
}

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down