-
Notifications
You must be signed in to change notification settings - Fork 538
/
utils.py
59 lines (45 loc) · 1.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pickle
def merge_maps(dict1, dict2):
"""用于合并两个word2id或者两个tag2id"""
for key in dict2.keys():
if key not in dict1:
dict1[key] = len(dict1)
return dict1
def save_model(model, file_name):
"""用于保存模型"""
with open(file_name, "wb") as f:
pickle.dump(model, f)
def load_model(file_name):
"""用于加载模型"""
with open(file_name, "rb") as f:
model = pickle.load(f)
return model
# LSTM模型训练的时候需要在word2id和tag2id加入PAD和UNK
# 如果是加了CRF的lstm还要加入<start>和<end> (解码的时候需要用到)
def extend_maps(word2id, tag2id, for_crf=True):
word2id['<unk>'] = len(word2id)
word2id['<pad>'] = len(word2id)
tag2id['<unk>'] = len(tag2id)
tag2id['<pad>'] = len(tag2id)
# 如果是加了CRF的bilstm 那么还要加入<start> 和 <end>token
if for_crf:
word2id['<start>'] = len(word2id)
word2id['<end>'] = len(word2id)
tag2id['<start>'] = len(tag2id)
tag2id['<end>'] = len(tag2id)
return word2id, tag2id
def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False):
assert len(word_lists) == len(tag_lists)
for i in range(len(word_lists)):
word_lists[i].append("<end>")
if not test: # 如果是测试数据,就不需要加end token了
tag_lists[i].append("<end>")
return word_lists, tag_lists
def flatten_lists(lists):
flatten_list = []
for l in lists:
if type(l) == list:
flatten_list += l
else:
flatten_list.append(l)
return flatten_list