Skip to content

Commit c01d4db

Browse files
committed
formatting
1 parent 932ac1d commit c01d4db

File tree

1 file changed

+71
-48
lines changed

1 file changed

+71
-48
lines changed

megatron/core/datasets/data_samplers.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,27 @@
44

55

66
import random
7-
import torch
7+
88
import numpy as np
9+
import torch
910
from torch.utils.data import Dataset
10-
from megatron.training import get_args
11+
1112
from megatron.core import mpu
1213
from megatron.core.datasets.utils import Split
14+
from megatron.training import get_args
1315
from megatron.training.dist_signal_handler import DistributedSignalHandler
1416

17+
1518
def build_pretraining_data_loader(dataset, consumed_samples):
1619
"""Build dataloader given an input dataset."""
1720

1821
if dataset is None:
1922
return None
2023
args = get_args()
21-
22-
if hasattr(dataset,'split'):
24+
25+
if hasattr(dataset, 'split'):
2326
split = dataset.split
24-
elif hasattr(dataset,'index_split'):
27+
elif hasattr(dataset, 'index_split'):
2528
split = dataset.index_split
2629
else:
2730
split = None
@@ -32,15 +35,17 @@ def build_pretraining_data_loader(dataset, consumed_samples):
3235
consumed_samples=0,
3336
micro_batch_size=args.micro_batch_size,
3437
data_parallel_rank=mpu.get_data_parallel_rank(),
35-
data_parallel_size=mpu.get_data_parallel_world_size())
38+
data_parallel_size=mpu.get_data_parallel_world_size(),
39+
)
3640
elif args.dataloader_type == 'single':
3741
# Megatron sampler
3842
batch_sampler = MegatronPretrainingSampler(
3943
total_samples=len(dataset),
4044
consumed_samples=consumed_samples,
4145
micro_batch_size=args.micro_batch_size,
4246
data_parallel_rank=mpu.get_data_parallel_rank(),
43-
data_parallel_size=mpu.get_data_parallel_world_size())
47+
data_parallel_size=mpu.get_data_parallel_world_size(),
48+
)
4449
elif args.dataloader_type == 'cyclic':
4550
batch_sampler = MegatronPretrainingRandomSampler(
4651
dataset,
@@ -49,51 +54,63 @@ def build_pretraining_data_loader(dataset, consumed_samples):
4954
micro_batch_size=args.micro_batch_size,
5055
data_parallel_rank=mpu.get_data_parallel_rank(),
5156
data_parallel_size=mpu.get_data_parallel_world_size(),
52-
data_sharding=args.data_sharding)
57+
data_sharding=args.data_sharding,
58+
)
5359
elif args.dataloader_type == "external":
5460
# External dataloaders are passed through. User is expected to provide a
5561
# torch-compatible dataloader and define samplers, if needed.
5662
return dataset
5763
else:
58-
raise Exception('{} dataloader type is not supported.'.format(
59-
args.dataloader_type))
64+
raise Exception('{} dataloader type is not supported.'.format(args.dataloader_type))
6065

6166
def worker_init_fn(_):
6267
DistributedSignalHandler(args.exit_signal).__enter__()
63-
maybe_worker_init_fn = worker_init_fn if args.exit_signal_handler and args.num_workers > 0 else None
68+
69+
maybe_worker_init_fn = (
70+
worker_init_fn if args.exit_signal_handler and args.num_workers > 0 else None
71+
)
6472
# Torch dataloader.
65-
return torch.utils.data.DataLoader(dataset,
66-
batch_sampler=batch_sampler,
67-
num_workers=args.num_workers,
68-
pin_memory=True,
69-
persistent_workers=True if args.num_workers > 0 else False,
70-
worker_init_fn=maybe_worker_init_fn,
71-
)
73+
return torch.utils.data.DataLoader(
74+
dataset,
75+
batch_sampler=batch_sampler,
76+
num_workers=args.num_workers,
77+
pin_memory=True,
78+
persistent_workers=True if args.num_workers > 0 else False,
79+
worker_init_fn=maybe_worker_init_fn,
80+
)
81+
7282

7383
class MegatronPretrainingSampler:
7484

75-
def __init__(self, total_samples, consumed_samples, micro_batch_size,
76-
data_parallel_rank, data_parallel_size, drop_last=True):
85+
def __init__(
86+
self,
87+
total_samples,
88+
consumed_samples,
89+
micro_batch_size,
90+
data_parallel_rank,
91+
data_parallel_size,
92+
drop_last=True,
93+
):
7794
# Keep a copy of input params for later use.
7895
self.total_samples = total_samples
7996
self.consumed_samples = consumed_samples
8097
self.micro_batch_size = micro_batch_size
8198
self.data_parallel_rank = data_parallel_rank
82-
self.micro_batch_times_data_parallel_size = \
83-
self.micro_batch_size * data_parallel_size
99+
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
84100
self.drop_last = drop_last
85101

86102
# Sanity checks.
87-
assert self.total_samples > 0, \
88-
'no sample to consume: {}'.format(self.total_samples)
89-
assert self.consumed_samples < self.total_samples, \
90-
'no samples left to consume: {}, {}'.format(self.consumed_samples,
91-
self.total_samples)
103+
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
104+
assert (
105+
self.consumed_samples < self.total_samples
106+
), 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples)
92107
assert self.micro_batch_size > 0
93108
assert data_parallel_size > 0
94-
assert self.data_parallel_rank < data_parallel_size, \
95-
'data_parallel_rank should be smaller than data size: {}, ' \
96-
'{}'.format(self.data_parallel_rank, data_parallel_size)
109+
assert (
110+
self.data_parallel_rank < data_parallel_size
111+
), 'data_parallel_rank should be smaller than data size: {}, ' '{}'.format(
112+
self.data_parallel_rank, data_parallel_size
113+
)
97114

98115
def __len__(self):
99116
return self.total_samples
@@ -143,8 +160,16 @@ def __getitem__(self, idx):
143160

144161
class MegatronPretrainingRandomSampler:
145162

146-
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
147-
data_parallel_rank, data_parallel_size, data_sharding):
163+
def __init__(
164+
self,
165+
dataset,
166+
total_samples,
167+
consumed_samples,
168+
micro_batch_size,
169+
data_parallel_rank,
170+
data_parallel_size,
171+
data_sharding,
172+
):
148173
# Keep a copy of input params for later use.
149174
self.dataset = dataset
150175
self.total_samples = total_samples
@@ -153,19 +178,18 @@ def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
153178
self.data_parallel_rank = data_parallel_rank
154179
self.data_parallel_size = data_parallel_size
155180
self.data_sharding = data_sharding
156-
self.micro_batch_times_data_parallel_size = \
157-
self.micro_batch_size * data_parallel_size
158-
self.last_batch_size = \
159-
self.total_samples % self.micro_batch_times_data_parallel_size
181+
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
182+
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
160183

161184
# Sanity checks.
162-
assert self.total_samples > 0, \
163-
'no sample to consume: {}'.format(self.total_samples)
185+
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
164186
assert self.micro_batch_size > 0
165187
assert data_parallel_size > 0
166-
assert self.data_parallel_rank < data_parallel_size, \
167-
'data_parallel_rank should be smaller than data size: {}, ' \
168-
'{}'.format(self.data_parallel_rank, data_parallel_size)
188+
assert (
189+
self.data_parallel_rank < data_parallel_size
190+
), 'data_parallel_rank should be smaller than data size: {}, ' '{}'.format(
191+
self.data_parallel_rank, data_parallel_size
192+
)
169193

170194
def __len__(self):
171195
return self.total_samples
@@ -181,8 +205,9 @@ def __iter__(self):
181205

182206
# data sharding and random sampling
183207
if self.data_sharding:
184-
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
185-
* self.micro_batch_size
208+
bucket_size = (
209+
self.total_samples // self.micro_batch_times_data_parallel_size
210+
) * self.micro_batch_size
186211
bucket_offset = current_epoch_samples // self.data_parallel_size
187212
start_idx = self.data_parallel_rank * bucket_size
188213

@@ -191,15 +216,13 @@ def __iter__(self):
191216
random_idx = torch.randperm(bucket_size, generator=g).tolist()
192217
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
193218
else:
194-
full_bucket_size = (self.total_samples // self.micro_batch_size) \
195-
* self.micro_batch_size
219+
full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size
196220
full_bucket_offset = current_epoch_samples
197221
g = torch.Generator()
198222
g.manual_seed(self.epoch)
199-
idx_range_total = \
200-
torch.randperm(full_bucket_size, generator=g).tolist()
223+
idx_range_total = torch.randperm(full_bucket_size, generator=g).tolist()
201224
idx_range_active = idx_range_total[full_bucket_offset:]
202-
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
225+
idx_range = idx_range_active[self.data_parallel_rank :: self.data_parallel_size]
203226

204227
batch = []
205228
# Last batch if not complete will be dropped.

0 commit comments

Comments
 (0)