44
55
66import random
7- import torch
7+
88import numpy as np
9+ import torch
910from torch .utils .data import Dataset
10- from megatron . training import get_args
11+
1112from megatron .core import mpu
1213from megatron .core .datasets .utils import Split
14+ from megatron .training import get_args
1315from megatron .training .dist_signal_handler import DistributedSignalHandler
1416
17+
1518def build_pretraining_data_loader (dataset , consumed_samples ):
1619 """Build dataloader given an input dataset."""
1720
1821 if dataset is None :
1922 return None
2023 args = get_args ()
21-
22- if hasattr (dataset ,'split' ):
24+
25+ if hasattr (dataset , 'split' ):
2326 split = dataset .split
24- elif hasattr (dataset ,'index_split' ):
27+ elif hasattr (dataset , 'index_split' ):
2528 split = dataset .index_split
2629 else :
2730 split = None
@@ -32,15 +35,17 @@ def build_pretraining_data_loader(dataset, consumed_samples):
3235 consumed_samples = 0 ,
3336 micro_batch_size = args .micro_batch_size ,
3437 data_parallel_rank = mpu .get_data_parallel_rank (),
35- data_parallel_size = mpu .get_data_parallel_world_size ())
38+ data_parallel_size = mpu .get_data_parallel_world_size (),
39+ )
3640 elif args .dataloader_type == 'single' :
3741 # Megatron sampler
3842 batch_sampler = MegatronPretrainingSampler (
3943 total_samples = len (dataset ),
4044 consumed_samples = consumed_samples ,
4145 micro_batch_size = args .micro_batch_size ,
4246 data_parallel_rank = mpu .get_data_parallel_rank (),
43- data_parallel_size = mpu .get_data_parallel_world_size ())
47+ data_parallel_size = mpu .get_data_parallel_world_size (),
48+ )
4449 elif args .dataloader_type == 'cyclic' :
4550 batch_sampler = MegatronPretrainingRandomSampler (
4651 dataset ,
@@ -49,51 +54,63 @@ def build_pretraining_data_loader(dataset, consumed_samples):
4954 micro_batch_size = args .micro_batch_size ,
5055 data_parallel_rank = mpu .get_data_parallel_rank (),
5156 data_parallel_size = mpu .get_data_parallel_world_size (),
52- data_sharding = args .data_sharding )
57+ data_sharding = args .data_sharding ,
58+ )
5359 elif args .dataloader_type == "external" :
5460 # External dataloaders are passed through. User is expected to provide a
5561 # torch-compatible dataloader and define samplers, if needed.
5662 return dataset
5763 else :
58- raise Exception ('{} dataloader type is not supported.' .format (
59- args .dataloader_type ))
64+ raise Exception ('{} dataloader type is not supported.' .format (args .dataloader_type ))
6065
6166 def worker_init_fn (_ ):
6267 DistributedSignalHandler (args .exit_signal ).__enter__ ()
63- maybe_worker_init_fn = worker_init_fn if args .exit_signal_handler and args .num_workers > 0 else None
68+
69+ maybe_worker_init_fn = (
70+ worker_init_fn if args .exit_signal_handler and args .num_workers > 0 else None
71+ )
6472 # Torch dataloader.
65- return torch .utils .data .DataLoader (dataset ,
66- batch_sampler = batch_sampler ,
67- num_workers = args .num_workers ,
68- pin_memory = True ,
69- persistent_workers = True if args .num_workers > 0 else False ,
70- worker_init_fn = maybe_worker_init_fn ,
71- )
73+ return torch .utils .data .DataLoader (
74+ dataset ,
75+ batch_sampler = batch_sampler ,
76+ num_workers = args .num_workers ,
77+ pin_memory = True ,
78+ persistent_workers = True if args .num_workers > 0 else False ,
79+ worker_init_fn = maybe_worker_init_fn ,
80+ )
81+
7282
7383class MegatronPretrainingSampler :
7484
75- def __init__ (self , total_samples , consumed_samples , micro_batch_size ,
76- data_parallel_rank , data_parallel_size , drop_last = True ):
85+ def __init__ (
86+ self ,
87+ total_samples ,
88+ consumed_samples ,
89+ micro_batch_size ,
90+ data_parallel_rank ,
91+ data_parallel_size ,
92+ drop_last = True ,
93+ ):
7794 # Keep a copy of input params for later use.
7895 self .total_samples = total_samples
7996 self .consumed_samples = consumed_samples
8097 self .micro_batch_size = micro_batch_size
8198 self .data_parallel_rank = data_parallel_rank
82- self .micro_batch_times_data_parallel_size = \
83- self .micro_batch_size * data_parallel_size
99+ self .micro_batch_times_data_parallel_size = self .micro_batch_size * data_parallel_size
84100 self .drop_last = drop_last
85101
86102 # Sanity checks.
87- assert self .total_samples > 0 , \
88- 'no sample to consume: {}' .format (self .total_samples )
89- assert self .consumed_samples < self .total_samples , \
90- 'no samples left to consume: {}, {}' .format (self .consumed_samples ,
91- self .total_samples )
103+ assert self .total_samples > 0 , 'no sample to consume: {}' .format (self .total_samples )
104+ assert (
105+ self .consumed_samples < self .total_samples
106+ ), 'no samples left to consume: {}, {}' .format (self .consumed_samples , self .total_samples )
92107 assert self .micro_batch_size > 0
93108 assert data_parallel_size > 0
94- assert self .data_parallel_rank < data_parallel_size , \
95- 'data_parallel_rank should be smaller than data size: {}, ' \
96- '{}' .format (self .data_parallel_rank , data_parallel_size )
109+ assert (
110+ self .data_parallel_rank < data_parallel_size
111+ ), 'data_parallel_rank should be smaller than data size: {}, ' '{}' .format (
112+ self .data_parallel_rank , data_parallel_size
113+ )
97114
98115 def __len__ (self ):
99116 return self .total_samples
@@ -143,8 +160,16 @@ def __getitem__(self, idx):
143160
144161class MegatronPretrainingRandomSampler :
145162
146- def __init__ (self , dataset , total_samples , consumed_samples , micro_batch_size ,
147- data_parallel_rank , data_parallel_size , data_sharding ):
163+ def __init__ (
164+ self ,
165+ dataset ,
166+ total_samples ,
167+ consumed_samples ,
168+ micro_batch_size ,
169+ data_parallel_rank ,
170+ data_parallel_size ,
171+ data_sharding ,
172+ ):
148173 # Keep a copy of input params for later use.
149174 self .dataset = dataset
150175 self .total_samples = total_samples
@@ -153,19 +178,18 @@ def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
153178 self .data_parallel_rank = data_parallel_rank
154179 self .data_parallel_size = data_parallel_size
155180 self .data_sharding = data_sharding
156- self .micro_batch_times_data_parallel_size = \
157- self .micro_batch_size * data_parallel_size
158- self .last_batch_size = \
159- self .total_samples % self .micro_batch_times_data_parallel_size
181+ self .micro_batch_times_data_parallel_size = self .micro_batch_size * data_parallel_size
182+ self .last_batch_size = self .total_samples % self .micro_batch_times_data_parallel_size
160183
161184 # Sanity checks.
162- assert self .total_samples > 0 , \
163- 'no sample to consume: {}' .format (self .total_samples )
185+ assert self .total_samples > 0 , 'no sample to consume: {}' .format (self .total_samples )
164186 assert self .micro_batch_size > 0
165187 assert data_parallel_size > 0
166- assert self .data_parallel_rank < data_parallel_size , \
167- 'data_parallel_rank should be smaller than data size: {}, ' \
168- '{}' .format (self .data_parallel_rank , data_parallel_size )
188+ assert (
189+ self .data_parallel_rank < data_parallel_size
190+ ), 'data_parallel_rank should be smaller than data size: {}, ' '{}' .format (
191+ self .data_parallel_rank , data_parallel_size
192+ )
169193
170194 def __len__ (self ):
171195 return self .total_samples
@@ -181,8 +205,9 @@ def __iter__(self):
181205
182206 # data sharding and random sampling
183207 if self .data_sharding :
184- bucket_size = (self .total_samples // self .micro_batch_times_data_parallel_size ) \
185- * self .micro_batch_size
208+ bucket_size = (
209+ self .total_samples // self .micro_batch_times_data_parallel_size
210+ ) * self .micro_batch_size
186211 bucket_offset = current_epoch_samples // self .data_parallel_size
187212 start_idx = self .data_parallel_rank * bucket_size
188213
@@ -191,15 +216,13 @@ def __iter__(self):
191216 random_idx = torch .randperm (bucket_size , generator = g ).tolist ()
192217 idx_range = [start_idx + x for x in random_idx [bucket_offset :]]
193218 else :
194- full_bucket_size = (self .total_samples // self .micro_batch_size ) \
195- * self .micro_batch_size
219+ full_bucket_size = (self .total_samples // self .micro_batch_size ) * self .micro_batch_size
196220 full_bucket_offset = current_epoch_samples
197221 g = torch .Generator ()
198222 g .manual_seed (self .epoch )
199- idx_range_total = \
200- torch .randperm (full_bucket_size , generator = g ).tolist ()
223+ idx_range_total = torch .randperm (full_bucket_size , generator = g ).tolist ()
201224 idx_range_active = idx_range_total [full_bucket_offset :]
202- idx_range = idx_range_active [self .data_parallel_rank :: self .data_parallel_size ]
225+ idx_range = idx_range_active [self .data_parallel_rank :: self .data_parallel_size ]
203226
204227 batch = []
205228 # Last batch if not complete will be dropped.
0 commit comments