forked from THUDM/ProteinLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
megatron-converter.py
128 lines (99 loc) · 6.19 KB
/
megatron-converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python
# coding: utf-8
from copy import copy
from copy import deepcopy
# # Start Process
import torch
import argparse
parser = argparse.ArgumentParser(description="Transfer Megatron LM's checkpoint to Tape's pretrain model.")
# add args
parser.add_argument('-src', '--src', type=str, help='megatron checkpoint location')
parser.add_argument('-dst', '--dst', type=str, help='tape\'s untrained checkpoint location')
parser.add_argument('-out', '--out', type=str, help='save the transferred model to here')
parser.add_argument('-dtp', '--dtype', type=str, help='destination checkpoint\'s data type, default=fp32. '
'If you want to specify data type, '
'use expressions like `torch.float64`')
parser.add_argument('-hidden', '--hidden_dim', type=int, default=1024, help='hidden size of the encoder layers')
parser.add_argument('-heads', '--num_heads', type=int, default=16, help='number of attention heads for each attention '
'layer in the ProteinBert encoder')
parser.add_argument('-layers', '--num_layers', type=int, default=16, help='number of hidden layers in the ProteinBert encoder')
args = parser.parse_args()
heads = args.num_heads
hidden_dim = args.hidden_dim
heads_dim = hidden_dim // heads
layers = args.num_layers
# dst's data type
if args.dtype == None:
DTYPE = torch.float32
else:
DTYPE = eval(args.dtype)
# load model
lm = torch.load(args.src)['model']['language_model']
emb = lm['embedding']
trans = lm['transformer']
tape = deepcopy(torch.load(args.dst))
def shape_check(dst, src):
assert dst.shape == src.shape, "shape mismatch"
# Part 1
# ## Embedding Part
# 1.1
# ### Process word-emb
# [PAD] # [MASK] # [CLS] # [SEP] # [UNK] # [unused1] # A
reserved_toks = emb['word_embeddings']['weight'][0:5]
protein_toks = emb['word_embeddings']['weight'][6:31]
toks = torch.cat((reserved_toks, protein_toks))
shape_check(tape['bert.embeddings.word_embeddings.weight'], toks)
tape['bert.embeddings.word_embeddings.weight'] = deepcopy(toks)
# 1.2
# ### Process Pos
shape_check(tape['bert.embeddings.position_embeddings.weight'], emb['position_embeddings']['weight'])
tape['bert.embeddings.position_embeddings.weight'] = emb['position_embeddings']['weight'].clone()
# 1.3
# ### Process token_type
shape_check(tape['bert.embeddings.token_type_embeddings.weight'], torch.zeros_like(tape['bert.embeddings.token_type_embeddings.weight']))
tape['bert.embeddings.token_type_embeddings.weight'] = torch.zeros_like(tape['bert.embeddings.token_type_embeddings.weight']).clone()
# Part 2
# ## Process the bert layers
for layer in range(layers):
shape_check(tape[f"bert.encoder.layer.{layer}.input_ln.weight"], trans[f"layers.{layer}.input_layernorm.weight"])
shape_check(tape[f"bert.encoder.layer.{layer}.input_ln.bias"], trans[f"layers.{layer}.input_layernorm.bias"])
tape[f"bert.encoder.layer.{layer}.input_ln.weight"] = trans[f"layers.{layer}.input_layernorm.weight"].clone()
tape[f"bert.encoder.layer.{layer}.input_ln.bias"] = trans[f"layers.{layer}.input_layernorm.bias"].clone()
# attention
wq, wk, wv = trans[f"layers.{layer}.attention.query_key_value.weight"].clone().view(heads, heads_dim * 3, -1).split(heads_dim, dim=1)
bq, bk, bv = trans[f"layers.{layer}.attention.query_key_value.bias"].clone().view(heads, heads_dim * 3).split(heads_dim, dim=1)
shape_check(tape[f"bert.encoder.layer.{layer}.attention.self.query.weight"], wq.contiguous().view(hidden_dim, hidden_dim))
shape_check(tape[f"bert.encoder.layer.{layer}.attention.self.query.bias"], bq.contiguous().view(-1))
tape[f"bert.encoder.layer.{layer}.attention.self.query.weight"] = wq.contiguous().view(hidden_dim, hidden_dim).clone()
tape[f"bert.encoder.layer.{layer}.attention.self.query.bias"] = bq.contiguous().view(-1).clone()
tape[f"bert.encoder.layer.{layer}.attention.self.key.weight"] = wk.contiguous().view(hidden_dim, hidden_dim).clone()
tape[f"bert.encoder.layer.{layer}.attention.self.key.bias"] = bk.contiguous().view(-1).clone()
tape[f"bert.encoder.layer.{layer}.attention.self.value.weight"] = wv.contiguous().view(hidden_dim, hidden_dim).clone()
tape[f"bert.encoder.layer.{layer}.attention.self.value.bias"] = bv.contiguous().view(-1).clone()
shape_check(tape[f"bert.encoder.layer.{layer}.attention.dense.weight"], trans[
f"layers.{layer}.attention.dense.weight"])
shape_check(tape[f"bert.encoder.layer.{layer}.post_attn_ln.weight"], trans[
f"layers.{layer}.post_attention_layernorm.weight"])
tape[f"bert.encoder.layer.{layer}.attention.dense.weight"] = trans[
f"layers.{layer}.attention.dense.weight"].clone()
tape[f"bert.encoder.layer.{layer}.attention.dense.bias"] = trans[
f"layers.{layer}.attention.dense.bias"].clone()
tape[f"bert.encoder.layer.{layer}.post_attn_ln.weight"] = trans[
f"layers.{layer}.post_attention_layernorm.weight"].clone()
tape[f"bert.encoder.layer.{layer}.post_attn_ln.bias"] = trans[
f"layers.{layer}.post_attention_layernorm.bias"].clone()
# dense h-4h-h, corresponding to Mega's self.mlp
shape_check(tape[f"bert.encoder.layer.{layer}.mlp.dense_h_4h.weight"], trans[
f"layers.{layer}.mlp.dense_h_to_4h.weight"])
shape_check(tape[f"bert.encoder.layer.{layer}.mlp.dense_4h_h.weight"], trans[f"layers.{layer}.mlp.dense_4h_to_h.weight"])
tape[f"bert.encoder.layer.{layer}.mlp.dense_h_4h.weight"] = trans[
f"layers.{layer}.mlp.dense_h_to_4h.weight"].clone()
tape[f"bert.encoder.layer.{layer}.mlp.dense_h_4h.bias"] = trans[
f"layers.{layer}.mlp.dense_h_to_4h.bias"].clone()
tape[f"bert.encoder.layer.{layer}.mlp.dense_4h_h.weight"] = trans[f"layers.{layer}.mlp.dense_4h_to_h.weight"].clone()
tape[f"bert.encoder.layer.{layer}.mlp.dense_4h_h.bias"] = trans[f"layers.{layer}.mlp.dense_4h_to_h.bias"].clone()
# Part 3
shape_check(tape['bert.encoder.final_ln.weight'], trans['final_layernorm.weight'])
tape['bert.encoder.final_ln.weight'] = trans['final_layernorm.weight'].clone()
tape['bert.encoder.final_ln.bias'] = trans['final_layernorm.bias'].clone()
torch.save(tape, args.out, _use_new_zipfile_serialization=False)