Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4cf3f4b
Modifying loss checking to support bf16. (#92)
jomayeri Nov 3, 2022
238ab6d
Updated to Curated acpt env and removed deepspeed install from github…
savitamittal1 Nov 7, 2022
c685fb5
Fix the bug of FusedLayerNorm on ROCm (#96)
hubertlu-tw Nov 18, 2022
f7ebcad
Fix generate_text.sh Megatron text-generation example working w/ DS i…
lekurile Dec 7, 2022
5e8d578
DeepSpeed Data Efficiency Library pretraining examples (#100)
conglongli Dec 12, 2022
515798f
fix floating point in script (#101)
conglongli Dec 14, 2022
9a52a2e
added unit test for megatron (#102)
mrwyattii Dec 19, 2022
789f2a9
fix deprecated numpy types (#103)
mrwyattii Dec 20, 2022
c240204
fix script typo
conglongli Dec 23, 2022
178beeb
Use accelerator for supporting other device (#104)
inkcherry Jan 30, 2023
57e6439
data efficiency example update (#113)
conglongli Feb 21, 2023
7bbd7f0
fix torch six import error (#119)
ajindal1 Mar 15, 2023
d3b401d
fix(uitls.py): add () after device_name (#130)
gfzum Apr 29, 2023
798b303
Remove duplicate content in README (#131)
sbmaruf May 1, 2023
3ed9f4f
fix a bug when run on bf16+pp (#134)
ys950902 May 5, 2023
ec70e81
Add a space to fix tab error for loss_scale (#136)
ys950902 May 6, 2023
1f640c0
fix(training.py): logical bug in eval_iters_calculation (#138)
2catycm May 10, 2023
6951ba2
passing num_experts to BERT and T5 language models (#142)
Jun 8, 2023
d5c822e
fix integer overflow in the cpp implementation of build_sample_idx (#…
jiayulu Jun 8, 2023
7491937
Fixed an enum value comparision error (#132)
rraminen Jun 8, 2023
958b47f
scr: initial integration
adammoody Oct 27, 2022
a7ef874
set variables for well-defined MPI environment
adammoody Oct 27, 2022
5fcfb1b
scr: enable scr.should_exit() to stop before allocation ends
adammoody Feb 7, 2023
fd304a5
add --scr-current to specify a checkpoint tag to load
adammoody Feb 16, 2023
3610f96
scr: improve comments
adammoody Jun 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include megatron/data/Makefile
include megatron/data/helpers.cpp
recursive-include megatron/fused_kernels *.cpp *.h *.cu *.tr *.cuh *.cc
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ All the cases from 1 billion to 1 trillion parameters achieve more than 43% half
* [Data Preprocessing](#data-preprocessing)
* [BERT Pretraining](#bert-pretraining)
* [GPT Pretraining](#gpt-pretraining)
* [GPT Pretraining](#gpt-pretraining)
* [T5 Pretraining](#t5-pretraining)
* [Distributed Pretraining](#distributed-pretraining)
* [GPT-3 Example](#gpt-3-example)
Expand Down
5 changes: 5 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ If you have a custom infrastructure (e.g. HPC clusters) or Azure VM and VMSS bas

Please see the ```MoE``` folder for different training recipes and scripts for Mixture-of-expert based models and dense models. These recipes are for GPT-style NLG models.

### Data Efficiency

The ```data_efficiency``` folder includes GPT-3 and BERT pretraining examples for DeepSpeed Data Efficiency Library. Please refer to the detailed tutorials in data_efficiency/README.MD.

### Curriculum Learning

Curriculum learning recipes are in the ```curriculum_learning``` folder. Please refer to the detailed tutorials linked inside. These recipes are for GPT-style NLG models.
Note that the DeepSpeed Data Efficiency Library above includes a more general curriculum learning support. This legacy curriculum learning feature is still compatible, but we recommend using the DeepSpeed Data Efficiency Library above.

### Model Compression

Expand Down
13 changes: 2 additions & 11 deletions examples/azureml/Dockerfile.dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
FROM mcr.microsoft.com/azureml/aifx/stable-ubuntu2004-cu115-py38-torch1110

FROM mcr.microsoft.com/azureml/curated/acpt-pytorch-1.11-py38-cuda11.5-gpu
USER root:root

RUN pip install pybind11

RUN pip install git+https://github.com/microsoft/DeepSpeed.git

# add a100-topo.xml
RUN mkdir -p /opt/microsoft/
RUN wget -O /opt/microsoft/a100-topo.xml https://hpcbenchmarks.blob.core.windows.net/bookcorpus/data/a100-topo.xml

# to use on A100, enable env var below in your job
ENV NCCL_TOPO_FILE="/opt/microsoft/a100-topo.xml"
RUN pip install regex
18 changes: 10 additions & 8 deletions examples/azureml/aml_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from azureml.core.compute import ComputeTarget, AmlCompute
from azureml.core.compute_target import ComputeTargetException
from azureml.core.runconfig import PyTorchConfiguration
from azureml.core.environment import DockerBuildContext

# Check core SDK version number
print("SDK version:", azureml.core.VERSION)
Expand Down Expand Up @@ -64,15 +65,9 @@
#-------------------------------------------------------------------------------
# Setup training environment
#-------------------------------------------------------------------------------
megatron_ds_env = Environment.from_dockerfile(name='megatron-ds-ptca', dockerfile='Dockerfile.dockerfile')
megatron_ds_env.register(ws).build(ws).wait_for_completion() # Comment this out if environment already exists

megatron_ds_env.environment_variables['NCCL_DEBUG'] = 'WARN'
megatron_ds_env.environment_variables['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
megatron_ds_env.environment_variables['NCCL_SOCKET_IFNAME'] = 'eth0'
megatron_ds_env.environment_variables['NCCL_IB_PCI_RELAXED_ORDERING']='1'
megatron_ds_env.environment_variables['UCX_TLS']='tcp'
megatron_ds_env.environment_variables['UCX_NET_DEVICES']='eth0'
megatron_ds_env = Environment.from_docker_build_context(name='megatron-ds-curated-acpt', docker_build_context=DockerBuildContext.from_local_directory(workspace = ws, path = '.', dockerfile_path='Dockerfile.dockerfile'))
megatron_ds_env.register(ws).build(ws).wait_for_completion() # Comment this out if environment already exists

#-------------------------------------------------------------------------------
# Training Settings and Arguments
Expand Down Expand Up @@ -187,6 +182,13 @@
environment=megatron_ds_env,
distributed_job_config=distr_config)

megatron_ds_src.run_config.environment_variables['NCCL_DEBUG'] = 'WARN'
megatron_ds_src.run_config.environment_variables['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
megatron_ds_src.run_config.environment_variables['NCCL_SOCKET_IFNAME'] = 'eth0'
megatron_ds_src.run_config.environment_variables['NCCL_IB_PCI_RELAXED_ORDERING']='1'
megatron_ds_src.run_config.environment_variables['UCX_TLS']='tcp'
megatron_ds_src.run_config.environment_variables['UCX_NET_DEVICES']='eth0'

#-------------------------------------------------------------------------------
# Submit experiment
#-------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions examples/bert_with_pile/prepare_pile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import sys
import time
import os

import sys
sys.path.insert(1, '../../')
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir,os.path.pardir)))
from megatron.data import indexed_dataset

def pile_download(download_url, file_path, i):
Expand Down
23 changes: 23 additions & 0 deletions examples/data_efficiency/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
This directory includes GPT-3/BERT pretraining example scripts for DeepSpeed Data Efficiency Library technologies (curriculum learning, random-LTD, and the two composed together).

You need to install updated DeepSpeed version (>=0.8.0), which contains the DeepSpeed Data Efficiency Library.

Additional tutorial can be found at [DeepSpeed website](https://www.deepspeed.ai/tutorials/data-efficiency/).

Additional technical details can be found in our [random-LTD paper](https://arxiv.org/abs/2211.11586) and [data efficiency paper](https://arxiv.org/abs/2212.03597).

## GPT-3 pretraining and evaluation
Inside ``gpt`` folder, first the ``ds_analyze_gpt_data_map.sh`` and ``ds_analyze_gpt_data_reduce.sh`` are used for curriculum learning's offline data analysis and indexing.

``gpt/pretrain`` includes the pretraining example scripts. You can choose a setup to run by uncommenting one block in ``ds_pretrain_gpt_1.3B_dense_run.sh``. One thing to note is that in our [random-LTD paper](https://arxiv.org/abs/2211.11586) we did not scale peak learning rate when using less than 100% data, while in our later [data efficiency paper](https://arxiv.org/abs/2212.03597) we find that scaling LR based on used percentage of data helps improve model quality.

``gpt/eval`` includes the zero-/few-shot evaluation example scripts. ``ds_evalharness_parallel_run.sh`` is for zero-shot, and ``ds_evalharness_parallel_run_10shot.sh`` is for 10-shot.

## BERT pretraining and finetuning
Inside ``bert`` folder, first the ``pile_data_download_preprocess.py`` can be used to download and preprocess the public Pile dataset.

The ``ds_analyze_bert_data_map.sh`` and ``ds_analyze_bert_data_reduce.sh`` are used for curriculum learning's offline data analysis and indexing.

``bert/pretrain`` includes the pretraining example scripts. You can choose a setup to run by uncommenting one block in ``ds_pretrain_bert_336M_run.sh``. One thing to note is that in our [random-LTD paper](https://arxiv.org/abs/2211.11586) we did not scale peak learning rate when using less than 100% data, while in our later [data efficiency paper](https://arxiv.org/abs/2212.03597) we find that scaling LR based on used percentage of data helps improve model quality.

``bert/finetune`` includes the MNLI/QQP/RACE finetuning example scripts following the [Megatron-LM paper](https://arxiv.org/abs/1909.08053). However, we found that the RACE task's accuracy is not very stable and the Megatron-LM paper used a very long number of epochs for MNLI/QQP which is not necessary. Thus we added capability of finetuning other GLUE tasks, and switched to follow the hyperparameters of the [original BERT paper](https://arxiv.org/abs/1810.04805). The corresponding scripts are at ``bert/finetune_glue``, which we recommend to use instead of ``bert/finetune``. Our [data efficiency paper](https://arxiv.org/abs/2212.03597) also uses the scripts under ``bert/finetune_glue`` for GLUE finetuning.
239 changes: 239 additions & 0 deletions examples/data_efficiency/analyze_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
Copyright 2022 The Microsoft DeepSpeed Team
'''

import os
import time
import sys
import math
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir,os.path.pardir)))
from datetime import datetime
import numpy as np
import torch

from deepspeed.runtime.data_pipeline.data_sampling.data_analyzer \
import DataAnalyzer
from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset \
import MMapIndexedDataset

from megatron import get_args
from megatron import print_rank_0
from megatron.initialize import initialize_megatron

def get_tasks_args(parser):
"""Provide extra arguments required for data analyzing."""
group = parser.add_argument_group(title='data_analyzing')

group.add_argument('--analyzing-task', type=str, required=True,
default=None,
choices=['map',
'reduce'],
help='What type of analyzing task to perform.')
group.add_argument('--analyzing-data-type', type=str, required=True,
default=None,
choices=['BERT',
'GPT'],
help='What type of data.')
group.add_argument('--analyzing-metric', type=str, nargs='+', default=[],
help='What kinds of metrics to analyze.')
group.add_argument('--analyzing-num-workers', type=int, default=1,
help='Number of workers. Each worker could be a single CPU node.')
group.add_argument('--analyzing-worker-id', type=int, default=0,
help='Worker id of current node.')
group.add_argument('--analyzing-num-threads', type=int, default=1,
help='Number of threads for each worker.')
group.add_argument('--analyzing-num-threads-reduce', type=int, default=1,
help='Number of threads for each worker.')
group.add_argument('--analyzing-specific-threads', type=int, nargs='+', default=[],
help='Which specific threads to run. Helpful when there are specific thread failed in previous run.')
return parser

def train_valid_test_datasets_provider_gpt():
"""Build train, valid, and test datasets."""
args = get_args()

print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
from megatron.data.gpt_dataset import build_train_valid_test_datasets
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=[1,1,1], # Just dummy numbers since we assume args.train_data_exact_num_epochs will override them
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT datasets ...")

return train_ds, valid_ds, test_ds

def train_valid_test_datasets_provider_bert():
"""Build train, valid, and test datasets."""
args = get_args()

print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
from megatron.data.dataset_utils import build_train_valid_test_datasets
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=[1,1,1], # Just dummy numbers since we assume args.train_data_exact_num_epochs will override them
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")

return train_ds, valid_ds, test_ds

def metric_seqlen(data):
metric = torch.count_nonzero(data['padding_mask'], dim=1)
return metric

def metric_total_vocab_freq(data):
args = get_args()
if args.analyzing_data_type == 'BERT':
frequency = torch.bincount(data['text'].view(-1),
minlength=args.padded_vocab_size+1,
weights=data['padding_mask'].view(-1))
elif args.analyzing_data_type == 'GPT':
frequency = torch.bincount(data['text'].view(-1),
minlength=args.padded_vocab_size+1)
return frequency

def metric_vocab_rarity(data):
args = get_args()
if args.analyzing_data_type == 'BERT':
rarity = torch.sum(data['padding_mask'] * \
args.total_vocab_freq[data['text']], dim=1).to(torch.long)
elif args.analyzing_data_type == 'GPT':
rarity = []
# Do one by one to avoid too high memory consumption
for row in range(data['text'].size()[0]):
rarity.append(int(torch.sum(args.total_vocab_freq[data['text'][row]]).item()))
rarity = torch.tensor(rarity, dtype=torch.long)
print(f"rarity min {min(rarity)}, max {max(rarity)}, len {len(rarity)}, avg {sum(rarity)/len(rarity)}")
return rarity

def metric_seqlen_vocab_rarity(data):
args = get_args()
metric = torch.count_nonzero(data['padding_mask'], dim=1).to(torch.long) * args.seqlen_coeff
metric += torch.sum(data['padding_mask'] * \
args.total_vocab_freq[data['text']], dim=1).to(torch.long)
print(f"metric min {min(metric)}, max {max(metric)}, len {len(metric)}, avg {sum(metric)/len(metric)}")
return metric

def get_metric_function(metric_name):
if metric_name == 'seqlen':
return metric_seqlen
if metric_name == 'total_vocab_freq':
return metric_total_vocab_freq
if metric_name == 'vocab_rarity':
return metric_vocab_rarity
if metric_name == 'seqlen_vocab_rarity':
return metric_seqlen_vocab_rarity

def get_metric_type(metric_name):
if metric_name == 'seqlen':
return 'single_value_per_sample'
if metric_name == 'total_vocab_freq':
return 'accumulate_value_over_samples'
if metric_name == 'vocab_rarity':
return 'single_value_per_sample'
if metric_name == 'seqlen_vocab_rarity':
return 'single_value_per_sample'

def run_map():
args = get_args()
if args.analyzing_data_type == 'BERT':
args.mask_prob = 0 # When analyzing data, we don't want any mask.
train_ds, _, _ = train_valid_test_datasets_provider_bert()
elif args.analyzing_data_type == 'GPT':
train_ds, _, _ = train_valid_test_datasets_provider_gpt()
assert 'seqlen' not in args.analyzing_metric, 'GPT data has fixed seqlen, thus unnecessary to analyze seqlen metric.'
assert 'seqlen_vocab_rarity' not in args.analyzing_metric, 'GPT data has fixed seqlen, thus unnecessary to analyze seqlen metric.'
if 'vocab_rarity' in args.analyzing_metric or 'seqlen_vocab_rarity' in args.analyzing_metric:
total_vocab_freq_fname = f"{args.save}/total_vocab_freq/total_vocab_freq_metric_value"
assert os.path.isfile(f"{total_vocab_freq_fname}.bin") and os.path.isfile(f"{total_vocab_freq_fname}.idx"), "To analyze vocab rarity, first need to analyze the total vocab freq."
total_vocab_freq = MMapIndexedDataset(total_vocab_freq_fname, skip_warmup=True)
total_vocab_freq = np.copy(total_vocab_freq[0])
total_vocab_freq[total_vocab_freq == 0] = 1 # Avoid log(0) error
total_vocab_freq = np.log(total_vocab_freq/sum(total_vocab_freq)) * -1
args.total_vocab_freq = torch.tensor(total_vocab_freq, dtype=torch.double)
if 'seqlen_vocab_rarity' in args.analyzing_metric:
# Use large coeff to make seqlen dominates vocab_rarity
max_possible_rarity = args.seq_length * torch.max(args.total_vocab_freq).item()
args.seqlen_coeff = 10 ** (math.ceil(math.log(max_possible_rarity, 10)) + 1)
print(f"Metric seqlen_vocab_rarity: using {args.seqlen_coeff} as coefficient for seqlen.")
metric_functions = [get_metric_function(x) for x in args.analyzing_metric]
metric_types = [get_metric_type(x) for x in args.analyzing_metric]
# For metric_dtypes we int64 by default since it could be hard to estimate
# the appropriate dtype before the mapping analysis. During reduce where
# we merge the analysis results, the DataAnalyzer will automatically choose
# the dtype of merged result file as the smallest one that meet the range
# requirement.
metric_dtypes = [np.int64 for x in args.analyzing_metric]
start = time.time()
data_analyzer = DataAnalyzer(train_ds,
num_workers=args.analyzing_num_workers,
worker_id=args.analyzing_worker_id,
num_threads=args.analyzing_num_threads,
specific_threads=args.analyzing_specific_threads,
batch_size=args.global_batch_size, metric_names=args.analyzing_metric,
metric_functions=metric_functions, metric_types=metric_types,
metric_dtypes=metric_dtypes, save_path=args.save)
data_analyzer.run_map()
duration = (time.time() - start) / 3600.0
print(f"map job finished in {duration} hr.")

def run_reduce():
args = get_args()
if args.analyzing_data_type == 'BERT':
args.mask_prob = 0 # When analyzing data, we don't want any mask.
train_ds, _, _ = train_valid_test_datasets_provider_bert()
elif args.analyzing_data_type == 'GPT':
train_ds, _, _ = train_valid_test_datasets_provider_gpt()
metric_functions = [get_metric_function(x) for x in args.analyzing_metric]
metric_types = [get_metric_type(x) for x in args.analyzing_metric]
metric_dtypes = [np.int64 for x in args.analyzing_metric]
start = time.time()
data_analyzer = DataAnalyzer(train_ds,
num_workers=args.analyzing_num_workers,
num_threads=args.analyzing_num_threads,
num_threads_reduce=args.analyzing_num_threads_reduce,
batch_size=args.global_batch_size, metric_names=args.analyzing_metric,
metric_functions=metric_functions, metric_types=metric_types,
metric_dtypes=metric_dtypes, save_path=args.save)
data_analyzer.run_reduce()
duration = (time.time() - start) / 3600.0
print(f"reduce job finished in {duration} hr.")

if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args, allow_no_cuda=True)
args = get_args()
if args.analyzing_task == 'map':
run_map()
elif args.analyzing_task == 'reduce':
run_reduce()
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.analyzing_task))
Loading