forked from chrisdonahue/LakhNES
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
115 lines (91 loc) · 2.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch
import torch.nn.functional as F
def load_vocab(vocab_fp):
idx2sym = ['<S>']
wait_amts = []
with open(vocab_fp, 'r') as f:
for line in f:
idx2sym.append(line.strip().split(',')[-1])
if line[:2] == 'WT':
wait_amts.append(int(line[3:]))
sym2idx = {s:i for i, s in enumerate(idx2sym)}
return idx2sym, sym2idx, wait_amts
def quantize_wait_event(wait_event):
wait_time = int(wait_event[3:])
diff = float('inf')
candidate = None
for t in wait_amts:
cur_diff = abs(wait_time - t)
if cur_diff < diff:
diff = cur_diff
candidate = t
else:
break
return 'WT_{}'.format(candidate)
class TxlSimpleSampler:
def __init__(self, model, device, tgt_len=1, mem_len=896, ext_len=0):
if tgt_len != 1:
raise ValueError()
if ext_len != 0:
raise ValueError()
self.model = model
self.model.eval()
self.model.reset_length(1, ext_len, mem_len)
self.device = device
self.reset()
def reset(self):
self.mems = []
self.generated = []
@torch.no_grad()
def sample_next_token_updating_mem(self, last_token=None, temp=1., topk=None, exclude_eos=True):
last_token = last_token if last_token is not None else 0
# Ensure that user is always passing 0 on first call
if len(self.generated) == 0:
assert len(self.mems) == 0
if last_token != 0:
raise Exception()
# Ensure that user isn't passing 0 after first call
if last_token == 0 and len(self.generated) > 0:
raise Exception()
# Sanitize sampling params
if temp < 0:
raise ValueError()
if topk is not None and topk < 1:
raise ValueError()
# Append last input token because we've officially selected it
self.generated.append(last_token)
# Create input array
_inp = [last_token]
_inp = np.array(_inp, dtype=np.int64)[:, np.newaxis]
inp = torch.from_numpy(_inp).to(self.device)
# Evaluate the model, saving its memory.
ret = self.model.forward_generate(inp, *self.mems)
all_logits, self.mems = ret[0], ret[1:]
# Select last timestep, only batch item
logits = all_logits[-1, 0]
if exclude_eos:
logits = logits[1:]
# Handle temp 0 (argmax) case
if temp == 0:
probs = torch.zeros_like(logits)
probs[logits.argmax()] = 1.
else:
# Apply temperature spec
if temp != 1:
logits /= temp
# Compute softmax
probs = F.softmax(logits, dim=-1)
if exclude_eos:
probs = F.pad(probs, [1, 0])
# Select top-k if specified
if topk is not None:
_, top_idx = torch.topk(probs, topk)
mask = torch.zeros_like(probs)
mask[top_idx] = 1.
probs *= mask
probs /= probs.sum()
# Sample from probabilities
token = torch.multinomial(probs, 1)
token = int(token.item())
return token, probs