forked from yandexdataschool/Practical_RL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
voc.py
72 lines (62 loc) · 2.52 KB
/
voc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
class Vocab:
def __init__(self, tokens, bos="__BOS__", eos="__EOS__", sep=''):
"""
A special class that handles tokenizing and detokenizing
"""
assert bos in tokens, eos in tokens
self.tokens = tokens
self.token_to_ix = {t: i for i, t in enumerate(tokens)}
self.bos = bos
self.bos_ix = self.token_to_ix[bos]
self.eos = eos
self.eos_ix = self.token_to_ix[eos]
self.sep = sep
def __len__(self):
return len(self.tokens)
@staticmethod
def from_lines(lines, bos="__BOS__", eos="__EOS__", sep=''):
flat_lines = sep.join(list(lines))
flat_lines = list(flat_lines.split(sep)) if sep else list(flat_lines)
tokens = sorted(set(sep.join(flat_lines)))
tokens = [t for t in tokens if t not in (bos, eos) and len(t) != 0]
tokens = [bos, eos] + tokens
return Vocab(tokens, bos, eos, sep)
def tokenize(self, string):
"""converts string to a list of tokens"""
tokens = list(filter(len, string.split(self.sep))) \
if self.sep != '' else list(string)
return [self.bos] + tokens + [self.eos]
def to_matrix(self, lines, max_len=None):
"""
convert variable length token sequences into fixed size matrix
example usage:
>>>print( as_matrix(words[:3],source_to_ix))
[[15 22 21 28 27 13 -1 -1 -1 -1 -1]
[30 21 15 15 21 14 28 27 13 -1 -1]
[25 37 31 34 21 20 37 21 28 19 13]]
"""
max_len = max_len or max(map(len, lines)) + 2 # 2 for bos and eos
matrix = np.zeros((len(lines), max_len), dtype='int32') + self.eos_ix
for i, seq in enumerate(lines):
tokens = self.tokenize(seq)
row_ix = list(map(self.token_to_ix.get, tokens))[:max_len]
matrix[i, :len(row_ix)] = row_ix
return matrix
def to_lines(self, matrix, crop=True):
"""
Convert matrix of token ids into strings
:param matrix: matrix of tokens of int32, shape=[batch,time]
:param crop: if True, crops BOS and EOS from line
:return:
"""
lines = []
for line_ix in map(list, matrix):
if crop:
if line_ix[0] == self.bos_ix:
line_ix = line_ix[1:]
if self.eos_ix in line_ix:
line_ix = line_ix[:line_ix.index(self.eos_ix)]
line = self.sep.join(self.tokens[i] for i in line_ix)
lines.append(line)
return lines