Skip to content

Commit 34c2f16

Browse files
authored
feat: make sample builders optional (#27)
1 parent 304b5ef commit 34c2f16

19 files changed

+210
-193
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- id: pyupgrade
1212

1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.8.0
14+
rev: v0.8.1
1515
hooks:
1616
- id: ruff
1717
args: [--fix]

config/_templates/dataset/carla.yaml

-9
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,4 @@ inputs:
127127
_target_: rbyte.io.DataFrameFilter
128128
predicate: |
129129
`control.throttle` > 0.5
130-
131-
- _target_: pipefunc.PipeFunc
132-
renames:
133-
input: data_filtered
134-
output_name: samples
135-
func:
136-
_target_: rbyte.RollingWindowSampleBuilder
137-
index_column: _idx_
138-
period: 1i
139130
#@ end

config/_templates/dataset/mimicgen.yaml

-10
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,5 @@ inputs:
5858
func:
5959
_target_: rbyte.io.DataFrameConcater
6060
method: vertical
61-
62-
- _target_: pipefunc.PipeFunc
63-
renames:
64-
input: data_concated
65-
output_name: samples
66-
func:
67-
_target_: rbyte.RollingWindowSampleBuilder
68-
index_column: _idx_
69-
period: 1i
70-
7161
#@ end
7262
#@ end

config/_templates/dataset/nuscenes/mcap.yaml

-9
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,4 @@ inputs:
105105
_target_: rbyte.io.DataFrameFilter
106106
predicate: |
107107
`/odom/vel.x` >= 8
108-
109-
- _target_: pipefunc.PipeFunc
110-
renames:
111-
input: data_filtered
112-
output_name: samples
113-
func:
114-
_target_: rbyte.RollingWindowSampleBuilder
115-
index_column: (@=camera_topics.values()[0]@)/_idx_
116-
period: 1i
117108
#@ end

config/_templates/dataset/nuscenes/rrd.yaml

-9
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,4 @@ inputs:
9999
_target_: rbyte.io.DataFrameFilter
100100
predicate: |
101101
`/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50'
102-
103-
- _target_: pipefunc.PipeFunc
104-
renames:
105-
input: data_filtered
106-
output_name: samples
107-
func:
108-
_target_: rbyte.RollingWindowSampleBuilder
109-
index_column: (@=camera_entities.values()[0]@)/_idx_
110-
period: 1i
111102
#@ end

pyproject.toml

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
[project]
22
name = "rbyte"
3-
version = "0.9.1"
3+
version = "0.10.0"
44
description = "Multimodal PyTorch dataset library"
55
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
66
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
77
dependencies = [
88
"tensordict>=0.6.2",
99
"torch",
1010
"numpy",
11-
"polars>=1.15.0",
11+
"polars>=1.16.0",
1212
"pydantic>=2.10.2",
1313
"more-itertools>=10.5.0",
1414
"hydra-core>=1.3.2",
1515
"optree>=0.13.1",
1616
"cachetools>=5.5.0",
1717
"diskcache>=5.6.3",
18-
"jaxtyping>=0.2.34",
1918
"parse>=1.20.2",
2019
"structlog>=24.4.0",
2120
"xxhash>=3.5.0",
@@ -40,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte"
4039

4140
[project.optional-dependencies]
4241
build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"]
43-
visualize = ["rerun-sdk[notebook]>=0.20.0"]
42+
visualize = ["rerun-sdk[notebook]>=0.20.2"]
4443
mcap = [
4544
"mcap>=1.2.1",
4645
"mcap-ros2-support>=0.5.5",
@@ -54,7 +53,7 @@ video = [
5453
"video-reader-rs>=0.2.1",
5554
]
5655
hdf5 = ["h5py>=3.12.1"]
57-
rrd = ["rerun-sdk>=0.20.0", "pyarrow-stubs"]
56+
rrd = ["rerun-sdk>=0.20.2", "pyarrow-stubs"]
5857

5958
[project.scripts]
6059
rbyte-visualize = 'rbyte.scripts.visualize:main'
@@ -72,7 +71,7 @@ dev-dependencies = [
7271
"wat-inspector>=0.4.3",
7372
"lovely-tensors>=0.1.18",
7473
"pudb>=2024.1.2",
75-
"ipython>=8.29.0",
74+
"ipython>=8.30.0",
7675
"ipython-autoimport>=0.5",
7776
"pytest>=8.3.3",
7877
"testbook>=0.4.2",
@@ -129,7 +128,18 @@ skip-magic-trailing-comma = true
129128
preview = true
130129
select = ["ALL"]
131130
fixable = ["ALL"]
132-
ignore = ["A001", "A002", "D", "CPY", "COM812", "F722", "PD901", "ISC001", "TD"]
131+
ignore = [
132+
"A001",
133+
"A002",
134+
"D",
135+
"CPY",
136+
"COM812",
137+
"F722",
138+
"PD901",
139+
"ISC001",
140+
"TD",
141+
"TC006",
142+
]
133143

134144
[tool.ruff.lint.isort]
135145
split-on-trailing-comma = false

src/rbyte/io/_mcap/tensor_source.py

+48-33
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from dataclasses import dataclass
33
from functools import cached_property
44
from mmap import ACCESS_READ, mmap
5-
from typing import IO, override
5+
from operator import itemgetter
6+
from typing import IO, final, override
67

78
import more_itertools as mit
89
import numpy.typing as npt
910
import torch
10-
from jaxtyping import Shaped
1111
from mcap.data_stream import ReadDataStream
1212
from mcap.decoder import DecoderFactory
1313
from mcap.opcode import Opcode
@@ -32,6 +32,7 @@ class MessageIndex:
3232
message_length: int
3333

3434

35+
@final
3536
class McapTensorSource(TensorSource):
3637
@validate_call(config=BaseModel.model_config)
3738
def __init__(
@@ -47,8 +48,8 @@ def __init__(
4748
with bound_contextvars(
4849
path=path.as_posix(), topic=topic, message_decoder_factory=decoder_factory
4950
):
50-
self._path: FilePath = path
51-
self._validate_crcs: bool = validate_crcs
51+
self._path = path
52+
self._validate_crcs = validate_crcs
5253

5354
summary = SeekingReader(
5455
stream=self._file, validate_crcs=self._validate_crcs
@@ -73,13 +74,14 @@ def __init__(
7374
logger.error(msg := "missing message decoder")
7475
raise RuntimeError(msg)
7576

76-
self._message_decoder: Callable[[bytes], object] = message_decoder
77-
self._chunk_indexes: tuple[ChunkIndex, ...] = tuple(
77+
self._message_decoder = message_decoder
78+
self._chunk_indexes = tuple(
7879
chunk_index
7980
for chunk_index in summary.chunk_indexes
8081
if self._channel.id in chunk_index.message_index_offsets
8182
)
82-
self._decoder: Callable[[bytes], npt.ArrayLike] = decoder
83+
self._decoder = decoder
84+
self._mmap = None
8385

8486
@property
8587
def _file(self) -> IO[bytes]:
@@ -89,42 +91,55 @@ def _file(self) -> IO[bytes]:
8991

9092
case None | mmap(closed=True):
9193
with self._path.open("rb") as f:
92-
self._mmap: mmap = mmap(
93-
fileno=f.fileno(), length=0, access=ACCESS_READ
94-
)
94+
self._mmap = mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
9595

9696
case _:
9797
raise RuntimeError
9898

9999
return self._mmap # pyright: ignore[reportReturnType]
100100

101101
@override
102-
def __getitem__(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]:
103-
frames: Mapping[int, npt.ArrayLike] = {}
104-
105-
message_indexes_by_chunk_start_offset: Mapping[
106-
int, Iterable[tuple[int, MessageIndex]]
107-
] = mit.map_reduce(
108-
zip(indexes, (self._message_indexes[idx] for idx in indexes), strict=True),
109-
keyfunc=lambda x: x[1].chunk_start_offset,
110-
)
102+
def __getitem__(self, indexes: int | Iterable[int]) -> Tensor:
103+
match indexes:
104+
case Iterable():
105+
arrays: Mapping[int, npt.ArrayLike] = {}
106+
message_indexes = (self._message_indexes[idx] for idx in indexes)
107+
indexes_by_chunk_start_offset = mit.map_reduce(
108+
zip(indexes, message_indexes, strict=True),
109+
keyfunc=lambda x: x[1].chunk_start_offset,
110+
)
111+
112+
for chunk_start_offset, chunk_indexes in sorted(
113+
indexes_by_chunk_start_offset.items(), key=itemgetter(0)
114+
):
115+
_ = self._file.seek(chunk_start_offset + 1 + 8)
116+
chunk = Chunk.read(ReadDataStream(self._file))
117+
stream, _ = get_chunk_data_stream(
118+
chunk, validate_crc=self._validate_crcs
119+
)
120+
for index, message_index in sorted(
121+
chunk_indexes, key=lambda x: x[1].message_start_offset
122+
):
123+
stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult]
124+
message = Message.read(stream, message_index.message_length)
125+
decoded_message = self._message_decoder(message.data)
126+
arrays[index] = self._decoder(decoded_message.data)
111127

112-
for (
113-
chunk_start_offset,
114-
chunk_message_indexes,
115-
) in message_indexes_by_chunk_start_offset.items():
116-
self._file.seek(chunk_start_offset + 1 + 8) # pyright: ignore[reportUnusedCallResult]
117-
chunk = Chunk.read(ReadDataStream(self._file))
118-
stream, _ = get_chunk_data_stream(chunk, validate_crc=self._validate_crcs)
119-
for frame_index, message_index in sorted(
120-
chunk_message_indexes, key=lambda x: x[1].message_start_offset
121-
):
122-
stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult]
123-
message = Message.read(stream, message_index.message_length)
128+
tensors = [torch.from_numpy(arrays[idx]) for idx in indexes] # pyright: ignore[reportUnknownMemberType]
129+
130+
return torch.stack(tensors)
131+
132+
case _:
133+
message_index = self._message_indexes[indexes]
134+
_ = self._file.seek(message_index.chunk_start_offset + 1 + 8)
135+
chunk = Chunk.read(ReadDataStream(self._file))
136+
stream, _ = get_chunk_data_stream(chunk, self._validate_crcs)
137+
_ = stream.read(message_index.message_start_offset - stream.count)
138+
message = Message.read(stream, length=message_index.message_length)
124139
decoded_message = self._message_decoder(message.data)
125-
frames[frame_index] = self._decoder(decoded_message.data) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
140+
array = self._decoder(decoded_message.data)
126141

127-
return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType]
142+
return torch.from_numpy(array) # pyright: ignore[reportUnknownMemberType]
128143

129144
@override
130145
def __len__(self) -> int:

src/rbyte/io/_numpy/tensor_source.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import cached_property
33
from os import PathLike
44
from pathlib import Path
5-
from typing import TYPE_CHECKING, override
5+
from typing import final, override
66

77
import numpy as np
88
import torch
@@ -16,34 +16,37 @@
1616
from rbyte.io.base import TensorSource
1717
from rbyte.utils.tensor import pad_sequence
1818

19-
if TYPE_CHECKING:
20-
from types import EllipsisType
21-
2219

20+
@final
2321
class NumpyTensorSource(TensorSource):
2422
@validate_call(config=BaseModel.model_config)
2523
def __init__(
2624
self, path: PathLike[str], select: Sequence[str] | None = None
2725
) -> None:
2826
super().__init__()
2927

30-
self._path: Path = Path(path)
31-
self._select: Sequence[str] | EllipsisType = select or ...
28+
self._path = Path(path)
29+
self._select = select or ...
3230

3331
@cached_property
3432
def _path_posix(self) -> str:
3533
return self._path.resolve().as_posix()
3634

35+
def _getitem(self, index: object) -> Tensor:
36+
path = self._path_posix.format(index)
37+
array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType]
38+
return torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
39+
3740
@override
38-
def __getitem__(self, indexes: Iterable[object]) -> Tensor:
39-
tensors: list[Tensor] = []
40-
for index in indexes:
41-
path = self._path_posix.format(index)
42-
array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType]
43-
tensor = torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
44-
tensors.append(tensor)
45-
46-
return pad_sequence(list(tensors), dim=0, value=torch.nan)
41+
def __getitem__(self, indexes: object | Iterable[object]) -> Tensor:
42+
match indexes:
43+
case Iterable():
44+
tensors = map(self._getitem, indexes) # pyright: ignore[reportUnknownArgumentType]
45+
46+
return pad_sequence(list(tensors), dim=0, value=torch.nan)
47+
48+
case _:
49+
return self._getitem(indexes)
4750

4851
@override
4952
def __len__(self) -> int:

src/rbyte/io/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from collections.abc import Iterable
2-
from typing import Any, Protocol, runtime_checkable
1+
from collections.abc import Sequence
2+
from typing import Protocol, runtime_checkable
33

44
from torch import Tensor
55

66

77
@runtime_checkable
88
class TensorSource(Protocol):
9-
def __getitem__(self, indexes: Iterable[Any]) -> Tensor: ...
9+
def __getitem__[T](self, indexes: T | Sequence[T]) -> Tensor: ...
1010
def __len__(self) -> int: ...

src/rbyte/io/dataframe/aligner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class MergeConfig(BaseModel):
4040
columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict)
4141

4242

43-
type Fields = MergeConfig | OrderedDict[str, "Fields"]
43+
type Fields = MergeConfig | OrderedDict[str, Fields]
4444

4545

4646
@final

src/rbyte/io/hdf5/dataframe_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from pydantic import ConfigDict, validate_call
1515

16-
type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, "Fields"]
16+
type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, Fields]
1717

1818

1919
@final

src/rbyte/io/hdf5/tensor_source.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
from collections.abc import Iterable
2-
from typing import cast, override
1+
from collections.abc import Sequence
2+
from typing import cast, final, override
33

44
import torch
55
from h5py import Dataset, File
6-
from jaxtyping import UInt8
76
from pydantic import FilePath, validate_call
87
from torch import Tensor
98

109
from rbyte.io.base import TensorSource
1110

1211

12+
@final
1313
class Hdf5TensorSource(TensorSource):
1414
@validate_call
1515
def __init__(self, path: FilePath, key: str) -> None:
16-
file = File(path)
17-
self._dataset: Dataset = cast(Dataset, file[key])
16+
self._dataset = cast(Dataset, File(path)[key])
1817

1918
@override
20-
def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]:
19+
def __getitem__(self, indexes: int | Sequence[int]) -> Tensor:
2120
return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
2221

2322
@override

0 commit comments

Comments
 (0)