Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions megatron/legacy/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import numpy as np
import torch

from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.legacy.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.legacy.data.data_samplers import MegatronPretrainingSampler
from megatron.legacy.data.dataset_utils import (
create_masked_lm_predictions,
pad_and_convert_to_numpy,
)
from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.training.datasets.data_samplers import MegatronPretrainingSampler


def make_attention_mask(source_block, target_block):
"""
Expand Down
14 changes: 8 additions & 6 deletions megatron/legacy/data/vit_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import os
import random

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter, ImageOps
from torchvision import datasets
from megatron.training import get_args
from megatron.legacy.data.image_folder import ImageFolder

from megatron.legacy.data.autoaugment import ImageNetPolicy
from megatron.legacy.data.data_samplers import RandomSeedDataset
from PIL import Image, ImageFilter, ImageOps
from megatron.legacy.data.image_folder import ImageFolder
from megatron.training import get_args
from megatron.training.datasets.data_samplers import RandomSeedDataset


class GaussianBlur(object):
Expand Down Expand Up @@ -236,14 +238,14 @@ def build_train_valid_datasets(data_path, image_size=224):
classes_fraction=args.classes_fraction,
data_per_class_fraction=args.data_per_class_fraction
)
train_data = RandomSeedDataset(train_data)
train_data = RandomSeedDataset(train_data, args.seed)

# validation dataset
val_data_path = data_path[1]
val_data = ImageFolder(
root=val_data_path,
transform=val_transform
)
val_data = RandomSeedDataset(val_data)
val_data = RandomSeedDataset(val_data, args.seed)

return train_data, val_data
7 changes: 5 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path
import re
import types
import warnings

import torch
import torch.nn.functional as F
Expand All @@ -35,6 +34,7 @@
)
from megatron.core.activations import squared_relu
from megatron.core.fusions.fused_bias_geglu import quick_gelu
from megatron.training.dist_signal_handler import SIGNAL_MAP
from megatron.training.utils import (
get_device_arch_version,
update_use_dist_ckpt,
Expand Down Expand Up @@ -2159,7 +2159,10 @@ def _add_training_args(parser):
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
'training if signal is received')
group.add_argument('--exit-signal', type=str, default='SIGTERM',
choices=list(SIGNAL_MAP.keys()),
help='Signal to use for exit signal handler. If not specified, defaults to SIGTERM.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@


import random
import torch

import numpy as np
import torch
from torch.utils.data import Dataset
from megatron.training import get_args

from megatron.core import mpu
from megatron.core.datasets.utils import Split

from megatron.training import get_args
from megatron.training.dist_signal_handler import DistributedSignalHandler


def build_pretraining_data_loader(dataset, consumed_samples):
"""Build dataloader given an input dataset."""

if dataset is None:
return None
args = get_args()
if hasattr(dataset,'split'):

if hasattr(dataset, 'split'):
split = dataset.split
elif hasattr(dataset,'index_split'):
elif hasattr(dataset, 'index_split'):
split = dataset.index_split
else:
split = None
Expand All @@ -32,15 +36,17 @@ def build_pretraining_data_loader(dataset, consumed_samples):
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
)
elif args.dataloader_type == 'single':
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
)
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
Expand All @@ -49,52 +55,82 @@ def build_pretraining_data_loader(dataset, consumed_samples):
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
data_sharding=args.data_sharding,
)
elif args.dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return dataset
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
raise Exception('{} dataloader type is not supported.'.format(args.dataloader_type))

def worker_init_fn(_):
DistributedSignalHandler(args.exit_signal).__enter__()

maybe_worker_init_fn = (
worker_init_fn if args.exit_signal_handler and args.num_workers > 0 else None
)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True if args.num_workers > 0 else False,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True if args.num_workers > 0 else False,
worker_init_fn=maybe_worker_init_fn,
)


class MegatronPretrainingSampler:
"""
Sampler for Megatron pretraining dataloaders that divides data samples across
data parallel workers. Each worker receives a contiguous chunk of data determined by
its rank and the micro batch size. Supports dropping the last incomplete batch if
specified, and keeps track of total and consumed samples. Designed to work with
distributed training using Megatron's data parallelism.
"""

def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, drop_last=True):
def __init__(
self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
drop_last=True,
):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert (
self.consumed_samples < self.total_samples
), 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert (
self.data_parallel_rank < data_parallel_size
), 'data_parallel_rank should be smaller than data size: {}, ' '{}'.format(
self.data_parallel_rank, data_parallel_size
)

def __len__(self):
return self.total_samples

def get_start_end_idx(self):
"""
Calculate the start and end indices for the current data parallel worker's
chunk within a batch.

Returns:
tuple: (start_idx, end_idx) indicating the slice of the batch for this worker.
"""
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
Expand All @@ -116,17 +152,37 @@ def __iter__(self):


class RandomSeedDataset(Dataset):
"""
A dataset wrapper that resets the random seed before each sample.

def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
This ensures deterministic behavior per sample by setting the RNG state
for torch, numpy, and random before accessing each underlying data sample.
The base seed is retrieved from training arguments, and can be varied per epoch
using the set_epoch method to ensure different shuffling or augmentation each epoch.

Args:
dataset: The underlying dataset to wrap.

Methods:
set_epoch(epoch): Change the seed offset so each epoch produces different randomization.
__getitem__(idx): Sets the seed based on the sample index and current epoch.
"""

def __init__(self, dataset, seed):
self.base_seed = seed
self.curr_seed = seed
self.dataset = dataset

def __len__(self):
return len(self.dataset)

def set_epoch(self, epoch):
"""
Change the seed offset so each epoch produces different randomization.

Args:
epoch: The epoch number to use as the seed offset.
"""
self.curr_seed = self.base_seed + epoch

def __getitem__(self, idx):
Expand All @@ -138,9 +194,23 @@ def __getitem__(self, idx):


class MegatronPretrainingRandomSampler:
"""
Sampler for Megatron pretraining dataloaders that performs random sampling
across data parallel workers. Supports data sharding to divide the dataset
into buckets and shuffle within each bucket. Designed to work with distributed
training using Megatron's data parallelism.
"""

def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, data_sharding):
def __init__(
self,
dataset,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
data_sharding,
):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
Expand All @@ -149,19 +219,18 @@ def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert (
self.data_parallel_rank < data_parallel_size
), 'data_parallel_rank should be smaller than data size: {}, ' '{}'.format(
self.data_parallel_rank, data_parallel_size
)

def __len__(self):
return self.total_samples
Expand All @@ -177,8 +246,9 @@ def __iter__(self):

# data sharding and random sampling
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_size = (
self.total_samples // self.micro_batch_times_data_parallel_size
) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

Expand All @@ -187,15 +257,13 @@ def __iter__(self):
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_total = torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
idx_range = idx_range_active[self.data_parallel_rank :: self.data_parallel_size]

batch = []
# Last batch if not complete will be dropped.
Expand Down
10 changes: 8 additions & 2 deletions megatron/training/dist_signal_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

import torch

SIGNAL_MAP = {
'SIGTERM': signal.SIGTERM,
'SIGINT': signal.SIGINT,
'SIGUSR1': signal.SIGUSR1,
'SIGUSR2': signal.SIGUSR2
}

def get_world_size():
if torch.distributed.is_available() and torch.distributed.is_initialized():
Expand Down Expand Up @@ -49,8 +55,8 @@ def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):


class DistributedSignalHandler:
def __init__(self, sig=signal.SIGTERM):
self.sig = sig
def __init__(self, sig: str = 'SIGTERM'):
self.sig = SIGNAL_MAP.get(sig, signal.SIGTERM)

def signals_received(self):
all_received = all_gather_item(
Expand Down
Loading