-
Notifications
You must be signed in to change notification settings - Fork 170
Description
🐛 Describe the bug
import torch
import torchdata.nodes as tn
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler, RandomSampler
class BoringDataset(torch.utils.data.Dataset):
def __init__(self, n: int):
self.n = n
def __len__(self):
return self.n
def __getitem__(self, idx):
return idx
sampler = tn.SamplerWrapper(
StatefulDistributedSampler(
BoringDataset(10),
num_replicas=2,
rank=0,
shuffle=True,
drop_last=False,
seed=42,
)
)
sampler.reset()
next(sampler)
state_dict = sampler.state_dict()
print(list(sampler))
sampler.reset(); list(sampler)
sampler.reset(state_dict)
print(sampler.epoch, sampler.sampler.epoch)
print(list(sampler))
state_dict = sampler.state_dict()
sampler.reset()
print(list(sampler))
sampler.reset(state_dict)
sampler.reset()
print(list(sampler))
"""
Output
[1, 4, 0, 3] <- should be the same
0 1 <- should be 0 0
[9, 5, 6, 2] <- should be the same
[8, 9, 5, 6, 2]
[] <- should not be empty
"""- Loading a state_dict which is partially through iteration does not call set_epoch (nor is it saved in the underlying sampler's state_dict), leading to different outputs.
- After loading an exhausted iterator, calling reset() again should refresh it, but it doesn't.
Versions
torchdata == 0.11
Metadata
Metadata
Assignees
Labels
No labels