-
Notifications
You must be signed in to change notification settings - Fork 2
/
beam_decoder.py
218 lines (178 loc) · 9.08 KB
/
beam_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from greedy_decoder import BatchSampler
from queue import PriorityQueue
import operator
import torch
import torch.nn.functional as F
import re
class BeamSearchNode:
"""A class to represent the node during the beam search"""
def __init__(self, hidd_state, prev_node, word_idx, log_prob, length):
"""
Args:
hidd_state: decoder hidden state
prev_node: the previous node (parent)
word_idx: the word index
log_prob: the log probability
length: length of decoded sentence
"""
self.h = hidd_state
self.prevNode = prev_node
self.wordid = word_idx
self.logp = log_prob
self.leng = length
def eval(self, alpha=1):
reward = 0
# Add here a function for shaping a reward
# the log prob will be normalized by the length of the sentence
# as defined by Wu et. al: https://arxiv.org/pdf/1609.08144.pdf
return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
#return self.logp / float(self.leng)**alpha
def __lt__(self, other):
"""Overriding the less than function to handle
the case if two nodes have the same log_prob so
they can fit in the priority queue"""
return self.logp < other.logp
class BeamSampler(BatchSampler):
"""A subclass of BatchSampler that uses beam_search for decoding"""
def __init__(self, model, src_vocab_char,
src_vocab_word, trg_vocab_char,
src_labels_vocab, trg_labels_vocab,
trg_gender_vocab):
super(BeamSampler, self).__init__(model, src_vocab_char,
src_vocab_word, trg_vocab_char,
src_labels_vocab, trg_labels_vocab,
trg_gender_vocab
)
def beam_decode(self, sentence, trg_gender=None, topk=3, beam_width=5, max_len=512):
"""
Args:
sentence: the source sentence
topk: number of sentences to generate from beam search. Defaults to 3
beam_width: the beam size. If 1, then we do greed search. Defaults to 5
max_len: the maximum length of the decoded sentence. Defaults to 512
Returns:
decoded_sentences: list of tuples. Each tuple is (log_prob, decoded_sentence)
"""
# vectorizing the src sentence on the char level and word level
sentence = re.split(r'(\s+)', sentence)
vectorized_src_sentence_char = [self.src_vocab_char.sos_idx]
vectorized_src_sentence_word = [self.src_vocab_word.sos_idx]
for word in sentence:
for c in word:
vectorized_src_sentence_char.append(self.src_vocab_char.lookup_token(c))
vectorized_src_sentence_word.append(self.src_vocab_word.lookup_token(word))
vectorized_src_sentence_word.append(self.src_vocab_word.eos_idx)
vectorized_src_sentence_char.append(self.src_vocab_char.eos_idx)
# getting sentence length
src_sentence_length = [len(vectorized_src_sentence_char)]
# vectorizing the trg gender
if trg_gender:
vectorized_trg_gender = self.trg_gender_vocab.lookup_token(trg_gender)
vectorized_trg_gender = torch.tensor([vectorized_trg_gender], dtype=torch.long)
else:
vectorized_trg_gender = None
# converting the lists to tensors
vectorized_src_sentence_char = torch.tensor([vectorized_src_sentence_char], dtype=torch.long)
vectorized_src_sentence_word = torch.tensor([vectorized_src_sentence_word], dtype=torch.long)
src_sentence_length = torch.tensor(src_sentence_length, dtype=torch.long)
# passing the src sentence to the encoder
with torch.no_grad():
encoder_outputs, encoder_h_t = self.model.encoder(vectorized_src_sentence_char,
vectorized_src_sentence_word,
src_sentence_length
)
# creating attention mask
attention_mask = self.model.create_mask(vectorized_src_sentence_char, self.src_vocab_char.pad_idx)
# initializing the first decoder_h_t to encoder_h_t
decoder_hidden = encoder_h_t
#decoder_hidden = torch.tanh(self.model.linear_map(encoder_h_t))
context_vectors = torch.zeros(1, self.model.encoder.rnn.hidden_size * 2)
# if beam_width == 1, then we're doing greedy decoding
beam_width = beam_width
# number of candidates to generate.
topk = topk
# topk must be <= beam_width
if topk > beam_width:
raise Exception("topk candidates must be <= beam_width")
decoded_batch = []
# starting input to the decoder is the <s> token
decoder_input = torch.LongTensor([self.trg_vocab_char.sos_idx])
# number of sentences to generate
endnodes = []
number_required = min((topk + 1), topk - len(endnodes))
# starting node - hidden vector, previous node, word id, logp, length
node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
nodes = PriorityQueue()
# start the queue
# each element in the queue will be (-log_prob, beam_node)
nodes.put((-node.eval(), node))
qsize = 1
# start beam search
while max_len > 0:
max_len -= 1
# give up when decoding takes too long
if qsize > 20000:
print('hiiii')
break
# fetch the best node (i.e. node with minimum negative log prob)
score, n = nodes.get()
decoder_input = n.wordid
decoder_hidden = n.h
# if we predict the </s> token, this means we finished decoding a sentence
if n.wordid.item() == self.trg_vocab_char.eos_idx and n.prevNode != None:
endnodes.append((score, n))
# if we reached maximum # of sentences required, stop beam search
if len(endnodes) >= number_required:
break
else:
continue
# decode for one step using decoder
with torch.no_grad():
decoder_output, decoder_hidden, atten_scores, context_vectors = self.model.decoder(trg_seqs=decoder_input,
encoder_outputs=encoder_outputs,
decoder_h_t=decoder_hidden,
context_vectors=context_vectors,
attention_mask=attention_mask,
trg_gender=vectorized_trg_gender
)
# obtaining log probs from the decoder predictions
decoder_output = F.log_softmax(decoder_output, dim=1)
# PUT HERE REAL BEAM SEARCH OF TOP
log_prob, indexes = torch.topk(decoder_output, beam_width)
# indexes shape: [batch_size, beam_width]
# log_prob shape: [batch_size, beam_width]
# expanding the current beam (n)
nextnodes = []
for new_k in range(beam_width):
decoded_t = indexes[0][new_k].unsqueeze(0)
log_p = log_prob[0][new_k].item()
node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
score = -node.eval()
nextnodes.append((score, node))
# put the expanded beams in the queue
for i in range(len(nextnodes)):
score, nn = nextnodes[i]
nodes.put((score, nn))
# increase qsize
qsize += len(nextnodes) - 1
# choose topk beams
if len(endnodes) == 0:
endnodes = [nodes.get() for _ in range(topk)]
# sorting the topk beams by their negative log probs
endnodes = sorted(endnodes, key=lambda x: x[0])
# decoding
#TODO: Decoding currently works for one sentence at a time,
#Bashar needs to make it work on the batch
decoded_sentences = []
for score, n in endnodes:
decoded_sentence = []
decoded_sentence.append(n.wordid.item())
# backtrack
while n.prevNode != None:
n = n.prevNode
decoded_sentence.append(n.wordid.item())
# reversing the decoding
decoded_sentence = decoded_sentence[::-1]
decoded_sentences.append((score, decoded_sentence))
str_decoded_sentence = self.get_str_sentence(decoded_sentences[0][1], self.trg_vocab_char)
return str_decoded_sentence