Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 6 additions & 227 deletions src/sedpack/io/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for a dataset."""
import itertools
import json
import logging
from pathlib import Path
import random
import semver

from typing import Callable, Iterator, Union
from typing import Iterator, Union

import sedpack
from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo
from sedpack.io.types import SplitT
from sedpack.io.metadata import DatasetInfo, DatasetStructure, Metadata
from sedpack.io.shard_file_metadata import ShardInfo
from sedpack.io.shard_info_iterator import ShardInfoIterator
from sedpack.io.types import SplitT


class DatasetBase:
Expand Down Expand Up @@ -137,12 +135,6 @@ def dataset_structure(self) -> DatasetStructure:
def dataset_structure(self, value: DatasetStructure) -> None:
self._dataset_info.dataset_structure = value

@property
def logger(self) -> logging.Logger:
"""Get the logger.
"""
return self._logger

def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
"""Iterate all `ShardInfo` in the split.

Expand All @@ -160,221 +152,8 @@ def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
raise ValueError(f"There is no shard in {split}.")

return ShardInfoIterator(
dataset_path=self.path,
dataset_info=self.dataset_info,
split=split,
dataset=self,
repeat=False,
)


class ShardInfoIterator:
"""Iterate shards of a dataset.
"""

def __init__(
self,
*,
split: SplitT | None,
dataset: DatasetBase,
repeat: bool = False,
) -> None:
"""Initialize shard information iteration.

Args:

split (SplitT | None): Which split to iterate or all if set to None.

dataset (DatasetBase): The dataset being iterated.

repeat (bool): Should we cycle indefinitely?
"""
self.split: SplitT | None = split
self.dataset: DatasetBase = dataset
self.repeat: bool = repeat

self._iterator: Iterator[ShardInfo] = iter([])

def __len__(self) -> int:
"""Either return the number of ShardInfo objects iterated or raise a
ValueError if infinite cycle.
"""
if self.number_of_shards() == 0 or not self.repeat:
return self.number_of_shards()
raise ValueError("Infinite iteration")

def number_of_shards(self) -> int:
"""Return the number of distinct shards that are iterated. When
repeated this method still returns a finite answer.
"""
# Single split.
if self.split is None:
# Sum all splits.
return sum(shard_list_info.number_of_shards for shard_list_info in
self.dataset.dataset_info.splits.values())

if self.split not in self.dataset.dataset_info.splits:
return 0

return self.dataset.dataset_info.splits[self.split].number_of_shards

def _shard_info_iterator(
self, shard_list_info: ShardListInfo) -> Iterator[ShardInfo]:
"""Recursively yield `ShardInfo` from the whole directory tree.
"""
shard_list: ShardsList = ShardsList.model_validate_json(
(self.dataset.path /
shard_list_info.shard_list_info_file.file_path).read_text())

yield from shard_list.shard_files

for child in shard_list.children_shard_lists:
yield from self._shard_info_iterator(child)

def __iter__(self) -> Iterator[ShardInfo]:
"""Return the shard information iterator (reentrant).
"""
if self.split is None:
self._iterator = itertools.chain.from_iterable(
self._shard_info_iterator(shard_list_info) for shard_list_info
in self.dataset.dataset_info.splits.values())
else:
self._iterator = self._shard_info_iterator(
self.dataset.dataset_info.splits[self.split])

return self._iterator

def __next__(self) -> ShardInfo:
"""Get the next item.
"""
return next(self._iterator)


class CachedShardInfoIterator(ShardInfoIterator):
"""Iterate shards of a dataset.
"""

def __init__(
self,
*,
split: SplitT | None,
dataset: DatasetBase,
repeat: bool = False,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
shuffle: int = 0,
) -> None:
"""Initialize shard information iteration.

Args:

split (SplitT | None): Which split to iterate or all if set to None.

dataset (DatasetBase): The dataset being iterated.

repeat (bool): Should we cycle indefinitely?

shards (int | None): If specified limits the dataset to the first
`shards` shards.

custom_metadata_type_limit (int | None): Ignored when None. If
non-zero then limit the number of shards with different
`custom_metadata`. Take only the first `custom_metadata_type_limit`
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (`json.dumps` with `sort_keys`).

shard_filter (Callable[[ShardInfo], bool | None): If present this is
a function taking the ShardInfo and returning True if the shard shall
be used for traversal and False otherwise.

shuffle (int): When set to 0 the iteration is deterministic otherwise
shuffle the shards with a shuffle buffer of at least `shuffle`
elements. Current implementation shuffles all shard information.
"""
super().__init__(
split=split,
dataset=dataset,
repeat=repeat,
)

self.shuffle: int = shuffle

# Cache the list of shards.
shard_list: list[ShardInfo] = list(
ShardInfoIterator(
split=split,
dataset=dataset,
repeat=False,
))

# Filter if needed.
if shard_filter:
shard_list = [
shard_info for shard_info in shard_list
if shard_filter(shard_info)
]

kept_metadata: set[str] = {
json.dumps(
s.custom_metadata,
sort_keys=True,
) for s in shard_list
}
self.dataset.logger.info(
"Filtered shards with custom metadata: %s from split: %s",
kept_metadata,
split,
)

# Only use a limited amount of shards for each setting of
# custom_metadata.
if custom_metadata_type_limit:
counts: dict[str, int] = {}
old_shards_list = shard_list
shard_list = []
for shard_info in old_shards_list:
k: str = json.dumps(
shard_info.custom_metadata,
sort_keys=True,
)
counts[k] = counts.get(k, 0) + 1
if counts[k] <= custom_metadata_type_limit:
shard_list.append(shard_info)
self.dataset.logger.info("Took %s shards total", len(shard_list))

# Limit the number of shards.
if shards:
shard_list = shard_list[:shards]

# Initial shuffling.
if shuffle:
random.shuffle(shard_list)

# Cached shards.
self._index: int = -1 # The last returned element.
self._shards: list[ShardInfo] = shard_list

def number_of_shards(self) -> int:
"""Return the number of distinct shards that are iterated. When
repeated this method still returns a finite answer.
"""
return len(self._shards)

def __iter__(self) -> Iterator[ShardInfo]:
"""Return the shard information iterator (reentrant).
"""
return self

def __next__(self) -> ShardInfo:
"""Get the next item.
"""
self._index += 1

if self._index >= len(self._shards):
if self.repeat:
self._index = 0
if self.shuffle:
random.shuffle(self._shards)
else:
raise StopIteration

return self._shards[self._index]
9 changes: 6 additions & 3 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
import numpy as np
import tensorflow as tf

from sedpack.io.dataset_base import CachedShardInfoIterator, DatasetBase
from sedpack.io.dataset_base import DatasetBase
from sedpack.io.flatbuffer import IterateShardFlatBuffer
from sedpack.io.itertools import LazyPool
from sedpack.io.itertools import round_robin, round_robin_async, shuffle_buffer
from sedpack.io.npz import IterateShardNP
from sedpack.io.shard import IterateShardBase
from sedpack.io.shard.iterate_shard_base import T
from sedpack.io.shard_info_iterator import CachedShardInfoIterator
from sedpack.io.shard_file_metadata import ShardInfo
from sedpack.io.tfrec import IterateShardTFRec
from sedpack.io.tfrec.tfdata import get_from_tfrecord
Expand Down Expand Up @@ -79,8 +80,9 @@ def shard_paths_dataset(
"""
shards_list: list[ShardInfo] = list(
CachedShardInfoIterator(
dataset_path=self.path,
dataset_info=self.dataset_info,
split=split,
dataset=self,
repeat=False,
shards=shards,
custom_metadata_type_limit=custom_metadata_type_limit,
Expand Down Expand Up @@ -625,7 +627,8 @@ def as_numpy_iterator(
`process_record` returns something else). No batching is done.
"""
shard_iterator: Iterable[ShardInfo] = CachedShardInfoIterator(
dataset=self,
dataset_path=self.path,
dataset_info=self.dataset_info,
split=split,
shards=shards,
custom_metadata_type_limit=custom_metadata_type_limit,
Expand Down
22 changes: 22 additions & 0 deletions src/sedpack/io/shard_info_iterator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Iterating of shard information."""

from sedpack.io.shard_info_iterator.shard_info_iterator import ShardInfoIterator
from sedpack.io.shard_info_iterator.cached_shard_info_iterator import CachedShardInfoIterator

__all__ = [
"CachedShardInfoIterator",
"ShardInfoIterator",
]
Loading
Loading