forked from aioz-ai/MICCAI19-MedVQA
-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathattention.py
executable file
·128 lines (102 loc) · 4.34 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
This code is extended from Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang's repository.
https://github.com/jnhwkim/ban-vqa
This code is modified from ZCYang's repository.
https://github.com/zcyang/imageqa-san
"""
import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
from bc import BCNet
# Bilinear Attention
class BiAttention(nn.Module):
def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]):
super(BiAttention, self).__init__()
self.glimpse = glimpse
self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
name='h_mat', dim=None)
def forward(self, v, q, v_mask=True):
"""
v: [batch, k, vdim]
q: [batch, qdim]
"""
p, logits = self.forward_all(v, q, v_mask)
return p, logits
def forward_all(self, v, q, v_mask=True):
v_num = v.size(1)
q_num = q.size(1)
logits = self.logits(v, q) # b x g x v x q
if v_mask:
mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand(logits.size())
logits.data.masked_fill_(mask.data, -float('inf'))
p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2)
return p.view(-1, self.glimpse, v_num, q_num), logits
# Stacked Attention
class StackedAttention(nn.Module):
def __init__(self, num_stacks, img_feat_size, ques_feat_size, att_size, output_size, drop_ratio):
super(StackedAttention, self).__init__()
self.img_feat_size = img_feat_size
self.ques_feat_size = ques_feat_size
self.att_size = att_size
self.output_size = output_size
self.drop_ratio = drop_ratio
self.num_stacks = num_stacks
self.layers = nn.ModuleList()
self.dropout = nn.Dropout(drop_ratio)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1)
self.fc11 = nn.Linear(ques_feat_size, att_size, bias=True)
self.fc12 = nn.Linear(img_feat_size, att_size, bias=False)
self.fc13 = nn.Linear(att_size, 1, bias=True)
for stack in range(num_stacks - 1):
self.layers.append(nn.Linear(att_size, att_size, bias=True))
self.layers.append(nn.Linear(img_feat_size, att_size, bias=False))
self.layers.append(nn.Linear(att_size, 1, bias=True))
def forward(self, img_feat, ques_feat, v_mask=True):
# Batch size
B = ques_feat.size(0)
# Stack 1
ques_emb_1 = self.fc11(ques_feat)
img_emb_1 = self.fc12(img_feat)
# Compute attention distribution
h1 = self.tanh(ques_emb_1.view(B, 1, self.att_size) + img_emb_1)
h1_emb = self.fc13(self.dropout(h1))
# Mask actual bounding box sizes before calculating softmax
if v_mask:
mask = (0 == img_emb_1.abs().sum(2)).unsqueeze(2).expand(h1_emb.size())
h1_emb.data.masked_fill_(mask.data, -float('inf'))
p1 = self.softmax(h1_emb)
# Compute weighted sum
img_att_1 = img_emb_1*p1
weight_sum_1 = torch.sum(img_att_1, dim=1)
# Combine with question vector
u1 = ques_emb_1 + weight_sum_1
# Other stacks
us = []
ques_embs = []
img_embs = []
hs = []
h_embs =[]
ps = []
img_atts = []
weight_sums = []
us.append(u1)
for stack in range(self.num_stacks - 1):
ques_embs.append(self.layers[3 * stack + 0](us[-1]))
img_embs.append(self.layers[3 * stack + 1](img_feat))
# Compute attention distribution
hs.append(self.tanh(ques_embs[-1].view(B, -1, self.att_size) + img_embs[-1]))
h_embs.append(self.layers[3*stack + 2](self.dropout(hs[-1])))
# Mask actual bounding box sizes before calculating softmax
if v_mask:
mask = (0 == img_embs[-1].abs().sum(2)).unsqueeze(2).expand(h_embs[-1].size())
h_embs[-1].data.masked_fill_(mask.data, -float('inf'))
ps.append(self.softmax(h_embs[-1]))
# Compute weighted sum
img_atts.append(img_embs[-1] * ps[-1])
weight_sums.append(torch.sum(img_atts[-1], dim=1))
# Combine with previous stack
ux = us[-1] + weight_sums[-1]
# Combine with previous stack by multiple
us.append(ux)
return us[-1]