Merak is a distributed deep learning training framework with automated 3D parallelism. It can automatically slice, allocate and train a DNN model, making the development of giant model fast and easy. The current version of Merak is adapted to PyTorch.
With the rapidly growing size of DNN models, exquisite distributed training solutions for giant models are required. However, using the SOTA technology of giant model pretraining: 3D parallelism (data parallelism, tensor model parallelism, pipeline model parallelism) needs much experiences and model rewriting.
The motivation of Merak is to simplify the usage of 3D parallelism and ensure that users only need to add as little code as the popular training tool Huggingface transformers trainer to achieve complicated 3D parallelism.
To install Merak:
# ensure PyTorch >= 1.10 installed since it requires extra index url
# (check https://pytorch.org/get-started/locally/)
# ensure pybind11 installed
git clone http://hpdl-group/Merak.git
cd Merak
pip install .
To use Merak, make the following modifications to your programs:
- Import Merak before import transformers and torch
- Set degrees of the data parallel, tensor model parallel and pipeline model parallel; and run
Merak.init(dp, tp, pp)
to initialize Merak. - Set training arguments
MerakArguments
. Replacement of transformers trainer arguments - Set
MerakTrainer
. Replacement of transformers trainer.
Example usage (see the Merak examples directory for full training examples):
import Merak
from Merak import MerakArguments, MerakTrainer
# Init Merak with degrees of 3D parallelism.
dp = 2
tp = 1
pp = 2
Merak.init(dp, tp, pp)
# Set training args MerakArgument.
training_args = MerakArguments(
output_dir='./path/to/save'
)
# Set our Trainer
trainer = MerakTrainer(
do_train=...,
model=...,
args=training_args,
train_data=...,
eval_data=...,
)
# Do train
trainer.train()
For more details you could refer to our api document. For more detail usage, please check transformers tutorials and its trainer examples.
In pipeline model parallelism of Merak, we uses torch.fx
and transformers.utils.fx
to trace a model into GraphModule
. Then we come up with a graph sharding algorithm to split traced graph evenly into a sequence of GraphModules
. For example, in the GPT model, each attention block and mlp block will be an individual module. Next a high performance runtime engine would allocate the module sequence and execute the training procedures.
As for tensor model parallelism, we use a feature dict to map the parameters into ColumnParallelLinear
and RowParallelLinear
in Megatron-LM. We hold default feature dicts for common models in transformers
package. In addition, users could define a feature dict through our API easily to achieve the tensor model parallelism.
- Using as easy as single GPU training.
For giant models in transformers
: our implementation is based on transformers.trainer
class. With a few lines of code setting of parallel degrees, training model with 3D parallelism could be as easy as single GPU training.
For models not in transformers
: as long as a model is traceable by torch.fx
and trainable by transformers.trainer
, the model could trained by Merak as well.
- Sharding a giant model in a single worker.
Training, even only loading, a DNN model on a single GPU device could easily exceed the device's memory capacity nowadays. Before the model being initialized in memory, we create proxy layers for torch.nn.Linear
layers. Proxy layers do not own parameters but could participate in model tracing and graph sharding normally. This make it possible that a single worker could store a whole giant DNN model and execute the graph sharding swiftly.
- Auto dataloader for 3D parallelism.
When we train a model with pipeline parallelism, different stages require different data, some stages even do not load data. So we try to make the different stages only get their needed data, without loading the full dataset.
To further boost the training performance, our efficient3D parallel runtime engine proposes some novel technologies to achieve better integration of training resources.
- Shifted critical path schedule.
We introduce a shifted critical path pipeline schedule for reducing pipeline bubbles. Critical path is an operation sequence that determines the pipeline latency. Our schedule shortens the critical path by dropping redundant recomputation and adjusting orders and start time of operations.
- Stage-aware recomputation.
In addition, we observe that a more efficient memory utilization can be obtained by adopting the activation recomputation in a fine-grained way. Hence we develop a stage-aware recomputation method to exploit the usage of worker memory, which employs idle memory for less recomputation according to pipeline stage rank and pipeline depth, and thereby speedup training.
- Sub-pipelined TMP.
Furthermore, we improve the concurrency of the communication and computation in TMP with our sub-pipelined TMP approach, which applies microbatch splitting for individual sub-microbatches, and thereby pipelines sub-microbatches to overlap the communication and computation in TMP.
Please refer to our paper for more technical details and experiment results.
The Merak source code was based off the transformers trainer, Deepspeed and Megatron-LM repository.
@article{lai2022merak,
author={Lai, Zhiquan and Li, Shengwei and Tang, Xudong and Ge, Keshi and Liu, Weijie and Duan, Yabo and Qiao, Linbo and Li, Dongsheng},
journal={IEEE Transactions on Parallel and Distributed Systems},
title={Merak: An Efficient Distributed DNN Training Framework With Automated 3D Parallelism for Giant Foundation Models},
year={2023},
volume={34},
number={5},
pages={1466-1478},
doi={10.1109/TPDS.2023.3247001}
}
Pipeline partition algorithm AutoPipe
@INPROCEEDINGS{liu2022autopipe,
author={Liu, Weijie and Lai, Zhiquan and Li, Shengwei and Duan, Yabo and Ge, Keshi and Li, Dongsheng},
booktitle={2022 IEEE International Conference on Cluster Computing (CLUSTER)},
title={AutoPipe: A Fast Pipeline Parallelism Approach with Balanced Partitioning and Micro-batch Slicing},
year={2022},
volume={},
number={},
pages={301-312},
doi={10.1109/CLUSTER51413.2022.00042}
}