Skip to content

Bug with loading state for SamplerWrapper + StatefulDistributedSampler #1513

@gorold

Description

@gorold

🐛 Describe the bug

import torch
import torchdata.nodes as tn
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler, RandomSampler


class BoringDataset(torch.utils.data.Dataset):
    def __init__(self, n: int):
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return idx


sampler = tn.SamplerWrapper(
    StatefulDistributedSampler(
        BoringDataset(10),
        num_replicas=2,
        rank=0,
        shuffle=True,
        drop_last=False,
        seed=42,
    )
)


sampler.reset()

next(sampler)

state_dict = sampler.state_dict()

print(list(sampler))

sampler.reset(); list(sampler)

sampler.reset(state_dict)

print(sampler.epoch, sampler.sampler.epoch)

print(list(sampler))

state_dict = sampler.state_dict()

sampler.reset()
print(list(sampler))

sampler.reset(state_dict)
sampler.reset()

print(list(sampler))

"""
Output
[1, 4, 0, 3] <- should be the same
0 1 <- should be 0 0
[9, 5, 6, 2] <- should be the same
[8, 9, 5, 6, 2]
[] <- should not be empty
"""
  1. Loading a state_dict which is partially through iteration does not call set_epoch (nor is it saved in the underlying sampler's state_dict), leading to different outputs.
  2. After loading an exhausted iterator, calling reset() again should refresh it, but it doesn't.

Versions

torchdata == 0.11

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions