|
1 | 1 | from collections.abc import Mapping, Sequence
|
2 | 2 | from enum import StrEnum, unique
|
3 | 3 | from functools import cache
|
4 |
| -from typing import Annotated |
| 4 | +from typing import Annotated, Literal, override |
5 | 5 |
|
6 | 6 | import polars as pl
|
7 | 7 | import torch
|
8 | 8 | from hydra.utils import instantiate
|
9 | 9 | from pipefunc import Pipeline
|
10 |
| -from pydantic import Field, StringConstraints, validate_call |
| 10 | +from pydantic import ConfigDict, Field, StringConstraints, validate_call |
11 | 11 | from structlog import get_logger
|
12 | 12 | from structlog.contextvars import bound_contextvars
|
13 | 13 | from tensordict import TensorDict
|
14 | 14 | from torch.utils.data import Dataset as TorchDataset
|
15 | 15 |
|
16 |
| -from rbyte.batch import Batch, BatchMeta |
| 16 | +from rbyte.batch import BATCH_KEYS_DEFAULT, Batch, BatchKeys, BatchMeta |
17 | 17 | from rbyte.config import BaseModel, HydraConfig
|
18 | 18 | from rbyte.io.base import TensorSource
|
19 | 19 | from rbyte.utils.tensor import pad_sequence
|
@@ -53,7 +53,14 @@ class Column(StrEnum):
|
53 | 53 | source_index_column = "__source.index_column"
|
54 | 54 |
|
55 | 55 |
|
56 |
| -class Dataset(TorchDataset[TensorDict]): |
| 56 | +class _ALL_TYPE: # noqa: N801 |
| 57 | + pass |
| 58 | + |
| 59 | + |
| 60 | +_ALL = _ALL_TYPE() |
| 61 | + |
| 62 | + |
| 63 | +class Dataset(TorchDataset[Batch]): |
57 | 64 | @validate_call(config=BaseModel.model_config)
|
58 | 65 | def __init__(
|
59 | 66 | self, inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)]
|
@@ -136,52 +143,112 @@ def sources(self) -> pl.DataFrame:
|
136 | 143 | def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301
|
137 | 144 | return HydraConfig[TensorSource].model_validate_json(config).instantiate()
|
138 | 145 |
|
139 |
| - def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201 |
140 |
| - samples = self.samples[indexes] |
141 |
| - batch_size = [samples.height] |
| 146 | + @validate_call( |
| 147 | + config=ConfigDict(arbitrary_types_allowed=True, validate_default=False) |
| 148 | + ) |
| 149 | + def get_batch( |
| 150 | + self, |
| 151 | + index: int | Sequence[int] | slice | range, |
| 152 | + *, |
| 153 | + keys: BatchKeys = BATCH_KEYS_DEFAULT, |
| 154 | + ) -> Batch: |
| 155 | + subkeys: Mapping[Literal["data", "meta"], set[_ALL_TYPE | str]] = { |
| 156 | + "data": set(), |
| 157 | + "meta": set(), |
| 158 | + } |
| 159 | + for key in keys: |
| 160 | + match key: |
| 161 | + case "data" | "meta": |
| 162 | + subkeys[key].add(_ALL) |
142 | 163 |
|
143 |
| - source_idx_cols = self._sources[Column.source_index_column].unique() |
| 164 | + case ("data" | "meta", _): |
| 165 | + subkeys[key[0]].add(key[1]) |
144 | 166 |
|
145 |
| - sources = ( |
146 |
| - samples.lazy() |
147 |
| - .join(self.sources.lazy(), on=Column.input_id, how="left") |
148 |
| - .with_columns( |
149 |
| - pl.coalesce( |
150 |
| - pl.when(pl.col(Column.source_index_column) == idx_col).then(idx_col) |
151 |
| - for idx_col in source_idx_cols |
152 |
| - ).alias(Column.source_idxs) |
| 167 | + for v in subkeys.values(): |
| 168 | + if _ALL in v and len(v) > 1: |
| 169 | + v.remove(_ALL) |
| 170 | + |
| 171 | + samples = self.samples[index] |
| 172 | + batch_size = [samples.height] |
| 173 | + |
| 174 | + if subkeys_data := subkeys["data"]: |
| 175 | + source_idx_cols = self._sources[Column.source_index_column].unique() |
| 176 | + sources = ( |
| 177 | + samples.lazy() |
| 178 | + .join(self.sources.lazy(), on=Column.input_id, how="left") |
| 179 | + .with_columns( |
| 180 | + pl.coalesce( |
| 181 | + pl.when(pl.col(Column.source_index_column) == idx_col).then( |
| 182 | + idx_col |
| 183 | + ) |
| 184 | + for idx_col in source_idx_cols |
| 185 | + ).alias(Column.source_idxs) |
| 186 | + ) |
| 187 | + .group_by(Column.source_id) |
| 188 | + .agg(Column.source_config, Column.source_idxs) |
| 189 | + .filter( |
| 190 | + True |
| 191 | + if _ALL in subkeys_data |
| 192 | + else pl.col(Column.source_id).is_in(subkeys_data) |
| 193 | + ) |
153 | 194 | )
|
154 |
| - .group_by(Column.source_id) |
155 |
| - .agg(Column.source_config, Column.source_idxs) |
156 |
| - ) |
157 | 195 |
|
158 |
| - tensor_data: Mapping[str, torch.Tensor] = { |
159 |
| - row[Column.source_id]: pad_sequence( |
160 |
| - [ |
161 |
| - self._get_source(source)[idxs] |
162 |
| - for (source, idxs) in zip( |
163 |
| - row[Column.source_config], row[Column.source_idxs], strict=True |
164 |
| - ) |
165 |
| - ], |
166 |
| - dim=1, |
167 |
| - value=torch.nan, |
| 196 | + source_data = { |
| 197 | + row[Column.source_id]: pad_sequence( |
| 198 | + [ |
| 199 | + self._get_source(source)[idxs] |
| 200 | + for (source, idxs) in zip( |
| 201 | + row[Column.source_config], |
| 202 | + row[Column.source_idxs], |
| 203 | + strict=True, |
| 204 | + ) |
| 205 | + ], |
| 206 | + dim=1, |
| 207 | + value=torch.nan, |
| 208 | + ) |
| 209 | + for row in sources.collect().iter_rows(named=True) |
| 210 | + } |
| 211 | + |
| 212 | + sample_data_cols = ( |
| 213 | + pl.all() |
| 214 | + if _ALL in subkeys_data |
| 215 | + else pl.col(subkeys_data - source_data.keys()) # pyright: ignore[reportArgumentType] |
| 216 | + ).exclude(Column.sample_idx, Column.input_id) |
| 217 | + |
| 218 | + sample_data = samples.select(sample_data_cols.to_physical()).to_dict( |
| 219 | + as_series=False |
168 | 220 | )
|
169 |
| - for row in sources.collect().iter_rows(named=True) |
170 |
| - } |
171 | 221 |
|
172 |
| - sample_data: Mapping[str, Sequence[object]] = samples.select( |
173 |
| - pl.exclude(Column.sample_idx, Column.input_id).to_physical() |
174 |
| - ).to_dict(as_series=False) |
| 222 | + data = TensorDict(source_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType] |
| 223 | + |
| 224 | + else: |
| 225 | + data = None |
| 226 | + |
| 227 | + if subkeys_meta := subkeys["meta"]: |
| 228 | + meta = BatchMeta( |
| 229 | + sample_idx=( |
| 230 | + samples[Column.sample_idx].to_torch() |
| 231 | + if _ALL in subkeys_meta or "sample_idx" in subkeys_meta |
| 232 | + else None |
| 233 | + ), |
| 234 | + input_id=( |
| 235 | + samples[Column.input_id].to_list() |
| 236 | + if _ALL in subkeys_meta or "input_id" in subkeys_meta |
| 237 | + else None |
| 238 | + ), |
| 239 | + batch_size=batch_size, |
| 240 | + ) |
| 241 | + else: |
| 242 | + meta = None |
175 | 243 |
|
176 |
| - data = TensorDict(tensor_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType] |
| 244 | + return Batch(data=data, meta=meta, batch_size=batch_size) |
177 | 245 |
|
178 |
| - meta = BatchMeta( |
179 |
| - sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue] |
180 |
| - input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue] |
181 |
| - batch_size=batch_size, # pyright: ignore[reportCallIssue] |
182 |
| - ) |
| 246 | + def __getitems__(self, index: Sequence[int]) -> Batch: # noqa: PLW3201 |
| 247 | + return self.get_batch(index) |
183 | 248 |
|
184 |
| - return Batch(data=data, meta=meta, batch_size=batch_size) # pyright: ignore[reportCallIssue] |
| 249 | + @override |
| 250 | + def __getitem__(self, index: int) -> Batch: |
| 251 | + return self.get_batch(index) |
185 | 252 |
|
186 | 253 | def __len__(self) -> int:
|
187 | 254 | return len(self.samples)
|
0 commit comments