-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
73 lines (55 loc) · 2.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, intoken, outtoken, hidden, nlayers=3, dropout=0.1):
super(TransformerModel, self).__init__()
nhead = hidden // 64
self.encoder = nn.Embedding(intoken, hidden)
self.pos_encoder = PositionalEncoding(hidden, dropout)
self.decoder = nn.Embedding(outtoken, hidden)
self.pos_decoder = PositionalEncoding(hidden, dropout)
self.inscale = math.sqrt(intoken)
self.outscale = math.sqrt(outtoken)
self.transformer = nn.Transformer(d_model=hidden, nhead=nhead, num_encoder_layers=nlayers,
num_decoder_layers=nlayers, dim_feedforward=hidden, dropout=dropout)
self.fc_out = nn.Linear(hidden, outtoken)
self.src_mask = None
self.trg_mask = None
self.memory_mask = None
def generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz), 1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def make_len_mask(self, inp):
return (inp == 0).transpose(0, 1)
def forward(self, src, trg):
if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device)
src_pad_mask = self.make_len_mask(src)
trg_pad_mask = self.make_len_mask(trg)
src = self.encoder(src)
src = self.pos_encoder(src)
trg = self.decoder(trg)
trg = self.pos_decoder(trg)
output = self.transformer(src, trg, tgt_mask=self.trg_mask)
# output = self.transformer(src, trg, src_mask=self.src_mask, tgt_mask=self.trg_mask,
# memory_mask=self.memory_mask,
# src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask,
# memory_key_padding_mask=src_pad_mask)
output = self.fc_out(output)
return output