diff --git a/configs/vit/README.md b/configs/vit/README.md index e5d743fbb..f424e57e4 100644 --- a/configs/vit/README.md +++ b/configs/vit/README.md @@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows. | Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | |--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------| -| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) | -| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) | -| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) | +| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) | +| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) | +| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) | diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py new file mode 100644 index 000000000..c570c5c1b --- /dev/null +++ b/mindcv/models/layers/pos_embed.py @@ -0,0 +1,51 @@ +"""positional embedding""" +from typing import Tuple + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops + + +class RelativePositionBiasWithCLS(nn.Cell): + def __init__( + self, + window_size: Tuple[int], + num_heads: int + ): + super(RelativePositionBiasWithCLS, self).__init__() + self.window_size = window_size + self.num_tokens = window_size[0] * window_size[1] + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # 3: cls to token, token to cls, cls to cls + self.relative_position_bias_table = Parameter( + Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) + ) + coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) + coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[0] - 1 + + relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index = Tensor(relative_position_index.reshape(-1)) + + self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) + self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + + def construct(self): + out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) + out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = ops.transpose(out, (2, 0, 1)) + out = ops.expand_dims(out, 0) + return out diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index ac2c4c4c7..b26712f91 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,28 +1,30 @@ """ViT""" +import math from typing import List, Optional, Union import numpy as np -import mindspore as ms -from mindspore import Tensor, nn -from mindspore import ops -from mindspore import ops as P -from mindspore.common.initializer import Normal, initializer -from mindspore.common.parameter import Parameter +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import TruncatedNormal, initializer from .helpers import ConfigDict, load_pretrained from .layers.compatibility import Dropout +from .layers.drop_path import DropPath +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed +from .layers.pos_embed import RelativePositionBiasWithCLS from .registry import register_model __all__ = [ + "VisionTransformerEncoder", "ViT", "vit_b_16_224", "vit_b_16_384", - "vit_l_16_224", # train + "vit_l_16_224", # with pretrained weights "vit_l_16_384", - "vit_b_32_224", # train + "vit_b_32_224", # with pretrained weights "vit_b_32_384", - "vit_l_32_224", # train + "vit_l_32_224", # with pretrained weights ] @@ -32,7 +34,7 @@ def _cfg(url="", **kwargs): "num_classes": 1000, "input_size": (3, 224, 224), "first_conv": "patch_embed.proj", - "classifier": "classifier", + "classifier": "head.classifier", **kwargs, } @@ -42,62 +44,19 @@ def _cfg(url="", **kwargs): "vit_b_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt"), + "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt"), "vit_l_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt"), + "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt"), "vit_b_32_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt"), + "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt"), } -class PatchEmbedding(nn.Cell): - """ - Path embedding layer for ViT. First rearrange b c (h p) (w p) -> b (h w) (p p c). - - Args: - image_size (int): Input image size. Default: 224. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - input_channels (int): The number of input channel. Default: 3. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = PathEmbedding(224, 16, 768, 3) - """ - - MIN_NUM_PATCHES = 4 - - def __init__( - self, - image_size: int = 224, - patch_size: int = 16, - embed_dim: int = 768, - input_channels: int = 3, - ): - super().__init__() - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = (image_size // patch_size) ** 2 - self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True) - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - - def construct(self, x): - """Path Embedding construct.""" - x = self.conv(x) - b, c, h, w = x.shape - x = self.reshape(x, (b, c, h * w)) - x = self.transpose(x, (0, 2, 1)) - - return x - - +# TODO: Flash Attention class Attention(nn.Cell): """ Attention layer implementation, Rearrange Input -> B x N x hidden size. @@ -105,8 +64,11 @@ class Attention(nn.Cell): Args: dim (int): The dimension of input features. num_heads (int): The number of attention heads. Default: 8. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. + attn_head_dim (int): The user-defined dimension of attention head features. Default: None. Returns: Tensor, output tensor. @@ -114,23 +76,33 @@ class Attention(nn.Cell): Examples: >>> ops = Attention(768, 12) """ - def __init__( self, dim: int, num_heads: int = 8, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, ): - super().__init__() + super(Attention, self).__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = Tensor(head_dim**-0.5) + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * num_heads + + if qk_scale: + self.scale = Tensor(qk_scale) + else: + self.scale = Tensor(head_dim ** -0.5) - self.qkv = nn.Dense(dim, dim * 3) - self.attn_drop = Dropout(p=1.0-attention_keep_prob) - self.out = nn.Dense(dim, dim) - self.out_drop = Dropout(p=1.0-keep_prob) + self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + + self.attn_drop = Dropout(attn_drop) + self.proj = nn.Dense(all_head_dim, dim) + self.proj_drop = Dropout(proj_drop) self.mul = ops.Mul() self.reshape = ops.Reshape() @@ -140,8 +112,7 @@ def __init__( self.q_matmul_k = ops.BatchMatMul(transpose_b=True) self.softmax = nn.Softmax(axis=-1) - def construct(self, x): - """Attention construct.""" + def construct(self, x, rel_pos_bias=None): b, n, c = x.shape qkv = self.qkv(x) qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) @@ -150,193 +121,245 @@ def construct(self, x): attn = self.q_matmul_k(q, k) attn = self.mul(attn, self.scale) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + attn = self.softmax(attn) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) out = self.transpose(out, (0, 2, 1, 3)) out = self.reshape(out, (b, n, c)) - out = self.out(out) - out = self.out_drop(out) + out = self.proj(out) + out = self.proj_drop(out) return out -class FeedForward(nn.Cell): +class LayerScale(nn.Cell): """ - Feed Forward layer implementation. + Layer scale, help ViT improve the training dynamic, allowing for the training + of deeper high-capacity image transformers that benefit from depth Args: - in_features (int): The dimension of input features. - hidden_features (int): The dimension of hidden features. Default: None. - out_features (int): The dimension of output features. Default: None - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. + dim (int): The output dimension of attnetion layer or mlp layer. + init_values (float): The scale factor. Default: 1e-5. Returns: Tensor, output tensor. Examples: - >>> ops = FeedForward(768, 3072) + >>> ops = LayerScale(768, 0.01) """ - def __init__( self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - activation: nn.Cell = nn.GELU, - keep_prob: float = 1.0, + dim: int, + init_values: float = 1e-5 ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.dense1 = nn.Dense(in_features, hidden_features) - self.activation = activation() - self.dense2 = nn.Dense(hidden_features, out_features) - self.dropout = Dropout(p=1.0-keep_prob) + super(LayerScale, self).__init__() + self.gamma = Parameter(initializer(init_values, dim)) def construct(self, x): - """Feed Forward construct.""" - x = self.dense1(x) - x = self.activation(x) - x = self.dropout(x) - x = self.dense2(x) - x = self.dropout(x) - - return x + return self.gamma * x -class ResidualCell(nn.Cell): +class TransformerBlock(nn.Cell): """ - Cell which implements Residual function: - - $$output = x + f(x)$$ + Transformer block implementation. Args: - cell (Cell): Cell needed to add residual block. + dim (int): The dimension of embedding. + num_heads (int): The number of attention heads. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0. + attn_head_dim (int): The user-defined dimension of attention head features. Default: None. + mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer. + drop_path (float): The drop rate for drop path. Default: 0.0. + act_layer (nn.Cell): Activation function which will be stacked on top of the + normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. + norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution + layer. Default: nn.LayerNorm. Returns: Tensor, output tensor. Examples: - >>> ops = ResidualCell(nn.Dense(3,4)) - """ - - def __init__(self, cell): - super().__init__() - self.cell = cell - - def construct(self, x): - """ResidualCell construct.""" - return self.cell(x) + x - - -class DropPath(nn.Cell): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + >>> ops = TransformerEncoder(768, 12, 12, 3072) """ + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + ): + super(TransformerBlock, self).__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + ) + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def __init__(self, keep_prob=None, seed=0): - super().__init__() - self.keep_prob = 1 - keep_prob - seed = min(seed, 0) - self.rand = P.UniformReal(seed=seed) - self.shape = P.Shape() - self.floor = P.Floor() - - def construct(self, x): - if self.training: - x_shape = self.shape(x) - random_tensor = self.rand((x_shape[0], 1, 1)) - random_tensor = random_tensor + self.keep_prob - random_tensor = self.floor(random_tensor) - x = x / self.keep_prob - x = x * random_tensor + self.norm2 = norm_layer((dim,)) + self.mlp = Mlp( + in_features=dim, hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, drop=proj_drop + ) + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + def construct(self, x, rel_pos_bias=None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -class TransformerEncoder(nn.Cell): - """ - TransformerEncoder implementation. - - Args: - dim (int): The dimension of embedding. - num_layers (int): The depth of transformer. - num_heads (int): The number of attention heads. - mlp_dim (int): The dimension of MLP hidden layer. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = TransformerEncoder(768, 12, 12, 3072) - """ - +class VisionTransformerEncoder(nn.Cell): + ''' + ViT encoder, which returns the feature encoded by transformer encoder. + ''' def __init__( self, - dim: int, - num_layers: int, - num_heads: int, - mlp_dim: int, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + **kwargs ): - super().__init__() - drop_path_rate = 1 - drop_path_keep_prob - dpr = [i.item() for i in np.linspace(0, drop_path_rate, num_layers)] - attn_seeds = [np.random.randint(1024) for _ in range(num_layers)] - mlp_seeds = [np.random.randint(1024) for _ in range(num_layers)] - - layers = [] - for i in range(num_layers): - normalization1 = norm((dim,)) - normalization2 = norm((dim,)) - attention = Attention(dim=dim, - num_heads=num_heads, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob) - - feedforward = FeedForward(in_features=dim, - hidden_features=mlp_dim, - activation=activation, - keep_prob=keep_prob) - - if drop_path_rate > 0: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention, - DropPath(dpr[i], attn_seeds[i])])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward, - DropPath(dpr[i], mlp_seeds[i])]))])) - else: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward])) - ]) + super(VisionTransformerEncoder, self).__init__() + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, + in_chans=in_channels, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + + self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), + (1, self.num_patches + 1, embed_dim))) if not use_rel_pos_emb else None + self.pos_drop = Dropout(pos_drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBiasWithCLS( + window_size=self.patch_embed.patches_resolution, + num_heads=num_heads, + ) + elif use_rel_pos_bias: + self.rel_pos_bias = nn.CellList([ + RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, + num_heads=num_heads) for _ in range(depth) + ]) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + TransformerBlock( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer + ) for i in range(depth) + ]) + + self._init_weights() + self._fix_init_weights() + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) ) - self.layers = nn.SequentialCell(layers) + elif isinstance(cell, nn.Conv2d): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + def _fix_init_weights(self): + for i, block in enumerate(self.blocks): + block.attn.proj.weight.set_data( + ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) + ) + block.mlp.fc2.weight.set_data( + ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + ) + + def forward_features(self, x): + x = self.patch_embed(x) + bsz = x.shape[0] + + cls_tokens = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + if isinstance(self.rel_pos_bias, nn.CellList): + for i, blk in enumerate(self.blocks): + rel_pos_bias = self.rel_pos_bias[i]() + x = blk(x, rel_pos_bias) + else: + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + + return x def construct(self, x): - """Transformer construct.""" - return self.layers(x) + x = self.forward_features(x) + return x class DenseHead(nn.Cell): @@ -424,238 +447,143 @@ def construct(self, x): return x -class BaseClassifier(nn.Cell): - """ - generate classifier to combine the backbone and head - """ - - def __init__(self, backbone, neck=None, head=None): - super().__init__() - self.backbone = backbone - if neck: - self.neck = neck - self.with_neck = True - else: - self.with_neck = False - if head: - self.head = head - self.with_head = True - else: - self.with_head = False - - def forward_features(self, x: Tensor) -> Tensor: - x = self.backbone(x) - return x - - def forward_head(self, x: Tensor) -> Tensor: - x = self.head(x) - return x - - def construct(self, x): - x = self.forward_features(x) - if self.with_neck: - x = self.neck(x) - if self.with_head: - x = self.forward_head(x) - return x - - -def init(init_type, shape, dtype, name, requires_grad): - initial = initializer(init_type, shape, dtype).init_data() - return Parameter(initial, name=name, requires_grad=requires_grad) - - -class ViT(nn.Cell): - """ - Vision Transformer architecture implementation. - - Args: - image_size (int): Input image size. Default: 224. - input_channels (int): The number of input channel. Default: 3. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - num_layers (int): The depth of transformer. Default: 12. - num_heads (int): The number of attention heads. Default: 12. - mlp_dim (int): The dimension of MLP hidden layer. Default: 3072. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention layer. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. - pool (str): The method of pooling. Default: 'cls'. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, 768)` - - Raises: - ValueError: If `split` is not 'train', 'test' or 'infer'. - - Supported Platforms: - ``GPU`` - - Examples: - >>> net = ViT() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 768) - - About ViT: - - Vision Transformer (ViT) shows that a pure transformer applied directly to sequences of image - patches can perform very well on image classification tasks. When pre-trained on large amounts - of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, - CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art - convolutional networks while requiring substantially fewer computational resources to train. - - Citation: - - .. code-block:: - - @article{2020An, - title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, - author={Dosovitskiy, A. and Beyer, L. and Kolesnikov, A. and Weissenborn, D. and Houlsby, N.}, - year={2020}, - } - """ - +class ViT(VisionTransformerEncoder): def __init__( self, image_size: int = 224, - input_channels: int = 3, patch_size: int = 16, + in_channels: int = 3, embed_dim: int = 768, - num_layers: int = 12, + depth: int = 12, num_heads: int = 12, - mlp_dim: int = 3072, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: Optional[nn.Cell] = nn.LayerNorm, - pool: str = "cls", - ) -> None: - super().__init__() - - self.patch_embedding = PatchEmbedding(image_size=image_size, - patch_size=patch_size, - embed_dim=embed_dim, - input_channels=input_channels) - num_patches = self.patch_embedding.num_patches - - if pool == "cls": - self.cls_token = init(init_type=Normal(sigma=1.0), - shape=(1, 1, embed_dim), - dtype=ms.float32, - name="cls", - requires_grad=True) - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches + 1, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.concat = ops.Concat(axis=1) - else: - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.mean = ops.ReduceMean(keep_dims=False) - - self.pool = pool - self.pos_dropout = Dropout(p=1.0-keep_prob) - self.norm = norm((embed_dim,)) - self.tile = ops.Tile() - self.transformer = TransformerEncoder( - dim=embed_dim, - num_layers=num_layers, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + use_cls: bool = True, + representation_size: Optional[int] = None, + num_classes: int = 1000, + **kwargs + ): + super(ViT, self).__init__( + image_size=image_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob, - drop_path_keep_prob=drop_path_keep_prob, - activation=activation, - norm=norm, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=init_values, + act_layer=act_layer, + norm_layer=norm_layer, + use_rel_pos_emb=use_rel_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + **kwargs ) + self.use_cls = use_cls + self.norm = norm_layer((embed_dim,)) + + if representation_size: + self.head = MultilayerDenseHead( + input_channel=embed_dim, + num_classes=num_classes, + mid_channel=[representation_size], + activation=["tanh", None], + keep_prob=[1.0, 1.0], + ) + else: + self.head = DenseHead(input_channel=embed_dim, num_classes=num_classes) def construct(self, x): - """ViT construct.""" - x = self.patch_embedding(x) - - if self.pool == "cls": - cls_tokens = self.tile(self.cls_token, (x.shape[0], 1, 1)) - x = self.concat((cls_tokens, x)) - x += self.pos_embedding - else: - x += self.pos_embedding - x = self.pos_dropout(x) - x = self.transformer(x) + x = self.forward_features(x) x = self.norm(x) - if self.pool == "cls": + if self.use_cls: x = x[:, 0] else: - x = self.mean(x, (1, )) # (1,) or (1,2) + x = x[:, 1:].mean(axis=1) + + x = self.head(x) return x def vit( - image_size: int, - input_channels: int, - patch_size: int, - embed_dim: int, - num_layers: int, - num_heads: int, - num_classes: int, - mlp_dim: int, - dropout: float = 0.0, - attention_dropout: float = 0.0, - drop_path_rate: float = 0.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, - pool: str = "cls", + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + use_cls: bool = True, representation_size: Optional[int] = None, + num_classes: int = 1000, pretrained: bool = False, url_cfg: dict = None, ) -> ViT: + """Vision Transformer architecture.""" - backbone = ViT( + + model = ViT( image_size=image_size, - input_channels=input_channels, patch_size=patch_size, + in_channels=in_channels, embed_dim=embed_dim, - num_layers=num_layers, + depth=depth, num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=1.0 - dropout, - attention_keep_prob=1.0 - attention_dropout, - drop_path_keep_prob=1.0 - drop_path_rate, - activation=activation, - norm=norm, - pool=pool, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=init_values, + act_layer=act_layer, + norm_layer=norm_layer, + use_rel_pos_emb=use_rel_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_cls=use_cls, + representation_size=representation_size, + num_classes=num_classes ) - if representation_size: - head = MultilayerDenseHead( - input_channel=embed_dim, - num_classes=num_classes, - mid_channel=[representation_size], - activation=["tanh", None], - keep_prob=[1.0, 1.0], - ) - else: - head = DenseHead(input_channel=embed_dim, num_classes=num_classes) - - model = BaseClassifier(backbone=backbone, head=head) if pretrained: # Download the pre-trained checkpoint file from url, and load ckpt file. - load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=input_channels) + load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=in_channels) return model @@ -665,60 +593,25 @@ def vit_b_16_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """ - Constructs a vit_b_16 architecture from - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. - - Args: - pretrained (bool): Whether to download and load the pre-trained model. Default: False. - num_classes (int): The number of classification. Default: 1000. - in_channels (int): The number of input channels. Default: 3. - image_size (int): The input image size. Default: 224 for ImageNet. - has_logits (bool): Whether has logits or not. Default: False. - drop_rate (float): The drop out rate. Default: 0.0.s - drop_path_rate (float): The stochastic depth rate. Default: 0.0. - - Returns: - ViT network, MindSpore.nn.Cell - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Examples: - >>> net = vit_b_16_224() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 1000) - - Outputs: - Tensor of shape :math:`(N, CLASSES_{out})` - - Supported Platforms: - ``GPU`` - """ +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_16_224"] return vit(**config) @@ -729,29 +622,25 @@ def vit_b_16_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_16_384"] return vit(**config) @@ -762,30 +651,25 @@ def vit_l_16_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_16_224"] return vit(**config) @@ -796,30 +680,25 @@ def vit_l_16_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_16_384"] return vit(**config) @@ -830,29 +709,25 @@ def vit_b_32_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_32_224"] return vit(**config) @@ -863,29 +738,25 @@ def vit_b_32_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention_dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention_dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_32_384"] return vit(**config) @@ -896,29 +767,25 @@ def vit_l_32_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_32_224"] return vit(**config)