From a61a037c79ed24d0e388b347159a1e3441f5bab7 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Tue, 3 Dec 2024 16:10:28 +0100 Subject: [PATCH] feat: sample builder ergonomics (#28) - add `gather_every` (~strides) and `length` to FixedWindowSampleBuilder - remove RollingWindowSampleBuilder --- config/_templates/dataloader/torch.yaml | 2 +- config/_templates/dataset/yaak.yaml | 10 +----- pyproject.toml | 2 +- src/rbyte/__init__.py | 9 ++--- src/rbyte/sample/__init__.py | 3 +- src/rbyte/sample/fixed_window.py | 17 ++++++--- src/rbyte/sample/rolling_window.py | 48 ------------------------- src/rbyte/scripts/visualize.py | 5 --- 8 files changed, 18 insertions(+), 78 deletions(-) delete mode 100644 src/rbyte/sample/rolling_window.py diff --git a/config/_templates/dataloader/torch.yaml b/config/_templates/dataloader/torch.yaml index 9839b8b..27c5922 100644 --- a/config/_templates/dataloader/torch.yaml +++ b/config/_templates/dataloader/torch.yaml @@ -10,4 +10,4 @@ collate_fn: num_workers: 1 pin_memory: false persistent_workers: true -multiprocessing_context: forkserver +multiprocessing_context: spawn diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index ca83b69..8a83a6e 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -134,15 +134,7 @@ inputs: index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx every: 6i period: 6i - - - _target_: pipefunc.PipeFunc - renames: - input: samples - output_name: samples_filtered - func: - _target_: rbyte.io.DataFrameFilter - predicate: | - array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6 + length: 6 kwargs: meta: diff --git a/pyproject.toml b/pyproject.toml index c4b2fa2..958fa44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.10.0" +version = "0.10.1" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] diff --git a/src/rbyte/__init__.py b/src/rbyte/__init__.py index 73f4243..dbfd726 100644 --- a/src/rbyte/__init__.py +++ b/src/rbyte/__init__.py @@ -1,13 +1,8 @@ from importlib.metadata import version from .dataset import Dataset -from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder +from .sample import FixedWindowSampleBuilder __version__ = version(__package__ or __name__) -__all__ = [ - "Dataset", - "FixedWindowSampleBuilder", - "RollingWindowSampleBuilder", - "__version__", -] +__all__ = ["Dataset", "FixedWindowSampleBuilder", "__version__"] diff --git a/src/rbyte/sample/__init__.py b/src/rbyte/sample/__init__.py index 7f83bb6..a1dcb35 100644 --- a/src/rbyte/sample/__init__.py +++ b/src/rbyte/sample/__init__.py @@ -1,4 +1,3 @@ from .fixed_window import FixedWindowSampleBuilder -from .rolling_window import RollingWindowSampleBuilder -__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"] +__all__ = ["FixedWindowSampleBuilder"] diff --git a/src/rbyte/sample/fixed_window.py b/src/rbyte/sample/fixed_window.py index ff26276..8ca8dda 100644 --- a/src/rbyte/sample/fixed_window.py +++ b/src/rbyte/sample/fixed_window.py @@ -4,7 +4,7 @@ import polars as pl from polars._typing import ClosedInterval -from pydantic import validate_call +from pydantic import PositiveInt, validate_call @final @@ -19,18 +19,26 @@ class FixedWindowSampleBuilder: __name__ = __qualname__ @validate_call - def __init__( + def __init__( # noqa: PLR0913 self, *, index_column: str, every: str | timedelta, period: str | timedelta | None = None, closed: ClosedInterval = "left", + gather_every: PositiveInt = 1, + length: PositiveInt | None = None, ) -> None: self._index_column = pl.col(index_column) self._every = every self._period = period self._closed: ClosedInterval = closed + self._gather_every = gather_every + self._length_filter = ( + (self._index_column.list.len() > 0) + if length is None + else (self._index_column.list.len() == length) + ) def __call__(self, input: pl.DataFrame) -> pl.DataFrame: return ( @@ -44,8 +52,7 @@ def __call__(self, input: pl.DataFrame) -> pl.DataFrame: label="datapoint", start_by="datapoint", ) - .agg(pl.all()) - .filter(self._index_column.list.len() > 0) - .sort(_index_column) + .agg(pl.all().gather_every(self._gather_every)) + .filter(self._length_filter) .drop(_index_column) ) diff --git a/src/rbyte/sample/rolling_window.py b/src/rbyte/sample/rolling_window.py deleted file mode 100644 index c8e78c2..0000000 --- a/src/rbyte/sample/rolling_window.py +++ /dev/null @@ -1,48 +0,0 @@ -from datetime import timedelta -from typing import final -from uuid import uuid4 - -import polars as pl -from polars._typing import ClosedInterval -from pydantic import validate_call - - -@final -class RollingWindowSampleBuilder: - """ - Build samples using rolling windows based on a temporal or integer column. - - https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling - """ - - __name__ = __qualname__ - - @validate_call - def __init__( - self, - *, - index_column: str, - period: str | timedelta, - offset: str | timedelta | None = None, - closed: ClosedInterval = "right", - ) -> None: - self._index_column = pl.col(index_column) - self._period = period - self._offset = offset - self._closed: ClosedInterval = closed - - def __call__(self, input: pl.DataFrame) -> pl.DataFrame: - return ( - input.sort(self._index_column) - .with_columns(self._index_column.alias(_index_column := uuid4().hex)) - .rolling( - index_column=_index_column, - period=self._period, - offset=self._offset, - closed=self._closed, - ) - .agg(pl.all()) - .filter(self._index_column.list.len() > 0) - .sort(_index_column) - .drop(_index_column) - ) diff --git a/src/rbyte/scripts/visualize.py b/src/rbyte/scripts/visualize.py index acac24f..9db6ee5 100644 --- a/src/rbyte/scripts/visualize.py +++ b/src/rbyte/scripts/visualize.py @@ -1,8 +1,6 @@ -from multiprocessing.context import ForkServerContext from typing import Any, cast import hydra -import torch.multiprocessing as mp from hydra.utils import instantiate from omegaconf import DictConfig from structlog import get_logger @@ -19,9 +17,6 @@ def main(config: DictConfig) -> None: logger = cast(Logger[Any], instantiate(config.logger)) dataloader = cast(DataLoader[Any], instantiate(config.dataloader)) - if isinstance(dataloader.multiprocessing_context, ForkServerContext): # pyright: ignore[reportUnknownMemberType] - mp.set_forkserver_preload(["rbyte"]) - for batch_idx, batch in enumerate(tqdm(dataloader)): logger.log(batch_idx, batch)