forked from liucongg/GPT2-NewsTitle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_title.py
191 lines (175 loc) · 9.48 KB
/
generate_title.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# -*- coding:utf-8 -*-
# @project: GPT2-NewsTitle
# @filename: generate_title.py
# @author: 刘聪NLP
# @contact: [email protected]
# @time: 2020/12/16 16:29
"""
文件说明:
根据训练好的模型,进行新闻标题生成,预测文件
"""
import torch
import os
import argparse
from model import GPT2LMHeadModel
from transformers import BertTokenizer
import torch.nn.functional as F
import copy
def set_args():
"""设置模型预测所需参数"""
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0', type=str, help='设置预测时使用的显卡,使用CPU设置成-1即可')
parser.add_argument('--model_path', default='output_dir/checkpoint-139805', type=str, help='模型文件路径')
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, help='词表,该词表为小词表,并增加了一些新的标记')
parser.add_argument('--batch_size', default=3, type=int, help='生成标题的个数')
parser.add_argument('--generate_max_len', default=32, type=int, help='生成标题的最大长度')
parser.add_argument('--repetition_penalty', default=1.2, type=float, help='重复处罚率')
parser.add_argument('--top_k', default=5, type=float, help='解码时保留概率最高的多少个标记')
parser.add_argument('--top_p', default=0.95, type=float, help='解码时保留概率累加大于多少的标记')
parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
return parser.parse_args()
def top_k_top_p_filtering(logits, top_k, top_p, filter_value=-float("Inf")):
"""
top_k或top_p解码策略,仅保留top_k个或累积概率到达top_p的标记,其他标记设为filter_value,后续在选取标记的过程中会取不到值设为无穷小。
Args:
logits: 预测结果,即预测成为词典中每个词的分数
top_k: 只保留概率最高的top_k个标记
top_p: 只保留概率累积达到top_p的标记
filter_value: 过滤标记值
Returns:
"""
# logits的维度必须为2,即size:[batch_size, vocab_size]
assert logits.dim() == 2
# 获取top_k和字典大小中较小的一个,也就是说,如果top_k大于字典大小,则取字典大小个标记
top_k = min(top_k, logits[0].size(-1))
# 如果top_k不为0,则将在logits中保留top_k个标记
if top_k > 0:
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的top_k标记
for logit in logits:
indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
logit[indices_to_remove] = filter_value
# 如果top_p不为0,则将在logits中保留概率值累积达到top_p的标记
if top_p > 0.0:
# 对logits进行递减排序
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# 对排序后的结果使用softmax归一化,再获取累积概率序列
# 例如:原始序列[0.1, 0.2, 0.3, 0.4],则变为:[0.1, 0.3, 0.6, 1.0]
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 删除累积概率高于top_p的标记
sorted_indices_to_remove = cumulative_probs > top_p
# 将索引向右移动,使第一个标记也保持在top_p之上
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for index, logit in enumerate(logits):
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的累积概率达到top_p的标记
indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
logit[indices_to_remove] = filter_value
return logits
def predict_one_sample(model, tokenizer, device, args, content):
"""
对单个样本进行预测
Args:
model: 模型
tokenizer: 分词器
device: 设备信息
args: 配置项信息
content: 新闻正文
Returns:
"""
# 对新闻正文进行预处理,并判断如果超长则进行截断
content_tokens = tokenizer.tokenize(content)
if len(content_tokens) > args.max_len - 3 - args.generate_max_len:
content_tokens = content_tokens[:args.max_len - 3 - args.generate_max_len]
# 获取content_id、title_id、unk_id、sep_id值
content_id = tokenizer.convert_tokens_to_ids("[Content]")
title_id = tokenizer.convert_tokens_to_ids("[Title]")
unk_id = tokenizer.convert_tokens_to_ids("[UNK]")
sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
# 将tokens索引化,变成模型所需格式
content_tokens = ["[CLS]"] + content_tokens + ["[SEP]"]
input_ids = tokenizer.convert_tokens_to_ids(content_tokens)
# 将input_ids和token_type_ids进行扩充,扩充到需要预测标题的个数,即batch_size
input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
token_type_ids = [[content_id] * len(content_tokens) for _ in range(args.batch_size)]
# 将input_ids和token_type_ids变成tensor
input_tensors = torch.tensor(input_ids).long().to(device)
token_type_tensors = torch.tensor(token_type_ids).long().to(device)
next_token_type = torch.tensor([[title_id] for _ in range(args.batch_size)]).long().to(device)
# 用于存放每一步解码的结果
generated = []
# 用于存放,完成解码序列的序号
finish_set = set()
with torch.no_grad():
# 遍历生成标题最大长度
for _ in range(args.generate_max_len):
outputs = model(input_ids=input_tensors, token_type_ids=token_type_tensors)
# 获取预测结果序列的最后一个标记,next_token_logits size:[batch_size, vocab_size]
next_token_logits = outputs[0][:, -1, :]
# 对batch_size进行遍历,将词表中出现在序列中的词的概率进行惩罚
for index in range(args.batch_size):
for token_id in set([token_ids[index] for token_ids in generated]):
next_token_logits[index][token_id] /= args.repetition_penalty
# 对batch_size进行遍历,将词表中的UNK的值设为无穷小
for next_token_logit in next_token_logits:
next_token_logit[unk_id] = -float("Inf")
# 使用top_k_top_p_filtering函数,按照top_k和top_p的值,对预测结果进行筛选
filter_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p)
# 对filter_logits的每一行做一次取值,输出结果是每一次取值时filter_logits对应行的下标,即词表位置(词的id)
# filter_logits中的越大的值,越容易被选中
next_tokens = torch.multinomial(F.softmax(filter_logits, dim=-1), num_samples=1)
# 判断如果哪个序列的预测标记为sep_id时,则加入到finish_set
for index, token_id in enumerate(next_tokens[:, 0]):
if token_id == sep_id:
finish_set.add(index)
# 判断,如果finish_set包含全部的序列序号,则停止预测;否则继续预测
finish_flag = True
for index in range(args.batch_size):
if index not in finish_set:
finish_flag = False
break
if finish_flag:
break
# 将预测标记添加到generated中
generated.append([token.item() for token in next_tokens[:, 0]])
# 将预测结果拼接到input_tensors和token_type_tensors上,继续下一次预测
input_tensors = torch.cat((input_tensors, next_tokens), dim=-1)
token_type_tensors = torch.cat((token_type_tensors, next_token_type), dim=-1)
# 用于存储预测结果
candidate_responses = []
# 对batch_size进行遍历,并将token_id变成对应汉字
for index in range(args.batch_size):
responses = []
for token_index in range(len(generated)):
# 判断,当出现sep_id时,停止在该序列中添加token
if generated[token_index][index] != sep_id:
responses.append(generated[token_index][index])
else:
break
# 将token_id序列变成汉字序列,去除"##",并将[Space]替换成空格
candidate_responses.append(
"".join(tokenizer.convert_ids_to_tokens(responses)).replace("##", "").replace("[space]", " "))
return candidate_responses
def main():
"""主函数"""
# 设置预测的配置参数
args = set_args()
# 获取设备信息
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICE"] = args.device
device = torch.device("cuda" if torch.cuda.is_available() and int(args.device) >= 0 else "cpu")
# 实例化tokenizer和model
tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=True)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model.to(device)
model.eval()
print('开始对新闻生成标题,输入CTRL + Z,则退出')
try:
while True:
content = input("输入的新闻正文为:")
titles = predict_one_sample(model, tokenizer, device, args, content)
for i, title in enumerate(titles):
print("生成的第{}个标题为:{}".format(i + 1, title))
except:
pass
if __name__ == '__main__':
main()