diff --git a/tests/conftest.py b/tests/conftest.py index bdf3e420..7893b29d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd from iohub import open_ome_zarr -from pytest import TempPathFactory, fixture +from pytest import FixtureRequest, TempPathFactory, fixture if TYPE_CHECKING: from numpy.typing import DTypeLike @@ -20,6 +20,7 @@ def _build_hcs( zyx_shape: tuple[int, int, int], dtype: DTypeLike, max_value: int | float, + sharded: bool = False, multiscales: bool = False, ): dataset = open_ome_zarr( @@ -27,6 +28,7 @@ def _build_hcs( layout="hcs", mode="w", channel_names=channel_names, + version="0.4" if not sharded else "0.5", ) for row in ("A", "B"): for col in ("1", "2"): @@ -37,6 +39,10 @@ def _build_hcs( ( np.random.rand(2, len(channel_names), *zyx_shape) * max_value ).astype(dtype), + chunks=(1, 1, 1, *zyx_shape[1:]), + shards_ratio=(2, len(channel_names), zyx_shape[0], 1, 1) + if sharded + else None, ) if multiscales: pos["1"] = pos["0"][::2, :, ::2, ::2, ::2] @@ -59,11 +65,15 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: return dataset_path -@fixture(scope="function") -def small_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: +@fixture(scope="function", params=[False, True]) +def small_hcs_dataset( + tmp_path_factory: TempPathFactory, request: FixtureRequest +) -> Path: """Provides a small, not preprocessed HCS OME-Zarr dataset.""" dataset_path = tmp_path_factory.mktemp("small.zarr") - _build_hcs(dataset_path, channel_names, (12, 64, 64), np.uint16, 1) + _build_hcs( + dataset_path, channel_names, (12, 64, 64), np.uint16, 1, sharded=request.param + ) return dataset_path diff --git a/tests/data/test_hcs.py b/tests/data/test_hcs.py index c71488c4..d45bd17e 100644 --- a/tests/data/test_hcs.py +++ b/tests/data/test_hcs.py @@ -1,33 +1,8 @@ -from pathlib import Path - from iohub import open_ome_zarr from monai.transforms import RandSpatialCropSamplesd from pytest import mark from viscy.data.hcs import HCSDataModule -from viscy.trainer import VisCyTrainer - - -@mark.parametrize("default_channels", [True, False]) -def test_preprocess(small_hcs_dataset: Path, default_channels: bool): - data_path = small_hcs_dataset - if default_channels: - channel_names = -1 - else: - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - trainer = VisCyTrainer(accelerator="cpu") - trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - for channel in channel_names: - assert "dataset_statistics" in dataset.zattrs["normalization"][channel] - for _, fov in dataset.positions(): - norm_metadata = fov.zattrs["normalization"] - for channel in channel_names: - assert channel in norm_metadata - assert "dataset_statistics" in norm_metadata[channel] - assert "fov_statistics" in norm_metadata[channel] @mark.parametrize("multi_sample_augmentation", [True, False]) diff --git a/tests/preprocessing/test_trainer_preprocess.py b/tests/preprocessing/test_trainer_preprocess.py new file mode 100644 index 00000000..a65fab17 --- /dev/null +++ b/tests/preprocessing/test_trainer_preprocess.py @@ -0,0 +1,28 @@ +from pathlib import Path + +from iohub import open_ome_zarr +from pytest import mark + +from viscy.trainer import VisCyTrainer + + +@mark.parametrize("default_channels", [True, False]) +def test_preprocess(small_hcs_dataset: Path, default_channels: bool): + data_path = small_hcs_dataset + if default_channels: + channel_names = -1 + else: + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + trainer = VisCyTrainer(accelerator="cpu") + trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + for channel in channel_names: + assert "dataset_statistics" in dataset.zattrs["normalization"][channel] + for _, fov in dataset.positions(): + norm_metadata = fov.zattrs["normalization"] + for channel in channel_names: + assert channel in norm_metadata + assert "dataset_statistics" in norm_metadata[channel] + assert "fov_statistics" in norm_metadata[channel]