-
Notifications
You must be signed in to change notification settings - Fork 320
/
cait.py
508 lines (437 loc) · 19 KB
/
cait.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CaiT in Paddle
A Paddle Implementation of CaiT as described in:
"Going deeper with Image Transformers"
- Paper Link: https://arxiv.org/abs/2103.17239
"""
import paddle
import paddle.nn as nn
from droppath import DropPath
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
This layer is used to avoid using 'if' condition in methods such as forward
"""
def forward(self, x):
return x
class PatchEmbedding(nn.Layer):
"""Patch Embedding
Apply patch embedding (which is implemented using Conv2D) on input data.
Attributes:
image_size: int, input image size, default: 224
patch_size: int, size of patch, default: 4
in_channels: int, input image channels, default: 3
embed_dim: int, embedding dimension, default: 96
"""
def __init__(self, image_size=224, patch_size=4, in_channels=3, embed_dim=96):
super().__init__()
image_size = (image_size, image_size)
patch_size = (patch_size, patch_size)
patches_resolution = [image_size[0]//patch_size[0], image_size[1]//patch_size[1]]
self.image_size = image_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_channels = in_channels
self.embed_dim = embed_dim
self.patch_embed = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size)
# CaiT norm is not included
#self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution
x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches
x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim]
#x = self.norm(x) # [batch, num_patches, embed_dim] # CaiT norm is not needed
return x
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: GELU
dropout1: dropout after fc1
dropout2: dropout after fc2
"""
def __init__(self, in_features, hidden_features, dropout=0.):
super(Mlp, self).__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
in_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class ClassAttention(nn.Layer):
""" Class Attention
Class Attention module
Args:
dim: int, all heads dimension
dim_head: int, single heads dimension, default: None
num_heads: int, num of heads
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
qk_scale: float, if None, qk_scale is dim_head ** -0.5, default: None
attention_dropout: float, dropout rate for attention dropout, default: 0.
dropout: float, dropout rate for projection dropout, default: 0.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attention_dropout=0.,
dropout=0.):
super().__init__()
self.num_heads = num_heads
self.dim_head = dim // num_heads
self.scale = qk_scale or self.dim_head ** -0.5
self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.k = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.v = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.attn_dropout = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
def forward(self, x):
B, N, C = x.shape
q = self.q(x[:, :1, :]) # same as x[:, 0], but more intuitive
q = q.reshape([B, self.num_heads, 1, self.dim_head])
k = self.k(x)
k = k.reshape([B, N, self.num_heads, self.dim_head])
k = k.transpose([0, 2, 1, 3])
v = self.v(x)
v = v.reshape([B, N, self.num_heads, self.dim_head])
v = v.transpose([0, 2, 1, 3])
attn = paddle.matmul(q * self.scale, k, transpose_y=True)
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
cls_embed = paddle.matmul(attn, v)
cls_embed = cls_embed.transpose([0, 2, 1, 3])
cls_embed = cls_embed.reshape([B, 1, C])
cls_embed = self.proj(cls_embed)
cls_embed = self.proj_dropout(cls_embed)
return cls_embed
class TalkingHeadAttention(nn.Layer):
""" Talking head attention
Talking head attention (https://arxiv.org/abs/2003.02436),
applies linear projections across the attention-heads dimension,
before and after the softmax operation.
Args:
dim: int, all heads dimension
num_heads: int, num of heads
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
attention_dropout: float, dropout rate for attention dropout, default: 0.
dropout: float, dropout rate for projection dropout, default: 0.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
dropout=0.,
attention_dropout=0.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
self.dim_head = dim // num_heads
self.scale = self.dim_head ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(axis=-1)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(dropout)
# talking head
self.proj_l = nn.Linear(num_heads, num_heads)
self.proj_w = nn.Linear(num_heads, num_heads)
def transpose_multihead(self, x):
new_shape = x.shape[:-1] + [self.num_heads, self.dim_head]
x = x.reshape(new_shape)
x = x.transpose([0, 2, 1, 3])
return x
def forward(self, x):
B, H, C = x.shape # H: num_patches
qkv = self.qkv(x).chunk(3, axis=-1)
q, k, v = map(self.transpose_multihead, qkv) #[B, num_heads, num_patches, single_head_dim]
q = q * self.scale
attn = paddle.matmul(q, k, transpose_y=True) #[B, num_heads, num_patches, num_patches]
# projection across heads (before softmax)
attn = attn.transpose([0, 2, 3, 1]) #[B, num_patches, num_patches, num_heads]
attn = self.proj_l(attn)
attn = attn.transpose([0, 3, 1, 2]) #[B, num_heads, num_patches, num_patches]
attn = self.softmax(attn)
# projection across heads (after softmax)
attn = attn.transpose([0, 2, 3, 1]) #[B, num_patches, num_patches, num_heads]
attn = self.proj_w(attn)
attn = attn.transpose([0, 3, 1, 2]) #[B, num_heads, num_patches, num_patches]
attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v) #[B, num_heads, num_patches, single_head_dim]
z = z.transpose([0, 2, 1, 3]) #[B, num_patches, num_heads, single_head_dim]
z = z.reshape([B, H, C])
z = self.proj(z)
z = self.proj_dropout(z)
return z
class LayerScaleBlockClassAttention(nn.Layer):
""" LayerScale layers for class attention
LayerScale layers for class attention contains regular class-attention layers,
in addition with gamma_1 and gamma_2, which apply per-channel multiplication
after each residual block (attention and mlp layers).
Args:
dim: int, all heads dimension
num_heads: int, num of heads
mlp_ratio: ratio to multiply on mlp input dim as mlp hidden dim, default: 4.
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
dropout: float, dropout rate for projection dropout, default: 0.
attention_dropout: float, dropout rate for attention dropout, default: 0.
init_values: initial values for learnable weights gamma_1 and gamma_2, default: 1e-4
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
dropout=0.,
attention_dropout=0.,
droppath=0.,
init_values=1e-4):
super().__init__()
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6)
self.attn = ClassAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout=dropout,
attention_dropout=attention_dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
self.norm2 = nn.LayerNorm(dim, epsilon=1e-6)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
dropout=dropout)
self.gamma_1 = paddle.create_parameter(
shape=[dim],
dtype='float32',
default_initializer=nn.initializer.Constant(init_values))
self.gamma_2 = paddle.create_parameter(
shape=[dim],
dtype='float32',
default_initializer=nn.initializer.Constant(init_values))
def forward(self, x, x_cls):
u = paddle.concat([x_cls, x], axis=1)
u = self.norm1(u)
u = self.attn(u)
u = self.gamma_1 * u
u = self.drop_path(u)
x_cls = u + x_cls
h = x_cls
x_cls = self.norm2(x_cls)
x_cls = self.mlp(x_cls)
x_cls = self.gamma_2 * x_cls
x_cls = self.drop_path(x_cls)
x_cls = h + x_cls
return x_cls
class LayerScaleBlock(nn.Layer):
""" LayerScale layers
LayerScale layers contains regular self-attention layers,
in addition with gamma_1 and gamma_2, which apply per-channel multiplication
after each residual block (attention and mlp layers).
Args:
dim: int, all heads dimension
num_heads: int, num of heads
mlp_ratio: ratio to multiply on mlp input dim as mlp hidden dim, default: 4.
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
dropout: float, dropout rate for projection dropout, default: 0.
attention_dropout: float, dropout rate for attention dropout, default: 0.
init_values: initial values for learnable weights gamma_1 and gamma_2, default: 1e-4
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
dropout=0.,
attention_dropout=0.,
droppath=0.,
init_values=1e-4):
super().__init__()
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6)
self.attn = TalkingHeadAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout=dropout,
attention_dropout=attention_dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
self.norm2 = nn.LayerNorm(dim, epsilon=1e-6)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
dropout=dropout)
self.gamma_1 = paddle.create_parameter(
shape=[dim],
dtype='float32',
default_initializer=nn.initializer.Constant(init_values))
self.gamma_2 = paddle.create_parameter(
shape=[dim],
dtype='float32',
default_initializer=nn.initializer.Constant(init_values))
def forward(self, x):
h = x
x = self.norm1(x)
x = self.attn(x)
x = self.gamma_1 * x #[B, num_patches, embed_dim]
x = self.drop_path(x)
x = h + x
h = x
x = self.norm2(x)
x = self.mlp(x)
x = self.gamma_2 * x #[B, num_patches, embed_dim]
x = self.drop_path(x)
x = h + x
return x
class Cait(nn.Layer):
""" CaiT model
Args:
image_size: int, input image size, default: 224
in_channels: int, input image channels, default: 3
num_classes: int, num of classes, default: 1000
patch_size: int, patch size for patch embedding, default: 16
embed_dim: int, dim of each patch after patch embedding, default: 768
depth: int, num of self-attention blocks, default: 12
num_heads: int, num of attention heads, default: 12
mlp_ratio: float, mlp hidden dim = mlp_ratio * mlp_in_dim, default: 4.
qkv_bias: bool, if True, qkv projection is set with bias, default: True
dropout: float, dropout rate for linear projections, default: 0.
attention_dropout: float, dropout rate for attention, default: 0.
droppath: float, drop path rate, default: 0.
init_values: initial value for layer scales, default: 1e-4
mlp_ratio_class_token: float, mlp_ratio for mlp used in class attention blocks, default: 4.0
depth_token_only, int, num of class attention blocks, default: 2
"""
def __init__(self,
image_size=224,
in_channels=3,
num_classes=1000,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
dropout=0.,
attention_dropout=0.,
droppath=0,
init_values=1e-4,
mlp_ratio_class_token=4.0,
depth_token_only=2):
super().__init__()
self.num_classes = num_classes
# convert image to paches
self.patch_embed = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
# tokens add for classification
self.cls_token = paddle.create_parameter(
shape=[1, 1, embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.0))
# positional embeddings for patch positions
self.pos_embed = paddle.create_parameter(
shape=[1, num_patches, embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.0))
self.pos_dropout = nn.Dropout(dropout)
# create self-attention(layer-scale) layers
layer_list = []
for i in range(depth):
layer_list.append(LayerScaleBlock(dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
dropout=dropout,
attention_dropout=attention_dropout,
droppath=droppath,
init_values=init_values))
self.blocks = nn.LayerList(layer_list)
# craete class-attention layers
layer_list = []
for i in range(depth_token_only):
layer_list.append(LayerScaleBlockClassAttention(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio_class_token,
qkv_bias=qkv_bias,
dropout=0.,
attention_dropout=0.,
droppath=0.,
init_values=init_values))
self.blocks_token_only = nn.LayerList(layer_list)
self.norm = nn.LayerNorm(embed_dim, epsilon=1e-6)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else Identity()
def forward_features(self, x):
# Patch Embedding
x = self.patch_embed(x) # [B, num_patches, embed_dim]
cls_tokens = self.cls_token.expand([x.shape[0], -1, -1]) # [B, 1, embed_dim]
x = x + self.pos_embed
x = self.pos_dropout(x)
# Self-Attention blocks
for idx, block in enumerate(self.blocks):
x = block(x) # [B, num_patches, embed_dim]
# Class-Attention blocks
for idx, block in enumerate(self.blocks_token_only):
cls_tokens = block(x, cls_tokens) # [B, 1, embed_dim]
# Concat outputs
x = paddle.concat([cls_tokens, x], axis=1)
x = self.norm(x) # [B, num_patches + 1, embed_dim]
return x[:, 0] # returns only cls_tokens
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def build_cait(config):
"""build cait model from config"""
model = Cait(image_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.PATCH_SIZE,
in_channels=config.DATA.IMAGE_CHANNELS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.EMBED_DIM,
depth=config.MODEL.DEPTH,
num_heads=config.MODEL.NUM_HEADS,
mlp_ratio=config.MODEL.MLP_RATIO,
qkv_bias=config.MODEL.QKV_BIAS,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
droppath=config.MODEL.DROPPATH,
init_values=config.MODEL.INIT_VALUES,
mlp_ratio_class_token=config.MODEL.MLP_RATIO,
depth_token_only=config.MODEL.DEPTH_TOKEN_ONLY)
return model