-
Notifications
You must be signed in to change notification settings - Fork 15
/
memnn.py
79 lines (65 loc) · 3.28 KB
/
memnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import to_var
import copy
import math
class MemNN(nn.Module):
def __init__(self, vocab_size, embd_size, ans_size, max_story_len, hops=3, dropout=0.2, te=True, pe=True):
super(MemNN, self).__init__()
self.hops = hops
self.embd_size = embd_size
self.temporal_encoding = te
self.position_encoding = pe
init_rng = 0.1
self.dropout = nn.Dropout(p=dropout)
self.A = nn.ModuleList([nn.Embedding(vocab_size, embd_size) for _ in range(hops+1)])
for i in range(len(self.A)):
self.A[i].weight.data.normal_(0, init_rng)
self.A[i].weight.data[0] = 0 # for padding index
self.B = self.A[0] # query encoder
# Temporal Encoding: see 4.1
if self.temporal_encoding:
self.TA = nn.Parameter(torch.Tensor(1, max_story_len, embd_size).normal_(0, 0.1))
self.TC = nn.Parameter(torch.Tensor(1, max_story_len, embd_size).normal_(0, 0.1))
def forward(self, x, q):
# x (bs, story_len, s_sent_len)
# q (bs, q_sent_len)
bs = x.size(0)
story_len = x.size(1)
s_sent_len = x.size(2)
# Position Encoding
if self.position_encoding:
J = s_sent_len
d = self.embd_size
pe = to_var(torch.zeros(J, d)) # (s_sent_len, embd_size)
for j in range(1, J+1):
for k in range(1, d+1):
l_kj = (1 - j / J) - (k / d) * (1 - 2 * j / J)
pe[j-1][k-1] = l_kj
pe = pe.unsqueeze(0).unsqueeze(0) # (1, 1, s_sent_len, embd_size)
pe = pe.repeat(bs, story_len, 1, 1) # (bs, story_len, s_sent_len, embd_size)
x = x.view(bs*story_len, -1) # (bs*s_sent_len, s_sent_len)
u = self.dropout(self.B(q)) # (bs, q_sent_len, embd_size)
u = torch.sum(u, 1) # (bs, embd_size)
# Adjacent weight tying
for k in range(self.hops):
m = self.dropout(self.A[k](x)) # (bs*story_len, s_sent_len, embd_size)
m = m.view(bs, story_len, s_sent_len, -1) # (bs, story_len, s_sent_len, embd_size)
if self.position_encoding:
m *= pe # (bs, story_len, s_sent_len, embd_size)
m = torch.sum(m, 2) # (bs, story_len, embd_size)
if self.temporal_encoding:
m += self.TA.repeat(bs, 1, 1)[:, :story_len, :]
c = self.dropout(self.A[k+1](x)) # (bs*story_len, s_sent_len, embd_size)
c = c.view(bs, story_len, s_sent_len, -1) # (bs, story_len, s_sent_len, embd_size)
c = torch.sum(c, 2) # (bs, story_len, embd_size)
if self.temporal_encoding:
c += self.TC.repeat(bs, 1, 1)[:, :story_len, :] # (bs, story_len, embd_size)
p = torch.bmm(m, u.unsqueeze(2)).squeeze() # (bs, story_len)
p = F.softmax(p, -1).unsqueeze(1) # (bs, 1, story_len)
o = torch.bmm(p, c).squeeze(1) # use m as c, (bs, embd_size)
u = o + u # (bs, embd_size)
W = torch.t(self.A[-1].weight) # (embd_size, vocab_size)
out = torch.bmm(u.unsqueeze(1), W.unsqueeze(0).repeat(bs, 1, 1)).squeeze() # (bs, ans_size)
return F.log_softmax(out, -1)