Skip to content

Commit df252dd

Browse files
authored
feat: batch key selection (#32)
- expose `Dataset.get_batch(keys=...)`, mimicking `TensorDict.select` - bump `tensordict>=0.7.0`
1 parent 9585c74 commit df252dd

File tree

9 files changed

+241
-76
lines changed

9 files changed

+241
-76
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- id: pyupgrade
1212

1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.9.4
14+
rev: v0.9.5
1515
hooks:
1616
- id: ruff
1717
args: [--fix]

pyproject.toml

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[project]
22
name = "rbyte"
3-
version = "0.11.1"
3+
version = "0.12.0"
44
description = "Multimodal PyTorch dataset library"
55
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
66
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
77
dependencies = [
8-
"tensordict>=0.6.2",
8+
"tensordict>=0.7.0",
99
"torch",
1010
"numpy",
1111
"polars>=1.21.0",
@@ -39,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte"
3939

4040
[project.optional-dependencies]
4141
build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"]
42-
visualize = ["rerun-sdk[notebook]>=0.21.0"]
42+
visualize = ["rerun-sdk[notebook]==0.21.0"]
4343
mcap = [
4444
"mcap>=1.2.1",
4545
"mcap-ros2-support>=0.5.5",
@@ -53,7 +53,7 @@ video = [
5353
"video-reader-rs>=0.2.2",
5454
]
5555
hdf5 = ["h5py>=3.12.1"]
56-
rrd = ["rerun-sdk>=0.21.0", "pyarrow-stubs"]
56+
rrd = ["rerun-sdk==0.21.0", "pyarrow-stubs"]
5757

5858
[project.scripts]
5959
rbyte-visualize = 'rbyte.scripts.visualize:main'
@@ -66,20 +66,18 @@ requires = [
6666
]
6767
build-backend = "hatchling.build"
6868

69-
[tool.uv]
70-
dev-dependencies = [
69+
[dependency-groups]
70+
dev = [
7171
"wat-inspector>=0.4.3",
7272
"lovely-tensors>=0.1.18",
7373
"pudb>=2024.1.2",
74-
"ipython>=8.30.0",
74+
"ipython>=8.32.0",
7575
"ipython-autoimport>=0.5",
76-
"pytest>=8.3.3",
76+
"pytest>=8.3.4",
7777
"testbook>=0.4.2",
7878
"ipykernel>=6.29.5",
7979
]
8080

81-
[tool.uv.sources]
82-
8381
[tool.hatch.metadata]
8482
allow-direct-references = true
8583

src/rbyte/batch.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Literal
2+
3+
from tensordict import (
4+
NonTensorData, # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType]
5+
TensorClass,
6+
TensorDict,
7+
)
8+
from torch import Tensor
9+
10+
11+
class BatchMeta(TensorClass, autocast=True): # pyright: ignore[reportGeneralTypeIssues, reportCallIssue]
12+
sample_idx: Tensor | None = None
13+
input_id: NonTensorData | None = None # pyright: ignore[reportUnknownVariableType]
14+
15+
16+
class Batch(TensorClass, autocast=True): # pyright: ignore[reportGeneralTypeIssues, reportCallIssue]
17+
data: TensorDict | None = None # pyright: ignore[reportIncompatibleMethodOverride]
18+
meta: BatchMeta | None = None
19+
20+
21+
type BatchKeys = frozenset[
22+
Literal["data", "meta"]
23+
| tuple[Literal["data"], str]
24+
| tuple[Literal["meta"], Literal["sample_idx", "input_id"]]
25+
]
26+
27+
BATCH_KEYS_DEFAULT = frozenset(("data", "meta"))

src/rbyte/batch/__init__.py

-3
This file was deleted.

src/rbyte/batch/batch.py

-18
This file was deleted.

src/rbyte/dataset.py

+108-41
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from collections.abc import Mapping, Sequence
22
from enum import StrEnum, unique
33
from functools import cache
4-
from typing import Annotated
4+
from typing import Annotated, Literal, override
55

66
import polars as pl
77
import torch
88
from hydra.utils import instantiate
99
from pipefunc import Pipeline
10-
from pydantic import Field, StringConstraints, validate_call
10+
from pydantic import ConfigDict, Field, StringConstraints, validate_call
1111
from structlog import get_logger
1212
from structlog.contextvars import bound_contextvars
1313
from tensordict import TensorDict
1414
from torch.utils.data import Dataset as TorchDataset
1515

16-
from rbyte.batch import Batch, BatchMeta
16+
from rbyte.batch import BATCH_KEYS_DEFAULT, Batch, BatchKeys, BatchMeta
1717
from rbyte.config import BaseModel, HydraConfig
1818
from rbyte.io.base import TensorSource
1919
from rbyte.utils.tensor import pad_sequence
@@ -53,7 +53,14 @@ class Column(StrEnum):
5353
source_index_column = "__source.index_column"
5454

5555

56-
class Dataset(TorchDataset[TensorDict]):
56+
class _ALL_TYPE: # noqa: N801
57+
pass
58+
59+
60+
_ALL = _ALL_TYPE()
61+
62+
63+
class Dataset(TorchDataset[Batch]):
5764
@validate_call(config=BaseModel.model_config)
5865
def __init__(
5966
self, inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)]
@@ -136,52 +143,112 @@ def sources(self) -> pl.DataFrame:
136143
def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301
137144
return HydraConfig[TensorSource].model_validate_json(config).instantiate()
138145

139-
def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201
140-
samples = self.samples[indexes]
141-
batch_size = [samples.height]
146+
@validate_call(
147+
config=ConfigDict(arbitrary_types_allowed=True, validate_default=False)
148+
)
149+
def get_batch(
150+
self,
151+
index: int | Sequence[int] | slice | range,
152+
*,
153+
keys: BatchKeys = BATCH_KEYS_DEFAULT,
154+
) -> Batch:
155+
subkeys: Mapping[Literal["data", "meta"], set[_ALL_TYPE | str]] = {
156+
"data": set(),
157+
"meta": set(),
158+
}
159+
for key in keys:
160+
match key:
161+
case "data" | "meta":
162+
subkeys[key].add(_ALL)
142163

143-
source_idx_cols = self._sources[Column.source_index_column].unique()
164+
case ("data" | "meta", _):
165+
subkeys[key[0]].add(key[1])
144166

145-
sources = (
146-
samples.lazy()
147-
.join(self.sources.lazy(), on=Column.input_id, how="left")
148-
.with_columns(
149-
pl.coalesce(
150-
pl.when(pl.col(Column.source_index_column) == idx_col).then(idx_col)
151-
for idx_col in source_idx_cols
152-
).alias(Column.source_idxs)
167+
for v in subkeys.values():
168+
if _ALL in v and len(v) > 1:
169+
v.remove(_ALL)
170+
171+
samples = self.samples[index]
172+
batch_size = [samples.height]
173+
174+
if subkeys_data := subkeys["data"]:
175+
source_idx_cols = self._sources[Column.source_index_column].unique()
176+
sources = (
177+
samples.lazy()
178+
.join(self.sources.lazy(), on=Column.input_id, how="left")
179+
.with_columns(
180+
pl.coalesce(
181+
pl.when(pl.col(Column.source_index_column) == idx_col).then(
182+
idx_col
183+
)
184+
for idx_col in source_idx_cols
185+
).alias(Column.source_idxs)
186+
)
187+
.group_by(Column.source_id)
188+
.agg(Column.source_config, Column.source_idxs)
189+
.filter(
190+
True
191+
if _ALL in subkeys_data
192+
else pl.col(Column.source_id).is_in(subkeys_data)
193+
)
153194
)
154-
.group_by(Column.source_id)
155-
.agg(Column.source_config, Column.source_idxs)
156-
)
157195

158-
tensor_data: Mapping[str, torch.Tensor] = {
159-
row[Column.source_id]: pad_sequence(
160-
[
161-
self._get_source(source)[idxs]
162-
for (source, idxs) in zip(
163-
row[Column.source_config], row[Column.source_idxs], strict=True
164-
)
165-
],
166-
dim=1,
167-
value=torch.nan,
196+
source_data = {
197+
row[Column.source_id]: pad_sequence(
198+
[
199+
self._get_source(source)[idxs]
200+
for (source, idxs) in zip(
201+
row[Column.source_config],
202+
row[Column.source_idxs],
203+
strict=True,
204+
)
205+
],
206+
dim=1,
207+
value=torch.nan,
208+
)
209+
for row in sources.collect().iter_rows(named=True)
210+
}
211+
212+
sample_data_cols = (
213+
pl.all()
214+
if _ALL in subkeys_data
215+
else pl.col(subkeys_data - source_data.keys()) # pyright: ignore[reportArgumentType]
216+
).exclude(Column.sample_idx, Column.input_id)
217+
218+
sample_data = samples.select(sample_data_cols.to_physical()).to_dict(
219+
as_series=False
168220
)
169-
for row in sources.collect().iter_rows(named=True)
170-
}
171221

172-
sample_data: Mapping[str, Sequence[object]] = samples.select(
173-
pl.exclude(Column.sample_idx, Column.input_id).to_physical()
174-
).to_dict(as_series=False)
222+
data = TensorDict(source_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType]
223+
224+
else:
225+
data = None
226+
227+
if subkeys_meta := subkeys["meta"]:
228+
meta = BatchMeta(
229+
sample_idx=(
230+
samples[Column.sample_idx].to_torch()
231+
if _ALL in subkeys_meta or "sample_idx" in subkeys_meta
232+
else None
233+
),
234+
input_id=(
235+
samples[Column.input_id].to_list()
236+
if _ALL in subkeys_meta or "input_id" in subkeys_meta
237+
else None
238+
),
239+
batch_size=batch_size,
240+
)
241+
else:
242+
meta = None
175243

176-
data = TensorDict(tensor_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType]
244+
return Batch(data=data, meta=meta, batch_size=batch_size)
177245

178-
meta = BatchMeta(
179-
sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue]
180-
input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue]
181-
batch_size=batch_size, # pyright: ignore[reportCallIssue]
182-
)
246+
def __getitems__(self, index: Sequence[int]) -> Batch: # noqa: PLW3201
247+
return self.get_batch(index)
183248

184-
return Batch(data=data, meta=meta, batch_size=batch_size) # pyright: ignore[reportCallIssue]
249+
@override
250+
def __getitem__(self, index: int) -> Batch:
251+
return self.get_batch(index)
185252

186253
def __len__(self) -> int:
187254
return len(self.samples)

src/rbyte/io/yaak/idl-repo

src/rbyte/viz/loggers/rerun_logger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def _build_components(
188188

189189
@override
190190
def log(self, batch_idx: int, batch: Batch) -> None:
191-
for i, sample in enumerate(batch.data): # pyright: ignore[reportUnknownVariableType]
192-
with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportIndexIssue]
191+
for i, sample in enumerate(batch.data): # pyright: ignore[reportArgumentType, reportUnknownVariableType]
192+
with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportOptionalSubscript, reportUnknownMemberType, reportOptionalMemberAccess]
193193
times: Sequence[TimeColumn] = [
194194
column(
195195
timeline=timeline,

0 commit comments

Comments
 (0)