forked from moon-hotel/TransformerClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMyTransformer.py
executable file
·465 lines (397 loc) · 24.8 KB
/
MyTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
from torch.nn.init import xavier_uniform_
import torch.nn.functional as F
from torch.nn import Parameter
import torch.nn as nn
import copy
import torch
is_print_shape = False
class MyTransformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
):
super(MyTransformer, self).__init__()
"""
:param d_model: d_k = d_v = d_model/nhead = 64, 模型中向量的维度,论文默认值为 512
:param nhead: 多头注意力机制中多头的数量,论文默认为值 8
:param num_encoder_layers: encoder堆叠的数量,也就是论文中的N,论文默认值为6
:param num_decoder_layers: decoder堆叠的数量,也就是论文中的N,论文默认值为6
:param dim_feedforward: 全连接中向量的维度,论文默认值为 2048
:param dropout: 丢弃率,论文中的默认值为 0.1
"""
# ================ 编码部分 =====================
encoder_layer = MyTransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
encoder_norm = nn.LayerNorm(d_model)
self.encoder = MyTransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
# ================ 解码部分 =====================
decoder_layer = MyTransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = MyTransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
"""
初始化
"""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(self, src, tgt, src_mask=None, tgt_mask=None,
memory_mask=None, src_key_padding_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
"""
:param src: [src_len,batch_size,embed_dim]
:param tgt: [tgt_len, batch_size, embed_dim]
:param src_mask: None
:param tgt_mask: [tgt_len, tgt_len]
:param memory_mask: None
:param src_key_padding_mask: [batch_size, src_len]
:param tgt_key_padding_mask: [batch_size, tgt_len]
:param memory_key_padding_mask: [batch_size, src_len]
:return: [tgt_len, batch_size, num_heads * kdim] <==> [tgt_len,batch_size,embed_dim]
"""
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
# [src_len, batch_size, num_heads * kdim] <==> [src_len,batch_size,embed_dim]
output = self.decoder(tgt=tgt, memory=memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output # [tgt_len, batch_size, num_heads * kdim] <==> [tgt_len,batch_size,embed_dim]
def generate_square_subsequent_mask(self, sz):
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask # [sz,sz]
class MyTransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super(MyTransformerEncoderLayer, self).__init__()
"""
:param d_model: d_k = d_v = d_model/nhead = 64, 模型中向量的维度,论文默认值为 512
:param nhead: 多头注意力机制中多头的数量,论文默认为值 8
:param dim_feedforward: 全连接中向量的维度,论文默认值为 2048
:param dropout: 丢弃率,论文中的默认值为 0.1
"""
self.self_attn = MyMultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.activation = F.relu
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
:param src: 编码部分的输入,形状为 [src_len,batch_size, embed_dim]
:param src_mask: 编码部分输入的padding情况,形状为 [batch_size, src_len]
:return:
"""
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, )[0] # 计算多头注意力
# src2: [src_len,batch_size,num_heads*kdim] num_heads*kdim = embed_dim
src = src + self.dropout1(src2) # 残差连接
src = self.norm1(src) # [src_len,batch_size,num_heads*kdim]
src2 = self.activation(self.linear1(src)) # [src_len,batch_size,dim_feedforward]
src2 = self.linear2(self.dropout(src2)) # [src_len,batch_size,num_heads*kdim]
src = src + self.dropout2(src2)
src = self.norm2(src)
return src # [src_len, batch_size, num_heads * kdim] <==> [src_len,batch_size,embed_dim]
class MyTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super(MyTransformerEncoder, self).__init__()
"""
encoder_layer: 就是包含有多头注意力机制的一个编码层
num_layers: 克隆得到多个encoder layers 论文中默认为6
norm: 归一化层
"""
self.layers = _get_clones(encoder_layer, num_layers) # 克隆得到多个encoder layers 论文中默认为6
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask=None, src_key_padding_mask=None):
"""
:param src: 编码部分的输入,形状为 [src_len,batch_size, embed_dim]
:param mask: 编码部分输入的padding情况,形状为 [batch_size, src_len]
:return:# [src_len, batch_size, num_heads * kdim] <==> [src_len,batch_size,embed_dim]
"""
output = src
for mod in self.layers:
output = mod(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask) # 多个encoder layers层堆叠后的前向传播过程
if self.norm is not None:
output = self.norm(output)
return output # [src_len, batch_size, num_heads * kdim] <==> [src_len,batch_size,embed_dim]
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class MyTransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super(MyTransformerDecoderLayer, self).__init__()
"""
:param d_model: d_k = d_v = d_model/nhead = 64, 模型中向量的维度,论文默认值为 512
:param nhead: 多头注意力机制中多头的数量,论文默认为值 8
:param dim_feedforward: 全连接中向量的维度,论文默认值为 2048
:param dropout: 丢弃率,论文中的默认值为 0.1
"""
self.self_attn = MyMultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
# 解码部分输入序列之间的多头注意力(也就是论文结构图中的Masked Multi-head attention)
self.multihead_attn = MyMultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
# 编码部分输出(memory)和解码部分之间的多头注意力机制。
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = F.relu
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None,
memory_key_padding_mask=None):
"""
:param tgt: 解码部分的输入,形状为 [tgt_len,batch_size, embed_dim]
:param memory: 编码部分的输出(memory), [src_len,batch_size,embed_dim]
:param tgt_mask: 注意力Mask输入,用于掩盖当前position之后的信息, [tgt_len, tgt_len]
:param memory_mask: 编码器-解码器交互时的注意力掩码,一般为None
:param tgt_key_padding_mask: 解码部分输入的padding情况,形状为 [batch_size, tgt_len]
:param memory_key_padding_mask: 编码部分输入的padding情况,形状为 [batch_size, src_len]
:return:
"""
tgt2 = self.self_attn(tgt, tgt, tgt, # [tgt_len,batch_size, embed_dim]
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
# 解码部分输入序列之间'的多头注意力(也就是论文结构图中的Masked Multi-head attention)
tgt = tgt + self.dropout1(tgt2) # 接着是残差连接
tgt = self.norm1(tgt) # [tgt_len,batch_size, embed_dim]
tgt2 = self.multihead_attn(tgt, memory, memory, # [tgt_len, batch_size, embed_dim]
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
# 解码部分的输入经过多头注意力后同编码部分的输出(memory)通过多头注意力机制进行交互
tgt = tgt + self.dropout2(tgt2) # 残差连接
tgt = self.norm2(tgt) # [tgt_len, batch_size, embed_dim]
tgt2 = self.activation(self.linear1(tgt)) # [tgt_len, batch_size, dim_feedforward]
tgt2 = self.linear2(self.dropout(tgt2)) # [tgt_len, batch_size, embed_dim]
# 最后的两层全连接
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt # [tgt_len, batch_size, num_heads * kdim] <==> [tgt_len,batch_size,embed_dim]
class MyTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None):
super(MyTransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None,
memory_key_padding_mask=None):
"""
:param tgt: 解码部分的输入,形状为 [tgt_len,batch_size, embed_dim]
:param memory: 编码部分最后一层的输出 [src_len,batch_size, embed_dim]
:param tgt_mask: 注意力Mask输入,用于掩盖当前position之后的信息, [tgt_len, tgt_len]
:param memory_mask: 编码器-解码器交互时的注意力掩码,一般为None
:param tgt_key_padding_mask: 解码部分输入的padding情况,形状为 [batch_size, tgt_len]
:param memory_key_padding_mask: 编码部分输入的padding情况,形状为 [batch_size, src_len]
:return:
"""
output = tgt # [tgt_len,batch_size, embed_dim]
for mod in self.layers: # 这里的layers就是N层解码层堆叠起来的
output = mod(output, memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
if self.norm is not None:
output = self.norm(output)
return output # [tgt_len, batch_size, num_heads * kdim] <==> [tgt_len,batch_size,embed_dim]
class MyMultiheadAttention(nn.Module):
"""
多头注意力机制的计算公式为(就是论文第5页的公式):
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
super(MyMultiheadAttention, self).__init__()
"""
:param embed_dim: 词嵌入的维度,也就是前面的d_model参数,论文中的默认值为512
:param num_heads: 多头注意力机制中多头的数量,也就是前面的nhead参数, 论文默认值为 8
:param dropout:
:param bias: 最后对多头的注意力(组合)输出进行线性变换时,是否使用偏置
"""
self.embed_dim = embed_dim # 前面的d_model参数
self.head_dim = embed_dim // num_heads # head_dim 指的就是d_k,d_v
self.kdim = self.head_dim
self.vdim = self.head_dim
self.num_heads = num_heads # 多头个数
self.dropout = dropout
assert self.head_dim * num_heads == self.embed_dim, "embed_dim 除以 num_heads必须为整数"
# 上面的限制条件就是论文中的 d_k = d_v = d_model/n_head 条件
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # embed_dim = kdim * num_heads
# 这里第二个维度之所以是embed_dim,实际上这里是同时初始化了num_heads个W_q堆叠起来的, 也就是num_heads个头
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # W_k, embed_dim = kdim * num_heads
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # W_v, embed_dim = vdim * num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# 最后将所有的Z组合起来的时候,也是一次性完成, embed_dim = vdim * num_heads
self._reset_parameters()
def _reset_parameters(self):
"""
以特定方式来初始化参数
:return:
"""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
"""
在论文中,编码时query, key, value 都是同一个输入, 解码时 输入的部分也都是同一个输入,
解码和编码交互时 key,value指的是 memory, query指的是tgt
:param query: # [tgt_len, batch_size, embed_dim], tgt_len 表示目标序列的长度
:param key: # [src_len, batch_size, embed_dim], src_len 表示源序列的长度
:param value: # [src_len, batch_size, embed_dim], src_len 表示源序列的长度
:param attn_mask: # [tgt_len,src_len] or [num_heads*batch_size,tgt_len, src_len]
一般只在解码时使用,为了并行一次喂入所有解码部分的输入,所以要用mask来进行掩盖当前时刻之后的位置信息
:param key_padding_mask: [batch_size, src_len], src_len 表示源序列的长度
:return:
attn_output: [tgt_len, batch_size, embed_dim]
attn_output_weights: # [batch_size, tgt_len, src_len]
"""
return multi_head_attention_forward(query, key, value, self.num_heads,
self.dropout,
out_proj=self.out_proj,
training=self.training,
key_padding_mask=key_padding_mask,
q_proj=self.q_proj,
k_proj=self.k_proj,
v_proj=self.v_proj,
attn_mask=attn_mask)
def multi_head_attention_forward(query, # [tgt_len,batch_size, embed_dim]
key, # [src_len, batch_size, embed_dim]
value, # [src_len, batch_size, embed_dim]
num_heads,
dropout_p,
out_proj, # [embed_dim = vdim * num_heads, embed_dim = vdim * num_heads]
training=True,
key_padding_mask=None, # [batch_size,src_len/tgt_len]
q_proj=None, # [embed_dim,kdim * num_heads]
k_proj=None, # [embed_dim, kdim * num_heads]
v_proj=None, # [embed_dim, vdim * num_heads]
attn_mask=None, # [tgt_len,src_len] or [num_heads*batch_size,tgt_len, src_len]
):
q = q_proj(query)
# [tgt_len,batch_size, embed_dim] x [embed_dim,kdim * num_heads] = [tgt_len,batch_size,kdim * num_heads]
k = k_proj(key)
# [src_len, batch_size, embed_dim] x [embed_dim, kdim * num_heads] = [src_len, batch_size, kdim * num_heads]
v = v_proj(value)
# [src_len, batch_size, embed_dim] x [embed_dim, vdim * num_heads] = [src_len, batch_size, vdim * num_heads]
if is_print_shape:
print("" + "=" * 80)
print("进入多头注意力计算:")
print(
f"\t 多头num_heads = {num_heads}, d_model={query.size(-1)}, d_k = d_v = d_model/num_heads={query.size(-1) // num_heads}")
print(f"\t query的shape([tgt_len, batch_size, embed_dim]):{query.shape}")
print(f"\t W_q 的shape([embed_dim,kdim * num_heads]):{q_proj.weight.shape}")
print(f"\t Q 的shape([tgt_len, batch_size,kdim * num_heads]):{q.shape}")
print("\t" + "-" * 70)
print(f"\t key 的shape([src_len,batch_size, embed_dim]):{key.shape}")
print(f"\t W_k 的shape([embed_dim,kdim * num_heads]):{k_proj.weight.shape}")
print(f"\t K 的shape([src_len,batch_size,kdim * num_heads]):{k.shape}")
print("\t" + "-" * 70)
print(f"\t value的shape([src_len,batch_size, embed_dim]):{value.shape}")
print(f"\t W_v 的shape([embed_dim,vdim * num_heads]):{v_proj.weight.shape}")
print(f"\t V 的shape([src_len,batch_size,vdim * num_heads]):{v.shape}")
print("\t" + "-" * 70)
print("\t ***** 注意,这里的W_q, W_k, W_v是多个head同时进行计算的. 因此,Q,K,V分别也是包含了多个head的q,k,v堆叠起来的结果 *****")
tgt_len, bsz, embed_dim = query.size() # [tgt_len,batch_size, embed_dim]
src_len = key.size(0)
head_dim = embed_dim // num_heads # num_heads * head_dim = embed_dim
scaling = float(head_dim) ** -0.5
q = q * scaling # [query_len,batch_size,kdim * num_heads]
if attn_mask is not None: # [tgt_len,src_len] or [num_heads*batch_size,tgt_len, src_len]
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) # [1, tgt_len,src_len]
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 3D attn_mask is not correct.')
# 现在 atten_mask 的维度就变成了3D
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
# [batch_size * num_heads,tgt_len,kdim]
# 因为前面是num_heads个头一起参与的计算,所以这里要进行一下变形,以便于后面计算。 且同时交换了0,1两个维度
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) # [batch_size * num_heads,src_len,kdim]
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) # [batch_size * num_heads,src_len,vdim]
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
# [batch_size * num_heads,tgt_len,kdim] x [batch_size * num_heads, kdim, src_len]
# = [batch_size * num_heads, tgt_len, src_len] 这就num_heads个QK相乘后的注意力矩阵
if attn_mask is not None:
attn_output_weights += attn_mask # [batch_size * num_heads, tgt_len, src_len]
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
# 变成 [batch_size, num_heads, tgt_len, src_len]的形状
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), # 扩展维度,从[batch_size,src_len]变成[batch_size,1,1,src_len]
float('-inf')) #
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len,
src_len) # [batch_size * num_heads, tgt_len, src_len]
attn_output_weights = F.softmax(attn_output_weights, dim=-1) # [batch_size * num_heads, tgt_len, src_len]
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
# Z = [batch_size * num_heads, tgt_len, src_len] x [batch_size * num_heads,src_len,vdim]
# = # [batch_size * num_heads,tgt_len,vdim]
# 这就num_heads个Attention(Q,K,V)结果
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
# 先transpose成 [tgt_len, batch_size* num_heads ,kdim]
# 再view成 [tgt_len,batch_size,num_heads*kdim]
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
Z = out_proj(attn_output)
# 这里就是多个z 线性组合成Z [tgt_len,batch_size,embed_dim]
if is_print_shape:
print(f"\t 多头注意力中,多头计算结束后的形状(堆叠)为([tgt_len,batch_size,num_heads*kdim]){attn_output.shape}")
print(f"\t 多头计算结束后,再进行线性变换时的权重W_o的形状为([num_heads*vdim, num_heads*vdim ]){out_proj.weight.shape}")
print(f"\t 多头线性变化后的形状为([tgt_len,batch_size,embed_dim]) {Z.shape}")
return Z, attn_output_weights.sum(dim=1) / num_heads # average attention weights over heads
if __name__ == '__main__':
src_len = 5
batch_size = 2
dmodel = 32
tgt_len = 6
num_head = 8
src = torch.rand((src_len, batch_size, dmodel)) # shape: [src_len, batch_size, embed_dim]
src_key_padding_mask = torch.tensor([[True, True, True, False, False],
[True, True, True, True, False]]) # shape: [batch_size, src_len]
tgt = torch.rand((tgt_len, batch_size, dmodel)) # shape: [tgt_len, batch_size, embed_dim]
tgt_key_padding_mask = torch.tensor([[True, True, True, False, False, False],
[True, True, True, True, False, False]]) # shape: [batch_size, tgt_len]
# ============ 测试 MyMultiheadAttention ============
# my_mh = MyMultiheadAttention(embed_dim=dmodel, num_heads=num_head)
# r = my_mh(src, src, src, key_padding_mask=src_key_padding_mask)
# ============ 测试 MyTransformerEncoderLayer ============
# my_transformer_encoder_layer = MyTransformerEncoderLayer(d_model=dmodel, nhead=num_head)
# r = my_transformer_encoder_layer(src=src, src_key_padding_mask=src_key_padding_mask)
# ============ 测试 MyTransformerDecoder ============
# my_transformer_encoder_layer = MyTransformerEncoderLayer(d_model=dmodel, nhead=num_head)
# my_transformer_encoder = MyTransformerEncoder(encoder_layer=my_transformer_encoder_layer,
# num_layers=2,
# norm=nn.LayerNorm(dmodel))
# memory = my_transformer_encoder(src=src, mask=None, src_key_padding_mask=src_key_padding_mask)
# print(memory.shape)
#
# my_transformer_decoder_layer = MyTransformerDecoderLayer(d_model=dmodel, nhead=num_head)
# my_transformer_decoder = MyTransformerDecoder(decoder_layer=my_transformer_decoder_layer,
# num_layers=1,
# norm=nn.LayerNorm(dmodel))
# out = my_transformer_decoder(tgt=tgt, memory=memory, tgt_key_padding_mask=tgt_key_padding_mask,
# memory_key_padding_mask=src_key_padding_mask)
# print(out.shape)
# ============ 测试 MyTransformer ============
my_transformer = MyTransformer(d_model=dmodel, nhead=num_head, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=500)
src_mask = my_transformer.generate_square_subsequent_mask(src_len)
tgt_mask = my_transformer.generate_square_subsequent_mask(tgt_len)
out = my_transformer(src=src, tgt=tgt, tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask)
print(out.shape)