-
Notifications
You must be signed in to change notification settings - Fork 5
/
1.pre_train_math_moe.py
executable file
·163 lines (137 loc) · 7.34 KB
/
1.pre_train_math_moe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# doc: https://docs.google.com/presentation/d/1FlkJ98hSGnZ4kjePaxCJvNkM9zvlkYKIZ6xu8TjJGXE/edit?usp=sharing
from tabulate import tabulate
import math
# Mixtraol-8*7B
NHIDDEN=4096
NLAYERS=32
NHEAD=8
SEQ_LEN=4096
VOCAB_SIZE=32000
EXPERT_NUM=8
EXPERT_NUM_LIVE=2
NODE=1
GPU_PER_NODE=8
GPU_MEMORY=80
BATCH_SIZE=1
BTOKEN=0.7 # Token in Billion
TFLOPS=140 # [130-170]
Gradient_checkpointing=True # gradient_checkpointing technich: https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
NGPU=GPU_PER_NODE*NODE
h=NHIDDEN
l=NLAYERS
s=SEQ_LEN
v=VOCAB_SIZE
b=BATCH_SIZE
a=NHEAD
expert_num=EXPERT_NUM
expert_num_live=EXPERT_NUM_LIVE
batch_expert_num_dict={'1':2,'2':4,'4':6, '8':8}
expert_num_live_batch=batch_expert_num_dict[str(expert_num)]
def main():
print('-----------Model_Size and GPU_Mem-----------')
emb=(v*h+s*h)/10**9
mlp=8*h**2
rounting=expert_num*h
attn_and_norm=4*h**2+6*h
blk=(mlp*expert_num+attn_and_norm+rounting)/10**9
blk_train=(mlp*expert_num_live_batch+attn_and_norm+rounting)/10**9
extra_norm=h/10**9
model=l*blk+emb+extra_norm
model_train=l*blk_train+emb+extra_norm
single_mem=GPU_MEMORY-1
dict={"Model size/B": round(model, 2), "ratio(NHIDDEN/NLAYERS)":int(h/l), "Usable_mem_per_GPU/G": round(single_mem, 2)}
print(tabulate([dict], headers="keys", tablefmt="pretty"))
print('-----------With Mixed Precision(bp16)-----------')
print(f'-----Memory_reference_indicator(Batch_size={b})-----')
input=(b*s*h)/10**9
activation_per_MLP = 19*s*b*h
activation_per_layer = b*s*h*15 + activation_per_MLP*expert_num_live_batch +5*s*s*b*a
activation=math.sqrt((activation_per_layer*l)/10**9) if Gradient_checkpointing else (activation_per_layer*l)/10**9
activation_b1=math.sqrt((activation_per_layer*l/b)/10**9) if Gradient_checkpointing else (activation_per_layer*l/b)/10**9
input_all=input+activation
train_memory_factor=2+2+4*3
train_memory = 2*model + (2+4*3)*model_train
total_memory=round(train_memory+input_all*2, 2)
list_of_dicts=[
{'Module': 'emb', 'Size/B':round(emb, 2), 'Eval_memory/GB': round(emb*2, 2), 'Train_momery/GB': round(emb*train_memory_factor, 2)},
{'Module': 'one_layer', 'Size/B':round(blk, 2), 'Eval_memory/GB': round(blk*2, 2), 'Train_momery/GB': round(blk*train_memory_factor, 2)},
{'Module': 'input', 'Size/B':round(input, 2), 'Eval_memory/GB': round(input*2, 2), 'Train_momery/GB': round(input*2, 2)},
{'Module': 'activation(batchsize=1)', 'Size/B':round(activation_b1, 2), 'Eval_memory/GB': round(activation_b1*2, 2), 'Train_momery/GB': round(activation_b1*2, 2)},
{'Module': 'ALL', 'Size/B':round(model+input_all, 2), 'Eval_memory/GB': round(model*2+input_all*2, 2), 'Train_momery/GB': total_memory}
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print(f'-----Strategy_reference_indicator(Batch_size={b})-----')
train_memory_factor_zero1=2+2+(4*3)/NGPU
train_memory_zero1=2*model+(2+(4*3))/NGPU*model_train
train_memory_factor_zero2=2+(2+4*3)/NGPU
train_memory_zero2=2*model+(2+4*3)/NGPU*model_train
train_memory_factor_zero3=(2+2+4*3)/NGPU
train_memory_zero3=2*model/NGPU+(2+(4*3))/NGPU*model_train
list_of_dicts=[
{'Strategy': 'Zero1','Eval_memory_per_gpu/GB': round(model*2, 2), 'Train_momery_per_gpu/GB': round(train_memory_zero1+input_all*2, 2)},
{'Strategy': 'Zero2','Eval_memory_per_gpu/GB': round(model*2, 2), 'Train_momery_per_gpu/GB': round(train_memory_zero2+input_all*2, 2)},
{'Strategy': 'Zero3','Eval_memory_per_gpu/GB': round(model*2/NGPU, 2), 'Train_momery_per_gpu/GB': round(train_memory_zero3+input_all*2, 2)},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print(f'---------------------Strategy_Recommand---------------------')
trianing_days=round(BTOKEN*1e9*8*model*1e9/(NGPU*TFLOPS*1e12*60*60*24),2) # https://arxiv.org/pdf/2104.04473.pdf
if total_memory>single_mem*NGPU:
print(f'Minimal_Memory_needed:{total_memory}GB, Give_usable_memory:{single_mem}*{NGPU}={single_mem*NGPU}GB')
print("You may try ZeRO-Infinity(Zero3+offload_param(nvme)+offload_optimizer(nvme)) strategy, If it doesn't work, please increase GPU. I'm sorry.")
return
if abs(list_of_dicts[0]['Train_momery_per_gpu/GB']-single_mem)<3 or list_of_dicts[0]['Train_momery_per_gpu/GB']<single_mem:
if list_of_dicts[0]['Train_momery_per_gpu/GB']<single_mem:
print('Recommand_Strategy:')
list_of_dicts=[
{'Zero': 'Zero1','DP': NGPU, 'TP': 1, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero1+input_all*2, 2), 'Trianing_days': trianing_days},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print('Please find the best batch_size by adjusting BATCH_SIZE')
return
else:
print('Recommand_Strategy:')
list_of_dicts=[
{'Zero': 'Zero1+offload','DP': NGPU, 'TP': 1, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero1+input_all*2, 2), 'Trianing_days': trianing_days},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print('Please find the best batch_size by adjusting BATCH_SIZE')
return
elif list_of_dicts[1]['Train_momery_per_gpu/GB']<single_mem:
print('Recommand_Strategy:')
list_of_dicts=[
{'Zero': 'Zero2','DP': NGPU, 'TP': 1, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero2+input_all*2, 2), 'Trianing_days': trianing_days},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print('Please find the best batch_size by adjusting BATCH_SIZE')
return
print("You can't use pure Zero1 or Zero2 strategy.")
single_node_mem=GPU_PER_NODE*single_mem
if total_memory < single_node_mem:
TP=GPU_PER_NODE
DP=NGPU/TP
train_memory_factor_zero1=2+2+(4*3)/DP
print('Recommand_Strategy:')
list_of_dicts=[
{'Zero': 'Zero1+TP','DP': DP, 'TP': TP, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero1/TP+input_all*2/TP, 2), 'Trianing_days': trianing_days},
{'Zero': 'Zero3+(offload)','DP': NGPU, 'TP': 1, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero3+input_all*2, 2), 'Trianing_days': trianing_days},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print('Please find the best batch_size by adjusting BATCH_SIZE')
return
else:
PP=math.ceil(total_memory/single_node_mem)
TP=GPU_PER_NODE
DP=1 if NGPU/TP/PP < 1 else int(NGPU/TP/PP)
if DP==1:
PP=int(NGPU/TP)
train_memory_factor_zero1=2+2+(4*3)/DP
print('Recommand_Strategy:')
list_of_dicts=[
{'Zero': 'Zero1+TP+PP','DP': DP, 'TP': TP, 'PP':PP, 'Train_momery_per_gpu/GB': round(train_memory_zero1/PP/TP+input_all*2/TP, 2), 'Trianing_days': trianing_days},
{'Zero': 'Zero3+(offload)','DP': NGPU, 'TP': 1, 'PP':1, 'Train_momery_per_gpu/GB': round(train_memory_zero3+input_all*2, 2), 'Trianing_days': trianing_days},
]
print(tabulate(list_of_dicts, headers="keys", tablefmt="grid"))
print('Please find the best batch_size by adjusting BATCH_SIZE')
return
if __name__=="__main__":
main()