-
Notifications
You must be signed in to change notification settings - Fork 21
/
export_pt.py
executable file
·206 lines (161 loc) · 6.71 KB
/
export_pt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================
import os
import torch
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser(description='export onnx')
parser.add_argument('-m', '--model_path', type=str, help='path to the torch model')
parser.add_argument('-s', '--seq_length', type=int, default=512, help="sequence length")
parser.add_argument('-d', '--device', type=str, choices=["cpu", "cuda"], default="cpu")
parser.add_argument('--save_dir', type=str, default=f"./tmp/onnx")
parser.add_argument('--guess_length', type=int, default=5)
args = parser.parse_args()
model_path = args.model_path
folder = f"./tmp/onnx"
device = torch.device(args.device)
origin_model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True,
torch_dtype=torch.float).eval()
for param in origin_model.parameters():
param.requires_grad = False
config = origin_model.config
transformer = origin_model.model
layers = transformer.layers
SEQ_LENGTH = args.seq_length
NUM_LAYERS = config.num_hidden_layers
HIDDEN_SIZE = config.hidden_size
NUM_ATTENTION_HEADS = config.num_attention_heads
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
VOCAB_SIZE = config.vocab_size
GUESS_LENGTH = args.guess_length
print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n')
class Embedding(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids):
out = transformer.embed_tokens(input_ids)
return out.float()
class QwenBlock(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id]
def forward(self, hidden_states, position_ids, attention_mask):
hidden_states, past_kv = self.layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()
class QwenBlockCache(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id]
def forward(self, hidden_states, position_ids, attention_mask, past_k,
past_v):
hidden_states, past_kv = self.layer(
hidden_states,
past_key_value=(past_k, past_v),
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()
class LmHead(torch.nn.Module):
def __init__(self, top_k = 50, min_tokens_to_keep = 5):
super().__init__()
self.top_k = top_k
self.min_tokens_to_keep = min_tokens_to_keep
self.keep_matrix = torch.zeros((1, self.top_k), dtype=torch.bool)
self.keep_matrix[0, :self.min_tokens_to_keep] = True
def forward(self, hidden_states, input_ids, top_p, temperature, penalty):
hidden_states = transformer.norm(hidden_states)
m_logits = origin_model.lm_head(hidden_states)
# repeat penalty
logits = torch.gather(m_logits, 1, input_ids)
logits = torch.where(logits < 0, logits * penalty, logits / penalty)
m_logits.scatter_(1, input_ids, logits)
# top_k
logits, token = torch.topk(m_logits.float(), self.top_k)
# temperature
logits = logits / temperature
# top_p
cumulative_probs = logits.softmax(dim=1).cumsum(dim=1)
mask = cumulative_probs < top_p
mask = mask + self.keep_matrix
filtered_logits = torch.where(mask, logits, torch.FloatTensor([-1000.]))
probs = filtered_logits.softmax(dim=1)
return probs, token
def convert_block(layer_id):
model = QwenBlock(layer_id)
hidden_states = torch.randn(
(1, SEQ_LENGTH, HIDDEN_SIZE)).to(torch.float).to(device)
position_ids = torch.tensor(
[range(SEQ_LENGTH)], dtype=torch.long).to(device)
attention_mask = torch.randn(
(1, 1, SEQ_LENGTH, SEQ_LENGTH)).to(torch.float).to(device)
torch.onnx.export(
model, (hidden_states, position_ids, attention_mask),
f'{folder}/block_{layer_id}.onnx',
verbose=False,
input_names=['input_states', 'position_ids', 'attention_mask'],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)
def convert_block_cache(layer_id):
model = QwenBlockCache(layer_id)
hidden_states = torch.randn((1, GUESS_LENGTH, HIDDEN_SIZE)).to(torch.float).to(device)
position_ids = torch.tensor([range(GUESS_LENGTH)], dtype=torch.long).to(device)
attention_mask = torch.ones(
(1, 1, GUESS_LENGTH, SEQ_LENGTH + GUESS_LENGTH)).to(torch.float).to(device)
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).to(torch.float).to(device)
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).to(torch.float).to(device)
torch.onnx.export(
model, (hidden_states, position_ids, attention_mask, past_k, past_v),
f'{folder}/block_cache_{layer_id}.onnx',
verbose=False,
input_names=[
'input_states', 'position_ids', 'attention_mask', 'history_k',
'history_v'
],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)
def convert_embedding():
model = Embedding()
input_ids = torch.tensor([range(SEQ_LENGTH)]).to(device)
module = torch.jit.trace(model.forward, input_ids)
torch.jit.save(module, f'{folder}/embedding.pt')
def convert_lm_head():
model = LmHead()
hidden_states = torch.randn(GUESS_LENGTH, HIDDEN_SIZE).bfloat16().to(device)
input_ids = torch.tensor([range(SEQ_LENGTH)])
top_p = torch.tensor([0.8])
temperature = torch.tensor([0.98])
penalty = torch.tensor([0.98])
module = torch.jit.trace(model.forward, (hidden_states, input_ids, top_p, temperature, penalty))
torch.jit.save(module, f'{folder}/lm_head.pt')
# create folder to store onnx
if not os.path.exists(folder):
os.makedirs(folder)
# export models
print(f'Convert block & block_cache')
for i in tqdm(range(NUM_LAYERS)):
convert_block(i)
convert_block_cache(i)
print(f'Convert embedding')
convert_embedding()
print(f'Convert lm_head')
convert_lm_head()