- [2024/7]Megatron-Core v0.7 improves scalability and training resiliency and adds support for multimodal training (blog).
- [2024/6]Megatron-Core added supports for Mamba-based models. Check out our paperAn Empirical Study of Mamba-based Language Modelsandcode example.
- [2024/1 Announcement]NVIDIA has released the core capabilities inMegatron-LMintoMegatron-Corein this repository. Megatron-Core expands upon Megatron-LM's GPU-optimized techniques with more cutting-edge innovations on system-level optimizations, featuring composable and modular APIs. Explore theMegatron-Core introfor more details.
- Megatron Overview
- Training Speed and Scalability
- Setup
- Usage
- Training
- Evaluation and Tasks
- Datasets
- Reproducibility
- Projects using Megatron
This repository comprises two essential components:Megatron-LMandMegatron-Core.Megatron-LM serves as a ressearch-oriented framework leveraging Megatron-Core for large language model (LLM) training. Megatron-Core, on the other hand, is a library of GPU optimized training techniques that comes with formal product support including versioned APIs and regular releases. You can use Megatron-Core alongside Megatron-LM orNvidia NeMo Frameworkfor an end-to-end and cloud-native solution. Alternatively, you can integrate Megatron-Core's building blocks into your preferred training framework.
First introduced in 2019, Megatron (1,2,and3) sparked a wave of innovation in the AI community, enabling researchers and developers to utilize the underpinnings of this library to further LLM advancements. Today, many of the most popular LLM developer frameworks have been inspired by and built directly leveraging the open-source Megatron-LM library, spurring a wave of foundation models and AI startups. Some of the most popular LLM frameworks built on top of Megatron-LM includeColossal-AI,HuggingFace Accelerate,andNVIDIA NeMo Framework.A list of projects that have directly used Megatron can be foundhere.
Megatron-Core is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure. This library is compatible with all NVIDIA Tensor Core GPUs, including FP8 acceleration support forNVIDIA Hopper architectures.
Megatron-Core offers core building blocks such as attention mechanisms, transformer blocks and layers, normalization layers, and embedding techniques. Additional functionality like activation recomputation, distributed checkpointing is also natively built-in to the library. The building blocks and functionality are all GPU optimized, and can be built with advanced parallelization strategies for optimal training speed and stability on NVIDIA Accelerated Computing Infrastructure. Another key component of the Megatron-Core library includes advanced model parallelism techniques (tensor, sequence, pipeline, context, and MoE expert parallelism).
Megatron-Core can be used withNVIDIA NeMo,an enterprise-grade AI platform. Alternatively, you can explore Megatron-Core with the native PyTorch training loophere.VisitMegatron-Core documentationto learn more.
Our codebase is capable of efficiently training large language models (i.e., models with hundreds of billions of parameters) with both model and data parallelism. To demonstrate how our software scales with multiple GPUs and model sizes, we consider GPT models ranging from 2 billion parameters to 462 billion parameters. All models use a vocabulary size of 131,072 and a sequence length of 4096. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase batch size. Our experiments use up to 6144H100GPUs. We perform fine-grained overlapping of data-parallel (--overlap-grad-reduce --overlap-param-gather
), tensor-parallel (--tp-comm-overlap
) and pipeline-parallel communication (enabled by default) with computation to improve scalability. The reported throughputs are measured for end-to-end training and include all operations including data loading, optimizer steps, communication, and even logging. Note that we did not train these models to convergence.
Our weak scaled results show superlinear scaling (MFU increases from 41% for the smallest model considered to 47-48% for the largest models); this is because larger GEMMs have higher arithmetic intensity and are consequently more efficient to execute.
We also strong scaled the standard GPT-3 model (our version has slightly more than 175 billion parameters due to larger vocabulary size) from 96 H100 GPUs to 4608 GPUs, using the same batch size of 1152 sequences throughout. Communication becomes more exposed at larger scale, leading to a reduction in MFU from 47% to 42%.
We strongly recommend using the latest release ofNGC's PyTorch containerwith DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIAAPEXreleases. Data preprocessing requiresNLTK,though this is not required for training, evaluation, or downstream tasks.
You can launch an instance of the PyTorch container and mount Megatron, your dataset, and checkpoints with the following Docker commands:
docker pull nvcr.io/nvidia/pytorch:xx.xx-py3
docker run --gpus all -it --rm -v /path/to/megatron:/workspace/megatron -v /path/to/dataset:/workspace/dataset -v /path/to/checkpoints:/workspace/checkpoints nvcr.io/nvidia/pytorch:xx.xx-py3
We have provided pretrainedBERT-345MandGPT-345Mcheckpoints to evaluate or for finetuning downstream tasks. To access these checkpoints, firstsign upfor andsetupthe NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in theNGC documentation.
Alternatively, you can directly download the checkpoints using:
BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia /v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia /v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip GPT-345M: wget --content-disposition https://api.ngc.nvidia /v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
The models require vocabulary files to run. The BERT WordPiece vocab file can be extracted from Google's pretrained BERT models:uncased,cased.The GPTvocab fileandmerge tablecan be downloaded directly.
After installation, there are several possible workflows. The most comprehensive is:
- Data preprocessing
- Pretraining
- Finetuning (Optional for zero-shot tasks)
- Downstream task evaluation or text generation
However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above.
We've provided several scripts for pretraining both BERT and GPT in theexamples
directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation.
The training data requires preprocessing. First, place your training data in a loose json format, with one json containing a text sample per line. For example:
{ "src": "nvidia", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part" } { "src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part" }
The name of thetext
field of the json can be changed by using the--json-key
flag inpreprocess_data.py
The other metadata are optional and are not used in training.
The loose json is then processed into a binary format for training. To convert the json into mmap format usepreprocess_data.py
.An example script to prepare data for BERT training is:
Python tools/preprocess_data.py \ --input my-corpus.json \ --output-prefix my-bert \ --vocab-file bert-vocab.txt \ --tokenizer-type BertWordPieceLowerCase \ --split-sentences
The output will be two files named, in this case,my-bert_text_sentence.bin
andmy-bert_text_sentence.idx
.The--data-path
specified in later BERT training is the full path and new filename, but without the file extension.
For T5 use the same preprocessing as BERT, perhaps renaming it to:
--output-prefix my-t5 \
Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
Python tools/preprocess_data.py \ --input my-corpus.json \ --output-prefix my-gpt2 \ --vocab-file gpt2-vocab.json \ --tokenizer-type GPT2BPETokenizer \ --merge-file gpt2-merges.txt \ --append-eod
Here the output files are namedmy-gpt2_text_document.bin
andmy-gpt2_text_document.idx
.As before, in GPT training, use the longer name without the extension as--data-path
.
Further command line arguments are described in the source filepreprocess_data.py
.
Theexamples/pretrain_bert.sh
script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at--lr
to a minimum set by--min-lr
over--lr-decay-iters
iterations. The fraction of training iterations used for warmup is set by--lr-warmup-fraction
.While this is single GPU training, the batch size specified by--micro-batch-size
is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reachesglobal-batch-size
which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with--seed
). We usetrain-iters
as the training iterations requested. Alternatively, one can provide--train-samples
which is total number of samples to train on. If this option is present, then instead of providing--lr-decay-iters
,one will need to provide--lr-decay-samples
.
The logging, checkpoint-saving, and evaluation interval options are specified. Note that the--data-path
now includes the additional_text_sentence
suffix added in preprocessing, but does not include the file extensions.
Further command line arguments are described in the source filearguments.py
.
To runexamples/pretrain_bert.sh
,make any desired modifications including setting the environment variables forCHECKPOINT_PATH
,VOCAB_FILE
,andDATA_PATH
.Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained inSetup) and run the example script.
Theexamples/pretrain_gpt.sh
script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training.
It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and ajson
vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the--lr-decay-style
has been set to cosine decay. Note that the--data-path
now includes the additional_text_document
suffix added in preprocessing, but does not include the file extensions.
Further command line arguments are described in the source filearguments.py
.
examples/pretrain_gpt.sh
can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script.
Very similar to BERT and GPT, theexamples/pretrain_t5.sh
script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture:
-
--kv-channels
sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5. -
--ffn-hidden-size
sets the hidden size in the feed-forward networks within a transformer layer. For BERT and GPT this defaults to 4 times the transformer hidden size, but can be configured for T5. -
--encoder-seq-length
and--decoder-seq-length
set the sequence length for the encoder and decoder separately.
All of the other arguments remain as they were for BERT and GPT pretraining. Run this example with the same steps described above for the other scripts.
Theexamples/pretrain_{bert,gpt,t5}_distributed.sh
scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorchdocumentationfor further description of theseenvironment variables.By default, multi-node training uses thenccldistributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with thetorchrun
elastic launcher (equivalent toPython -m torch.distributed.run
) are the only additional requirements to adopt distributed training. See any ofexamples/pretrain_{bert,gpt,t5}_distributed.sh
for more details.
We use two types of parallelism: data and model parallelism. Our data parallelism implementation is inmegatron/core/distributed
,and supports overlapping of the gradient reduction with the backward pass when the--overlap-grad-reduce
command-line option is used.
Second, we developed a simple and efficient two-dimensional model-parallel approach. To use the first dimension, tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 ofour paper), add the--tensor-model-parallel-size
flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use the second dimension, sequence parallelism, specify--sequence-parallel
,which also requires tensor model parallelism to be enabled because it splits across the same GPUs (more details in Section 4.2.2 ofour paper).
To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches, see Section 2.2 ofour paper), use the--pipeline-model-parallel-size
flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each).
We have examples of how to use these two different forms of model parallelism the example scripts ending indistributed_with_mp.sh
.
Other than these minor changes, the distributed training is identical to the training on a single GPU.
The interleaved pipelining schedule (more details in Section 2.2.2 ofour paper) can be enabled using the--num-layers-per-virtual-pipeline-stage
argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage withNUM_LAYERS / PIPELINE_MP_SIZE
transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed asGLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)
) should be divisible by thePIPELINE_MP_SIZE
when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (PIPELINE_MP_SIZE=2
).
To reduce GPU memory usage when training a large model, we support various forms of activation checkpointing and recomputation. Instead of all activations being stored in memory to be used during backprop, as was traditionally the case in deep learning models, only activations at certain "checkpoints" in the model are retained (or stored) in memory, and the other activations are recomputed on-the-fly when needed for backprop. Note that this kind of checkpointing,activationcheckpointing, is very different from the checkpointing of model parameters and optimizer state, which is mentioned elsewhere.
We support two levels of recompute granularity:selective
andfull
.Selective recomputation is the default and is recommended in almost all cases. This mode retains in memory the activations that take less memory storage space and are more expensive to recompute and recomputes the activations that take more memory storage space but are relatively inexpensive to recompute. Seeour paperfor details. You should find that this mode maximizes performance while minimizing the memory required to store activations. To enable selective activation recompute simply use--recompute-activations
.
For cases where memory is very limited,full
recompute saves just the inputs to a transformer layer, or a group, or block, of transformer layers, and recomputes everything else. To enable full activation recompute use--recompute-granularity full
.When usingfull
activation recompute, there are two methods:uniform
andblock
,chosen using the--recompute-method
argument.
-
The
uniform
method uniformly divides the transformer layers into groups of layers (each group of size--recompute-num-layers
) and stores the input activations of each group in memory. The baseline group size is 1 and, in this case, the input activation of each transformer layer is stored. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage, enabling a bigger model to be trained. For example, when--recompute-num-layers
is set to 4, only the input activation of each group of 4 transformer layers is stored. -
The
block
method recomputes the input activations of a specific number (given by--recompute-num-layers
) of individual transformer layers per pipeline stage and stores the input activations of the remaining layers in the pipeline stage. Reducing--recompute-num-layers
results in storing the input activations to more transformer layers, which reduces the activation recomputation required in the backprop, thus improving training performance while increasing memory usage. For example, when we specify 5 layers to recompute of 8 layers per pipeline stage, the input activations of only the first 5 transformer layers are recomputed in the backprop step while the input activations for the final 3 layers are stored.--recompute-num-layers
can be incrementally increased until the amount of memory storage space required is just small enough to fit in the available memory, thereby both maximally utilizing memory and maximizing performance.
Usage:--use-distributed-optimizer
.Compatible with all model and data types.
The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described inZeRO: Memory Optimizations Toward Training Trillion Parameter Models,our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed (although the fp32 main params are still distributed, as they are separate from the bf16 model params).
Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):
Non-distributed optim | Distributed optim | |
---|---|---|
fp16 param, fp16 grads | 20 | 4 + 16/d |
bf16 param, fp32 grads | 18 | 6 + 12/d |
fp32 param, fp32 grads | 16 | 8 + 8/d |
As with regular data parallelism, overlapping of the gradient reduction (in this case, a reduce-scatter) with the backward pass can be facilitated using the--overlap-grad-reduce
flag. Additionally, overlapping of the parameter all-gather can be overlapped with the forward pass using--overlap-param-gather
.
Usage:--use-flash-attn
.Support attention head dimensions at most 128.
FlashAttentionis a fast and memory-efficient algorithm to compute exact attention. It speeds up model training and reduces memory requirement.
To install FlashAttention:
pip install flash-attn
Inexamples/pretrain_gpt3_175B.sh
we have provided an example of how to configure Megatron to trainGPT-3with 175 billion parameters on 1024 GPUs. The script is designed forslurmwithpyxisplugin but can be easily adopted to any other scheduler. It uses 8-way tensor parallelism and 16-way pipeline parallelism. With optionsglobal-batch-size 1536
andrampup-batch-size 16 16 5859375
,the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incremental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights.
With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.
Retro(Borgeaud et al., 2022)is an autoregressive decoder-only language model (LM) pretrained with retrieval-augmentation. Retro features practical scalability to support large-scale pretraining from scratch by retrieving from trillions of tokens. Pretraining with retrieval provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters, thus largely reducing model parameters while achieving lower perplexity than standard GPT. Retro also provides the flexibility to update the knowledge stored in LMs(Wang et al., 2023a) by updating the retrieval database without training LMs again.
InstructRetro(Wang et al., 2023b)further scales up the size of Retro to 48B, featuring the largest LLM pretrained with retrieval (as of December 2023). The obtained foundation model, Retro 48B, largely outperforms the GPT counterpart in terms of perplexity. With instruction tuning on Retro, InstructRetro demonstrates significant improvement over the instruction tuned GPT on downstream tasks in the zero-shot setting. Specifically, the average improvement of InstructRetro is 7% over its GPT counterpart across 8 short-form QA tasks, and 10% over GPT across 4 challenging long-form QA tasks. We also find that one can ablate the encoder from InstructRetro architecture and directly use the InstructRetro decoder backbone as GPT, while achieving comparable results.
In this repo, we provide an end-to-end reproduction guide to implement Retro and InstructRetro, covering
- Retrieval database construction,which supports billions or even trillions of tokens as a large-scale retrieval database.
- Pretraining with retrieval,which supports pretraining from scratch and pretraining from a pretrained GPT model (Retro-fitting).
- Instruction tuning,where we provide an open-source instruction tuning dataset and the training recipe for instruction tuning on Retro.
- Downstream task evaluation,where we provide the text generation and evaluation scripts for zero-shot question answering tasks.
Seetools/retro/README.mdfor a detailed overview.
Seeexamples/mambafor details.
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the--finetune
flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the--finetune
flag before continuing, otherwise the training will start again from the beginning.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.
Python tools/checkpoint/convert.py \ --model-type GPT \ --load-dir checkpoints/gpt3_tp4_pp4 \ --save-dir checkpoints/gpt3_tp2_pp2 \ --target-tensor-parallel-size 2 \ --target-pipeline-parallel-size 2
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
We have included a simple REST server to use for text generation intools/run_text_generation_server.py
.You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters:temperature
,top-k
andtop-p
.See--help
or the source file for more information. Seeexamples/run_text_generation_server_345M.shfor an example of how to run the server.
Once the server is running you can usetools/text_generation_cli.py
to query it, it takes one argument which is the host the server is running on.
tools/text_generation_cli.py localhost:5000
You can also use CURL or any other tools to query the server directly:
curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{ "prompts":[ "Hello world" ], "tokens_to_generate":1}'
Seemegatron/inference/text_generation_server.pyfor more API options.
We include an example inexamples/detxoify_lm/
to detoxify language models by leveraging the generative power of language models.
Seeexamples/detxoify_lm/README.mdfor step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus.
We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.
For even comparison with prior works, we evaluate perplexity on the word-levelWikiText-103 test dataset,and appropriately compute perplexity given the change in tokens when using our subword tokenizer.
We use the following command to run WikiText-103 evaluation on a 345M parameter model.
TASK= "WIKITEXT103" VALID_DATA=<wikitext path>.txt VOCAB_FILE=gpt2-vocab.json MERGE_FILE=gpt2-merges.txt CHECKPOINT_PATH=checkpoints/gpt2_345m COMMON_TASK_ARGS= "--num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --seq-length 1024 \ --max-position-embeddings 1024 \ --fp16 \ --vocab-file $VOCAB_FILE" Python tasks/main.py \ --task $TASK \ $COMMON_TASK_ARGS \ --valid-data $VALID_DATA \ --tokenizer-type GPT2BPETokenizer \ --merge-file $MERGE_FILE \ --load $CHECKPOINT_PATH \ --micro-batch-size 8 \ --log-interval 10 \ --no-load-optim \ --no-load-rng
To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of theLAMBADA dataset.
We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the--strict-lambada
flag should be used to require whole word matching. Ensure thatlambada
is part of the file path.
TASK= "LAMBADA" VALID_DATA=<lambada path>.json VOCAB_FILE=gpt2-vocab.json MERGE_FILE=gpt2-merges.txt CHECKPOINT_PATH=checkpoints/gpt2_345m COMMON_TASK_ARGS=<same as those inWikiText Perplexity Evaluationabove> Python tasks/main.py \ --task $TASK \ $COMMON_TASK_ARGS \ --valid-data $VALID_DATA \ --tokenizer-type GPT2BPETokenizer \ --strict-lambada \ --merge-file $MERGE_FILE \ --load $CHECKPOINT_PATH \ --micro-batch-size 8 \ --log-interval 10 \ --no-load-optim \ --no-load-rng
Further command line arguments are described in the source filemain.py
The following script finetunes the BERT model for evaluation on theRACE dataset.TheTRAIN_DATA
andVALID_DATA
directory contain the RACE dataset as separate.txt
files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line.
TRAIN_DATA= "data/RACE/train/middle" VALID_DATA= "data/RACE/dev/middle \ data/RACE/dev/high" VOCAB_FILE=bert-vocab.txt PRETRAINED_CHECKPOINT=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m_race COMMON_TASK_ARGS= "--num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --seq-length 512 \ --max-position-embeddings 512 \ --fp16 \ --vocab-file $VOCAB_FILE" COMMON_TASK_ARGS_EXT= "--train-data $TRAIN_DATA \ --valid-data $VALID_DATA \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --save-interval 10000 \ --save $CHECKPOINT_PATH \ --log-interval 100 \ --eval-interval 1000 \ --eval-iters 10 \ --weight-decay 1.0e-1" Python tasks/main.py \ --task RACE \ $COMMON_TASK_ARGS \ $COMMON_TASK_ARGS_EXT \ --tokenizer-type BertWordPieceLowerCase \ --epochs 3 \ --micro-batch-size 4 \ --lr 1.0e-5 \ --lr-warmup-fraction 0.06
The following script finetunes the BERT model for evaluation with theMultiNLI sentence pair corpus.Because the matching tasks are quite similar, the script can be quickly tweaked to work with theQuora Question Pairs(QQP) dataset as well.
TRAIN_DATA= "data/glue_data/MNLI/train.tsv" VALID_DATA= "data/glue_data/MNLI/dev_matched.tsv \ data/glue_data/MNLI/dev_mismatched.tsv" PRETRAINED_CHECKPOINT=checkpoints/bert_345m VOCAB_FILE=bert-vocab.txt CHECKPOINT_PATH=checkpoints/bert_345m_mnli COMMON_TASK_ARGS=<same as those inRACE Evaluationabove> COMMON_TASK_ARGS_EXT=<same as those inRACE Evaluationabove> Python tasks/main.py \ --task MNLI \ $COMMON_TASK_ARGS \ $COMMON_TASK_ARGS_EXT \ --tokenizer-type BertWordPieceLowerCase \ --epochs 5 \ --micro-batch-size 8 \ --lr 5.0e-5 \ --lr-warmup-fraction 0.065
The Llama-2family of modelsare an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At the time of release, Llama-2 models achieved among the best results for open-source models, and were competitive with the closed-source GPT-3.5 model (seehttps://arxiv.org/pdf/2307.09288.pdf).
The Llama-2 checkpoints can be loaded into Megatron for inference and finetuning. See documentationhere.
Megatron-Core (MCore)GPTModel
family supports advanced quantization algorithms and high-performance inference through TensorRT-LLM.
SeeMegatron Model Optimization and Deploymentforllama2
andnemotron3
examples.
We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced.
We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to downloadthe latest dump,extract the text withWikiExtractor.py,and then apply any necessary cleanup to convert it into plain text. "
We recommend using the--json
argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json object per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset with nltk punctuation standardization. For BERT training, use the--split-sentences
flag topreprocess_data.py
as describedaboveto include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the--split-sentences
flag.
We utilize the publicly availableOpenWebTextlibrary fromjcpetersonandeukaryote31'swork to download urls. We then filter, clean, and deduplicate all downloaded content according to the procedure described in ouropenwebtextdirectory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.
Megatron training can be bitwise reproducible; to enable this mode use--deterministic-mode
.This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).
There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs:
- The specific NCCL algorithm that is used during an all-reduce (as specified by the environment variable
NCCL_ALGO
) is important. We have tested the following:^NVLS
,Tree
,Ring
,CollnetDirect
,CollnetChain
.The code admits the use of^NVLS
,which allows NCCL the choice of non-NVLS algorithms; its choice seems to be stable. - Flash attention is non-deterministic; do not use
--use-flash-attn
. - If using Transformer Engine, you must also set the environment variable
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
.
In addition, determinisim has only been verified in NGC PyTorch containers up to and newer than 23.12. If you observe nondeterminism in Megatron training under other circumstances please open an issue.
Below are some of the projects where we have directly used Megatron:
- BERT and GPT Studies Using Megatron
- BioMegatron: Larger Biomedical Domain Language Model
- End-to-End Training of Neural Retrievers for Open-Domain Question Answering
- Large Scale Multi-Actor Generative Dialog Modeling
- Local Knowledge Powered Conversational Agents
- MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models
- RACE Reading Comprehension Dataset Leaderboard
- Training Question Answering Models From Synthetic Data
- Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases
- Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models
- Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model
- Multi-Stage Prompting for Knowledgeable Dialogue Generation
- Evaluating Parameter Efficient Learning for Generation
- Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models
- Shall We Pretrain Autoregressive Language Models with Retrieval? A Comprehensive Study
- InstructRetro: Instruction Tuning post Retrieval-Augmented Pretraining
- An Empirical Study of Mamba-based Language Models