Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,17 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
By default, if the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

When the number of physical shards is small or imbalanced relative to `world_size`, you can pass `force_sample_level=True` to force sample-level splitting even when shard counts divide evenly. Each node then keeps 1 example out of `world_size`, which avoids duplicating expensive transformations applied after `split_dataset_by_node` across nodes:

```python
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size, force_sample_level=True)
ds = ds.map(tokenize_fn) # only runs on the examples this rank will actually consume
```

This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data.

> [!WARNING]
Expand Down
27 changes: 22 additions & 5 deletions src/datasets/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,32 @@
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)


def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
def split_dataset_by_node(
dataset: DatasetType,
rank: int,
world_size: int,
force_sample_level: bool = False,
) -> DatasetType:
"""
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.

For map-style datasets:

Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
The `force_sample_level` argument is ignored for map-style datasets.

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
By default, if the dataset has a number of shards that is a factor of `world_size` (i.e. if
`dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes,
which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping
the other examples.

Pass `force_sample_level=True` to force sample-level splitting regardless of shard count
(each node keeps 1 example out of `world_size`). This is useful when the number of physical
shards is small or imbalanced relative to `world_size`, so that expensive transformations applied
after `split_dataset_by_node` are not duplicated across nodes.

> [!WARNING]
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`]
Expand All @@ -33,11 +45,16 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
Rank of the current node.
world_size (`int`):
Total number of nodes.
force_sample_level (`bool`, defaults to `False`):
For iterable datasets, force sample-level splitting even when shards divide evenly across
nodes. Ignored for map-style datasets.

Returns:
[`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`.
"""
if isinstance(dataset, Dataset):
return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size)
else:
return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size)
return _split_by_node_iterable_dataset(
dataset, rank=rank, world_size=world_size, force_sample_level=force_sample_level
)
46 changes: 33 additions & 13 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,7 @@ def num_shards(self) -> int:
class DistributedConfig:
rank: int
world_size: int
force_sample_level: bool = False


def _maybe_add_torch_iterable_dataset_parent_class(cls):
Expand Down Expand Up @@ -2666,8 +2667,11 @@ def epoch(self) -> int:

@property
def num_shards(self) -> int:
if self._distributed and self._ex_iterable.num_shards % self._distributed.world_size == 0:
return self._ex_iterable.num_shards // self._distributed.world_size
if self._distributed:
world_size = self._distributed.world_size
divisible = self._ex_iterable.num_shards % world_size == 0
if divisible and not self._distributed.force_sample_level:
return self._ex_iterable.num_shards // world_size
return self._ex_iterable.num_shards

@property
Expand Down Expand Up @@ -2753,7 +2757,8 @@ def _prepare_ex_iterable_for_iteration(
if self._distributed:
rank = self._distributed.rank
world_size = self._distributed.world_size
if ex_iterable.num_shards % world_size == 0:
divisible = ex_iterable.num_shards % world_size == 0
if divisible and not self._distributed.force_sample_level:
if self._is_main_process():
num_shards_per_node = ex_iterable.num_shards // world_size
plural = "s" if num_shards_per_node > 1 else ""
Expand All @@ -2766,11 +2771,12 @@ def _prepare_ex_iterable_for_iteration(
logger.info(
f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration."
)
logger.info(
f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}"
)
if not self._distributed.force_sample_level and not divisible:
logger.info(
f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}"
)
ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)

if ex_iterable.iter_arrow:
Expand Down Expand Up @@ -5341,13 +5347,24 @@ def _interleave_iterable_datasets(
)


def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset:
def _split_by_node_iterable_dataset(
dataset: IterableDataset,
rank: int,
world_size: int,
force_sample_level: bool = False,
) -> IterableDataset:
"""
Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`.

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
By default, if the dataset has a number of shards that is a factor of `world_size` (i.e. if
`dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes,
which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping
the other examples.

Pass `force_sample_level=True` to force sample-level splitting regardless of shard count
(each node keeps 1 example out of `world_size`). This is useful when the number of physical
shards is small or imbalanced relative to `world_size`, so that expensive transformations applied
after `split_dataset_by_node` are not duplicated across nodes.

Args:
dataset ([`IterableDataset`]):
Expand All @@ -5356,14 +5373,17 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
Rank of the current node.
world_size (`int`):
Total number of nodes.
force_sample_level (`bool`, defaults to `False`):
If `True`, force sample-level splitting even when shards divide evenly across nodes.

Returns:
[`IterableDataset`]: The iterable dataset to be used on the node at rank `rank`.
"""
if dataset._distributed:
rank = world_size * dataset._distributed.rank + rank
world_size = world_size * dataset._distributed.world_size
distributed = DistributedConfig(rank=rank, world_size=world_size)
force_sample_level = force_sample_level or dataset._distributed.force_sample_level
distributed = DistributedConfig(rank=rank, world_size=world_size, force_sample_level=force_sample_level)
return IterableDataset(
ex_iterable=dataset._ex_iterable,
info=dataset._info.copy(),
Expand Down
62 changes: 62 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,68 @@ def gen():
assert len({tuple(x.values()) for ds in datasets_per_rank_per_worker for x in ds}) == full_size


def test_split_dataset_by_node_iterable_force_sample_level_when_divisible():
def gen(shards):
for shard in shards:
yield from ({"i": i, "shard": shard} for i in range(4))

world_size = 2
num_shards = 4 # divisible by world_size, so default behavior would be shard-level
gen_kwargs = {"shards": [f"shard_{idx}.txt" for idx in range(num_shards)]}
full_ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs)
assert full_ds.num_shards == num_shards
full_examples = list(full_ds)

datasets_per_rank = [
split_dataset_by_node(full_ds, rank=rank, world_size=world_size, force_sample_level=True)
for rank in range(world_size)
]
# When sample-level is forced, num_shards stays at the underlying value (StepExamplesIterable
# does not collapse shards), and each rank receives the strided slice of examples.
assert [ds.num_shards for ds in datasets_per_rank] == [num_shards] * world_size
for rank, ds in enumerate(datasets_per_rank):
expected = full_examples[rank::world_size]
assert list(ds) == expected


def test_split_dataset_by_node_map_style_ignores_force_sample_level():
full_ds = Dataset.from_dict({"i": range(17)})
world_size = 3
default = [split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size)]
with_flag = [
split_dataset_by_node(full_ds, rank=rank, world_size=world_size, force_sample_level=True)
for rank in range(world_size)
]
for ds_default, ds_flag in zip(default, with_flag):
assert list(ds_default) == list(ds_flag)


def test_split_dataset_by_node_iterable_force_sample_level_chaining():
def gen():
return ({"i": i} for i in range(60))

world_size = 2
num_workers = 3
full_ds = IterableDataset.from_generator(gen)
full_examples = list(full_ds)

# Outer split sets the flag; inner split leaves it False and should inherit it.
datasets_per_rank = [
split_dataset_by_node(full_ds, rank=rank, world_size=world_size, force_sample_level=True)
for rank in range(world_size)
]
datasets_per_rank_per_worker = [
split_dataset_by_node(ds, rank=worker, world_size=num_workers)
for ds in datasets_per_rank
for worker in range(num_workers)
]
assert sum(len(list(ds)) for ds in datasets_per_rank_per_worker) == len(full_examples)
assert len({tuple(x.values()) for ds in datasets_per_rank_per_worker for x in ds}) == len(full_examples)
# Combined rank uses world_size * num_workers strides; verify rank 0/worker 0 sees indices 0, 6, 12, ...
rank0_worker0 = list(datasets_per_rank_per_worker[0])
assert rank0_worker0 == full_examples[0 :: world_size * num_workers]


def test_distributed_shuffle_iterable():
def gen():
return ({"i": i} for i in range(17))
Expand Down