diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 5eb200f3..7a36a572 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -203,8 +203,10 @@ class CachedOmeZarrDataModule(GPUTransformDataModule, SelectWell): Skip caching for this dataset, by default False include_wells : list[str], optional List of well names to include in the dataset, by default None (all) - include_wells : list[str], optional - List of well names to include in the dataset, by default None (all) + exclude_fovs : list[str], optional + List of fovs names to exclude from the dataset, by default None (none) + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker. """ def __init__( @@ -222,6 +224,7 @@ def __init__( skip_cache: bool = False, include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, + prefetch_factor: int | None = None, ): super().__init__() self.data_path = data_path @@ -237,6 +240,7 @@ def __init__( self.skip_cache = skip_cache self._include_wells = include_wells self._exclude_fovs = exclude_fovs + self.prefetch_factor = prefetch_factor @property def train_cpu_transforms(self) -> Compose: