Skip to content

Commit

Permalink
Update/add GPT/Llama universal checkpointing scripts (#391)
Browse files Browse the repository at this point in the history
This PR adds a Llama universal checkpointing example to examples_deepspeed/universal_checkpointing.

It also includes changes to the README, some minor changes, and an update to the TensorBoard analysis script.
  • Loading branch information
lekurile authored Jul 29, 2024
1 parent 3afd267 commit 8822a5c
Show file tree
Hide file tree
Showing 14 changed files with 486 additions and 26 deletions.
22 changes: 11 additions & 11 deletions examples_deepspeed/universal_checkpointing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ This folder contains example scripts that demonstrate how to use Universal Check
For ZeRO stage 1, we provide bash scripts for bf16 and fp16 training examples corresponding to the steps 1 and 3 above. The step 1 scripts launch a training run of TP=PP=DP=2 of 200 iterations that creates a checkpoint every 100 iterations. The step 3 scripts load a universal checkpoint of iteration 100 and resume training with TP=PP=2 and DP=1 for an additional 100 iterations. Users can modify these scripts to try out other save and resume 3D combinations (e.g., save TP=PP=DP=1 and resume TP=PP=DP=2). Tensorboard logs are created by both step 1 and 3 scripts to enable visual inspection of how well the loss curves of the initial and resumed training runs match, especially at iteration 101.

1. bf16:
* run_bf16.sh: step 1
* run_universal_bf16.sh: step 3
* megatron_gpt/run_bf16.sh: step 1
* megatron_gpt/run_universal_bf16.sh: step 3

2. fp16:
* run_fp16.sh: step 1
* run_universal_fp16.sh: step 3
* megatron_gpt/run_fp16.sh: step 1
* megatron_gpt/run_universal_fp16.sh: step 3

Please note that these scripts should be run from the root folder of the repo (i.e., two levels above this README). For illustration, here are the commands for running the bf16 example.

Expand All @@ -41,22 +41,22 @@ NOTE: Make sure to update your `BASE_DATA_PATH` path in the `run_[bf16/fp16].sh`

### Step 1: Create ZeRO checkpoint
```bash
bash examples_deepspeed/universal_checkpointing/run_bf16.sh
bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_bf16.sh
```
By default the script will create the checkpoints in folder `z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy`
By default the script will create the checkpoints in folder `z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy`

### Step 2: Convert ZeRO checkpoint of iteration 100 to Universal format
Assuming the DeepSpeed source code is cloned into the home folder, the following command will generate universal checkpoint for iteration 100.

```bash
python ${HOME}/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py \
--input_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/global_step100 \
--output_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/global_step100_universal
--input_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/global_step100 \
--output_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/global_step100_universal
```
Note that we chose to create the universal checkpoint in the same checkpoint folder as the ZeRO checkpoint. This maintains the normal checkpoint folder structure expected by the Megatron-DeepSpeed code, which makes it easy to load universal checkpoints with little/no script or code changes. For clarity, we show below the contents of the checkpoint folder after creation of the universal checkpoint. Note that the conversion script creates `global_step100_universal` folder and `latest_universal` file.

```bash
ls -l z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/
ls -l z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/
total 48
drwxr-xr-x 2 user group 4096 Oct 21 08:51 global_step100
drwxr-xr-x 3 user group 4096 Oct 21 09:28 global_step100_universal
Expand All @@ -69,7 +69,7 @@ drwxr-xr-x 2 user group 4096 Oct 21 09:01 global_step200

### Step 3: Resume training with Universal checkpoint of iteration 100
```bash
bash examples_deepspeed/universal_checkpointing/run_universal_bf16.sh
bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_bf16.sh
```
This resumption script effects the loading of universal checkpoint rather than the ZeRO checkpoint in the folder by passing `--universal-checkpoint` command line flag to the main training script (i.e., `pretrain_gpt.py`).

Expand All @@ -85,7 +85,7 @@ The Universal Checkpointing example includes a TensorBoard analysis script that

After Step 3 is completed, the script may be executed as follows:
```bash
bash examples_deepspeed/universal_checkpointing/run_tb_analysis.sh z1_uni_ckpt
bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt.sh z1_uni_ckpt
```

The script will output the following `csv` files:
Expand Down
175 changes: 175 additions & 0 deletions examples_deepspeed/universal_checkpointing/llama/run_llama_bf16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/bin/bash
set -ex

DIR=`pwd`
######################################
# Change the below configurations here
BASE_PATH=dataset
DS_CONFIG=${BASE_PATH}/deepspeed.json
DATASET=${BASE_PATH}/my-gpt2_text_document
TOKENIZER_PATH=${BASE_PATH}/llama-7b/tokenizer.model # offical llama tokenizer.model

GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0

HIDDEN_SIZE=2048 # e.g. llama-13b: 5120
FFN_HIDDEN_SIZE=5504 # e.g. llama-13b: 13824
NUM_LAYERS=24 # e.g. llama-13b: 40
NUM_HEADS=16 # e.g. llama-13b: 40
SEQ=2048

LR_WARMUP_STEPS=2000
WEIGHT_DECAY=0.1
GRAD_CLIP=1

## Activation checkpointing saves GPU memory, but reduces training speed
# activation_checkpoint="true"
activation_checkpoint="false"

ZERO_STAGE=1
DTYPE="bf16"

# 3D parallelism of training
TP=2
PP=2
DP=2
SP=1
WORLD_SIZE=$((TP*PP*DP*SP))
GLOBAL_BATCH=32
MICRO_BATCH=$((GLOBAL_BATCH/WORLD_SIZE))
TRAIN_ITERS=250000
LR=3e-4
MIN_LR=3e-5

# Debug
DEBUG_MODE=1
if [[ $DEBUG_MODE == 1 ]]; then
EXIT_INTERVAL=200
SIZE_TAG="toy"
else
EXIT_INTERVAL=$TRAIN_ITERS
SIZE_TAG="big"
fi

# 3D parallelism of checkpoint to load
LOAD_TP=$TP
LOAD_PP=$PP
LOAD_DP=$DP
LOAD_SP=$SP
RUN_TAG="save"


EXP_DIR="z${ZERO_STAGE}_uni_ckpt"
CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_${SIZE_TAG}
LOAD_CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${LOAD_TP}_pp${LOAD_PP}_dp${LOAD_DP}_sp${LOAD_SP}_${SIZE_TAG}
LOG_DIR="${EXP_DIR}/tensorboard/llama/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_${SIZE_TAG}_${RUN_TAG}"
mkdir -p $LOG_DIR

# Below configuration required for llama model as per llama paper
# --no-query-key-layer-scaling \
# --attention-dropout 0 \
# --hidden-dropout 0 \
# --use-rotary-position-embeddings \
# --untie-embeddings-and-output-weights \
# --swiglu \
# --normalization rmsnorm \
# --disable-bias-linear \
######################################

cat <<EOT > $DS_CONFIG
{
"train_batch_size" : $GLOBAL_BATCH,
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
"steps_per_print": 1,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
},
"wall_clock_breakdown" : false
}
EOT

ds_args=""
ds_args=" --deepspeed ${ds_args}"
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"

if [ "${activation_checkpoint}" = "true" ]; then
ds_args="--deepspeed-activation-checkpointing ${ds_args}"

## old argument for recomputing the transformer layer
# ds_args="--checkpoint-activations ${ds_args}"

## new argument for recomputing the transformer layer
ds_args="--recompute-granularity full --recompute-method uniform ${ds_args}"
## new argument for recomputing only the attention layer
# ds_args="--recompute-granularity selective ${ds_args}"
fi

if [[ ${ZERO_STAGE} -gt 1 ]]; then
ds_args="${ds_args} \
--no-pipeline-parallel"
fi

options="\
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--ds-sequence-parallel-size $SP \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--micro-batch-size $MICRO_BATCH \
--global-batch-size $GLOBAL_BATCH \
--seq-length $SEQ \
--max-position-embeddings $SEQ \
--train-iters $TRAIN_ITERS \
--save ${CHECKPOINT_PATH} \
--load ${LOAD_CHECKPOINT_PATH} \
--data-path $DATASET \
--data-impl mmap \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model $TOKENIZER_PATH \
--split 949,50,1 \
--distributed-backend nccl \
--lr $LR \
--lr-decay-style cosine \
--min-lr $MIN_LR \
--weight-decay $WEIGHT_DECAY \
--clip-grad $GRAD_CLIP \
--lr-warmup-iters $LR_WARMUP_STEPS \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--log-interval 1 \
--save-interval 100 \
--eval-interval 10 \
--eval-iters 40 \
--exit-interval ${EXIT_INTERVAL} \
--${DTYPE} \
--no-query-key-layer-scaling \
--attention-dropout 0 \
--hidden-dropout 0 \
--use-rotary-position-embeddings \
--untie-embeddings-and-output-weights \
--swiglu \
--normalization rmsnorm \
--disable-bias-linear \
--tensorboard-dir $LOG_DIR \
$ds_args
"

WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}"

echo ${options}
echo ${run_cmd}
eval ${run_cmd}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

OUTPUT_PATH=$1

if [ "$OUTPUT_PATH" == "" ]; then
OUTPUT_PATH="z1_uni_ckpt"
fi

# Training Loss
python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \
--tb_dir $OUTPUT_PATH \
--tb_event_key "lm-loss-training/lm loss" \
--plot_name "uc_char_training_loss.png" \
--plot_title "Llama 7B Universal Checkpointing - Training Loss" \

# Validation Loss
python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \
--tb_dir $OUTPUT_PATH \
--tb_event_key "lm-loss-validation/lm loss validation" \
--csv_name "val_" \
--plot_name "uc_char_validation_loss.png" \
--plot_title "Llama 7B Universal Checkpointing - Validation Loss" \
--plot_y_label "Validation LM Loss" \
Loading

0 comments on commit 8822a5c

Please sign in to comment.