-
Notifications
You must be signed in to change notification settings - Fork 483
/
decoder.py
135 lines (117 loc) · 4.96 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import Levenshtein as Lev
import torch
from six.moves import xrange
class Decoder(object):
"""
Basic decoder class from which all other decoders inherit. Implements several
helper functions. Subclasses should implement the decode() method.
Arguments:
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
space_index (int, optional): index for the space ' ' character. Defaults to 28.
"""
def __init__(self, labels, blank_index=0):
# e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#"
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
"""
space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
if ' ' in labels:
space_index = labels.index(' ')
self.space_index = space_index
"""
def wer(self, s1, s2):
"""
Computes the Word Error Rate, defined as the edit distance between the
two provided sentences after tokenizing to words.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
# build mapping of words to integers
b = set(s1.split() + s2.split())
word2char = dict(zip(b, range(len(b))))
# map the words to a char array (Levenshtein packages only accepts
# strings)
w1 = [chr(word2char[w]) for w in s1.split()]
w2 = [chr(word2char[w]) for w in s2.split()]
return Lev.distance("".join(w1), "".join(w2))
def cer(self, s1, s2):
"""
Computes the Character Error Rate, defined as the edit distance.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
s1, s2, = s1.replace(" ", ""), s2.replace(" ", "")
return Lev.distance(s1, s2)
def decode(self, probs, sizes=None):
"""
Given a matrix of character probabilities, returns the decoder's
best guess of the transcription
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes(optional): Size of each sequence in the mini-batch
Returns:
string: sequence of the model's best guess for the transcription
"""
raise NotImplementedError
class GreedyDecoder(Decoder):
def __init__(self, labels, blank_index=0):
super(GreedyDecoder, self).__init__(labels, blank_index)
def convert_to_strings(
self, sequences, sizes=None, remove_repetitions=False, return_offsets=False
):
"""Given a list of numeric sequences, returns the corresponding strings"""
strings = []
offsets = [] if return_offsets else None
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string, string_offsets = self.process_string(
sequences[x], seq_len, remove_repetitions
)
strings.append([string]) # We only return one path
if return_offsets:
offsets.append([string_offsets])
if return_offsets:
return strings, offsets
else:
return strings
def process_string(self, sequence, size, remove_repetitions=False):
string = ""
offsets = []
for i in range(size):
char = self.int_to_char[sequence[i].item()]
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true, then skip
if (
remove_repetitions
and i != 0
and char == self.int_to_char[sequence[i - 1].item()]
):
pass
else:
string = string + char
offsets.append(i)
return string, torch.tensor(offsets, dtype=torch.int)
def decode(self, probs, sizes=None):
"""
Returns the argmax decoding given the probability matrix. Removes
repeated elements in the sequence, as well as blanks.
Arguments:
probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
sizes(optional): Size of each sequence in the mini-batch
Returns:
strings: sequences of the model's best guess for the transcription on inputs
offsets: time step per character predicted
"""
_, max_probs = torch.max(probs, 2)
strings, offsets = self.convert_to_strings(
max_probs.view(max_probs.size(0), max_probs.size(1)),
sizes,
remove_repetitions=True,
return_offsets=True,
)
return strings, offsets