-
Notifications
You must be signed in to change notification settings - Fork 7
/
model.py
318 lines (277 loc) · 11.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.container import ModuleList
from fairseq.modules import SinusoidalPositionalEmbedding
from abc import ABC
from utils import load_json
from transformers import BertConfig, BertModel
def clean_state_dict(state_dict):
new = {}
for key, value in state_dict.items():
if key in ['fc.weight', 'fc.bias']:
continue
new[key.replace('bert.', '')] = value
return new
def load_bert(bert_path, device):
bert_config_path = os.path.join(bert_path, 'config.json')
bert = BertModel(BertConfig(**load_json(bert_config_path))).to(device)
bert_model_path = os.path.join(bert_path, 'model.bin')
bert.load_state_dict(clean_state_dict(torch.load(bert_model_path)))
return bert
class TransformerBlock(nn.Module, ABC):
def __init__(self,
d_model,
n_heads,
attn_dropout,
res_dropout):
super(TransformerBlock, self).__init__()
self.layer_norm = nn.LayerNorm(d_model)
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=attn_dropout)
self.dropout = nn.Dropout(res_dropout)
def forward(self,
query, key, value,
key_padding_mask=None,
attn_mask=True):
"""
From original Multimodal Transformer code,
In the original paper each operation (multi-head attention or FFN) is
post-processed with: `dropout -> add residual -> layer-norm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layer-norm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper.
"""
query, key, value = [self.layer_norm(x) for x in (query, key, value)]
mask = self.get_future_mask(query, key) if attn_mask else None
x = self.self_attn(
query, key, value,
key_padding_mask=key_padding_mask,
attn_mask=mask)[0]
return query + self.dropout(x)
@staticmethod
def get_future_mask(query, key=None):
"""
:return: source mask
ex) tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0.]])
"""
dim_query = query.shape[0]
dim_key = dim_query if key is None else key.shape[0]
future_mask = torch.ones(dim_query, dim_key, device=query.device)
future_mask = torch.triu(future_mask, diagonal=1).float()
future_mask = future_mask.masked_fill(future_mask == float(1), float('-inf'))
return future_mask
class FeedForwardBlock(nn.Module, ABC):
def __init__(self,
d_model,
d_feedforward,
res_dropout,
relu_dropout):
super(FeedForwardBlock, self).__init__()
self.layer_norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_feedforward)
self.dropout1 = nn.Dropout(relu_dropout)
self.linear2 = nn.Linear(d_feedforward, d_model)
self.dropout2 = nn.Dropout(res_dropout)
def forward(self, x):
"""
Do layer-norm before self-attention
"""
normed = self.layer_norm(x)
projected = self.linear2(self.dropout1(F.relu(self.linear1(normed))))
skipped = normed + self.dropout2(projected)
return skipped
class TransformerEncoderBlock(nn.Module, ABC):
def __init__(self,
d_model,
n_heads,
d_feedforward,
attn_dropout,
res_dropout,
relu_dropout):
"""
Args:
d_model: the number of expected features in the input (required).
n_heads: the number of heads in the multi-head attention models (required).
d_feedforward: the dimension of the feedforward network model (required).
attn_dropout: the dropout value for multi-head attention (required).
res_dropout: the dropout value for residual connection (required).
relu_dropout: the dropout value for relu (required).
"""
super(TransformerEncoderBlock, self).__init__()
self.transformer = TransformerBlock(d_model, n_heads, attn_dropout, res_dropout)
self.feedforward = FeedForwardBlock(d_model, d_feedforward, res_dropout, relu_dropout)
def forward(self,
x_query,
x_key=None,
key_mask=None,
attn_mask=None):
"""
x : input of the encoder layer -> (L, B, d)
"""
if x_key is not None:
x = self.transformer(
x_query, x_key, x_key,
key_padding_mask=key_mask,
attn_mask=attn_mask
)
else:
x = self.transformer(
x_query, x_query, x_query,
key_padding_mask=key_mask,
attn_mask=attn_mask
)
x = self.feedforward(x)
return x
class CrossmodalTransformer(nn.Module, ABC):
def __init__(self,
n_layers,
n_heads,
d_model,
attn_dropout,
relu_dropout,
emb_dropout,
res_dropout,
attn_mask,
scale_embedding=True):
super(CrossmodalTransformer, self).__init__()
self.attn_mask = attn_mask
self.emb_scale = math.sqrt(d_model) if scale_embedding else 1.0
self.pos_emb = SinusoidalPositionalEmbedding(d_model, 0, init_size=128)
self.dropout = nn.Dropout(emb_dropout)
layer = TransformerEncoderBlock(
d_model=d_model,
n_heads=n_heads,
d_feedforward=d_model * 4,
attn_dropout=attn_dropout,
res_dropout=res_dropout,
relu_dropout=relu_dropout
)
self.layers = _get_clones(layer, n_layers)
def forward(self, x_query, x_key=None, key_mask=None):
# query settings
x_query_pos = self.pos_emb(x_query[:, :, 0])
x_query = self.emb_scale * x_query + x_query_pos
x_query = self.dropout(x_query).transpose(0, 1)
# key settings
if x_key is not None:
x_key_pos = self.pos_emb(x_key[:, :, 0])
x_key = self.emb_scale * x_key + x_key_pos
x_key = self.dropout(x_key).transpose(0, 1)
for layer in self.layers:
x_query = layer(
x_query, x_key,
key_mask=key_mask,
attn_mask=self.attn_mask
)
return x_query
class MultimodalTransformer(nn.Module, ABC):
def __init__(self,
n_layers=4,
n_heads=8,
n_classes=7,
only_audio=False,
only_text=False,
d_audio_orig=40,
d_text_orig=768,
d_model=64,
attn_dropout=.25,
relu_dropout=.0,
emb_dropout=.3,
res_dropout=.0,
out_dropout=.1,
attn_mask=True):
super(MultimodalTransformer, self).__init__()
self.only_audio = only_audio
self.only_text = only_text
self.use_both = not (self.only_audio or self.only_text)
if self.only_audio:
d_model = d_audio_orig
elif self.only_text:
d_model = d_text_orig
combined_dim = d_model * 2 if self.use_both else d_model
# temporal convolutional layers
# (B, d_orig, L) => (B, d_model, L)
if self.use_both:
self.audio_encoder = nn.Conv1d(d_audio_orig, d_model, 3, padding=1, bias=False)
self.text_encoder = nn.Conv1d(d_text_orig, d_model, 3, padding=1, bias=False)
# kwargs for crossmodal transformers
kwargs = {
'n_layers': n_layers,
'n_heads': n_heads,
'd_model': d_model,
'attn_dropout': attn_dropout,
'relu_dropout': relu_dropout,
'emb_dropout': emb_dropout,
'res_dropout': res_dropout,
'attn_mask': attn_mask
}
# crossmodal transformers
if self.use_both:
self.audio_with_text = self.get_network(**kwargs)
self.text_with_audio = self.get_network(**kwargs)
# self-attention layers
if self.use_both or self.only_audio:
self.audio_layers = self.get_network(**kwargs)
if self.use_both:
# we do not use this layer if self.only_text == True,
# because we use just a pooler layer for prediction.
self.text_layers = self.get_network(**kwargs)
# Projection layers
if self.use_both or self.only_audio:
self.fc1 = nn.Linear(combined_dim, combined_dim)
self.fc2 = nn.Linear(combined_dim, combined_dim)
self.dropout = nn.Dropout(out_dropout)
self.out_layer = nn.Linear(combined_dim, n_classes)
def forward(self,
x_audio,
x_text,
a_mask, # (B, L_a)
t_mask): # (B, L_t)
out, features = None, None
if self.use_both:
# temporal convolution
x_audio = self.audio_encoder(x_audio.transpose(1, 2)).transpose(1, 2)
x_text = self.text_encoder(x_text.transpose(1, 2)).transpose(1, 2)
# crossmodal attention
x_audio = self.audio_with_text(x_audio, x_text, t_mask).transpose(0, 1)
x_text = self.text_with_audio(x_text, x_audio, a_mask).transpose(0, 1)
# self-attention
x_audio = self.audio_layers(x_audio, key_mask=a_mask) # (L_a, B, D)
x_text = self.text_layers(x_text, key_mask=t_mask) # (L_t, B, D)
# aggregation & prediction
#features = torch.cat([x_audio.mean(dim=0), x_text.mean(dim=0)], dim=1)
#features = torch.cat([x_audio[-1], x_text[-1]], dim=1)
features = []
for idx, (cur_a_mask, cur_t_mask) in enumerate(zip(a_mask, t_mask)):
cur_x_audio = x_audio[~cur_a_mask, idx, :].mean(dim=0).unsqueeze(0)
cur_x_text = x_text[~cur_t_mask, idx, :].mean(dim=0).unsqueeze(0)
features.append(torch.cat([cur_x_audio, cur_x_text], dim=1))
features = torch.cat(features, dim=0)
out = features + self.fc2(self.dropout(F.relu(self.fc1(features))))
elif self.only_audio:
features = self.audio_layers(x_audio, key_mask=a_mask).mean(dim=0)
out = features + self.fc2(self.dropout(F.relu(self.fc1(features))))
elif self.only_text:
features = x_text[:, 0, :]
out = self.dropout(features)
return self.out_layer(out), features
@staticmethod
def get_network(**kwargs):
return CrossmodalTransformer(
n_layers=kwargs['n_layers'],
n_heads=kwargs['n_heads'],
d_model=kwargs['d_model'],
attn_dropout=kwargs['attn_dropout'],
relu_dropout=kwargs['relu_dropout'],
emb_dropout=kwargs['emb_dropout'],
res_dropout=kwargs['res_dropout'],
attn_mask=kwargs['attn_mask'],
scale_embedding=True
)
def _get_clones(module, n):
return ModuleList([copy.deepcopy(module) for _ in range(n)])