Skip to content

Commit a61a037

Browse files
authored
feat: sample builder ergonomics (#28)
- add `gather_every` (~strides) and `length` to FixedWindowSampleBuilder - remove RollingWindowSampleBuilder
1 parent 34c2f16 commit a61a037

File tree

8 files changed

+18
-78
lines changed

8 files changed

+18
-78
lines changed

config/_templates/dataloader/torch.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ collate_fn:
1010
num_workers: 1
1111
pin_memory: false
1212
persistent_workers: true
13-
multiprocessing_context: forkserver
13+
multiprocessing_context: spawn

config/_templates/dataset/yaak.yaml

+1-9
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,7 @@ inputs:
134134
index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx
135135
every: 6i
136136
period: 6i
137-
138-
- _target_: pipefunc.PipeFunc
139-
renames:
140-
input: samples
141-
output_name: samples_filtered
142-
func:
143-
_target_: rbyte.io.DataFrameFilter
144-
predicate: |
145-
array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6
137+
length: 6
146138

147139
kwargs:
148140
meta:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rbyte"
3-
version = "0.10.0"
3+
version = "0.10.1"
44
description = "Multimodal PyTorch dataset library"
55
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
66
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]

src/rbyte/__init__.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
from importlib.metadata import version
22

33
from .dataset import Dataset
4-
from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder
4+
from .sample import FixedWindowSampleBuilder
55

66
__version__ = version(__package__ or __name__)
77

8-
__all__ = [
9-
"Dataset",
10-
"FixedWindowSampleBuilder",
11-
"RollingWindowSampleBuilder",
12-
"__version__",
13-
]
8+
__all__ = ["Dataset", "FixedWindowSampleBuilder", "__version__"]

src/rbyte/sample/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .fixed_window import FixedWindowSampleBuilder
2-
from .rolling_window import RollingWindowSampleBuilder
32

4-
__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"]
3+
__all__ = ["FixedWindowSampleBuilder"]

src/rbyte/sample/fixed_window.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import polars as pl
66
from polars._typing import ClosedInterval
7-
from pydantic import validate_call
7+
from pydantic import PositiveInt, validate_call
88

99

1010
@final
@@ -19,18 +19,26 @@ class FixedWindowSampleBuilder:
1919
__name__ = __qualname__
2020

2121
@validate_call
22-
def __init__(
22+
def __init__( # noqa: PLR0913
2323
self,
2424
*,
2525
index_column: str,
2626
every: str | timedelta,
2727
period: str | timedelta | None = None,
2828
closed: ClosedInterval = "left",
29+
gather_every: PositiveInt = 1,
30+
length: PositiveInt | None = None,
2931
) -> None:
3032
self._index_column = pl.col(index_column)
3133
self._every = every
3234
self._period = period
3335
self._closed: ClosedInterval = closed
36+
self._gather_every = gather_every
37+
self._length_filter = (
38+
(self._index_column.list.len() > 0)
39+
if length is None
40+
else (self._index_column.list.len() == length)
41+
)
3442

3543
def __call__(self, input: pl.DataFrame) -> pl.DataFrame:
3644
return (
@@ -44,8 +52,7 @@ def __call__(self, input: pl.DataFrame) -> pl.DataFrame:
4452
label="datapoint",
4553
start_by="datapoint",
4654
)
47-
.agg(pl.all())
48-
.filter(self._index_column.list.len() > 0)
49-
.sort(_index_column)
55+
.agg(pl.all().gather_every(self._gather_every))
56+
.filter(self._length_filter)
5057
.drop(_index_column)
5158
)

src/rbyte/sample/rolling_window.py

-48
This file was deleted.

src/rbyte/scripts/visualize.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from multiprocessing.context import ForkServerContext
21
from typing import Any, cast
32

43
import hydra
5-
import torch.multiprocessing as mp
64
from hydra.utils import instantiate
75
from omegaconf import DictConfig
86
from structlog import get_logger
@@ -19,9 +17,6 @@ def main(config: DictConfig) -> None:
1917
logger = cast(Logger[Any], instantiate(config.logger))
2018
dataloader = cast(DataLoader[Any], instantiate(config.dataloader))
2119

22-
if isinstance(dataloader.multiprocessing_context, ForkServerContext): # pyright: ignore[reportUnknownMemberType]
23-
mp.set_forkserver_preload(["rbyte"])
24-
2520
for batch_idx, batch in enumerate(tqdm(dataloader)):
2621
logger.log(batch_idx, batch)
2722

0 commit comments

Comments
 (0)