diff --git a/configs/fastvit/README.md b/configs/fastvit/README.md new file mode 100644 index 000000000..ac7a4b123 --- /dev/null +++ b/configs/fastvit/README.md @@ -0,0 +1,91 @@ +# FastViT + + +> [A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) + +## Introduction + + +The recent amalgamation of transformer and convolutional designs has led to steady improvements in accuracy and efficiency of the models. In this work, we introduce FastViT, a hybrid vision transformer architecture that obtains the state-of-the-art latency-accuracy trade-off. To this end, we introduce a novel token mixing operator, RepMixer, a building block of FastViT, that uses structural reparameterization to lower the memory access cost by removing skip-connections in the network. We further apply train-time overparametrization and large kernel convolutions to boost accuracy and empirically show that these choices have minimal effect on latency. We show that - our model is 3.5x faster than CMT, a recent state-of-the-art hybrid transformer architecture, 4.9x faster than EfficientNet, and 1.9x faster than ConvNeXt on a mobile device for the same accuracy on the ImageNet dataset. At similar latency, our model obtains 4.2% better Top-1 accuracy on ImageNet than MobileOne. Our model consistently outperforms competing architectures across several tasks -- image classification, detection, segmentation and 3D mesh regression with significant improvement in latency on both a mobile device and a desktop GPU. Furthermore, our model is highly robust to out-of-distribution samples and corruptions, improving over competing robust models. + + + +## Results + + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | +|-----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------| +| FastViT-T8 | D910x8-G | 74.25 | 91.97 | 48 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/fastvit/fastvit_t8_ascend.yaml) | + +
+ +#### Notes +- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + + +## Quick Start +### Preparation + +#### Installation +Please refer to the [installation instruction](https://github.com/mindspore-lab/mindcv#installation) in MindCV. + +#### Dataset Preparation +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + + +* Distributed Training + +It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run + +```shell +# distributed training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/fastvit/fastvit_t8_ascend.yaml --data_dir /path/to/imagenet +``` +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/fastvit/fastvit_t8_ascend.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +``` +python validate.py -c configs/fastvit/fastvit_t8_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +To deploy online inference services with the trained model efficiently, please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/). + +## References + + +[1] Vasu P K A, Gabriel J, Zhu J, et al. FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization[J]. arXiv preprint arXiv:2303.14189, 2023. diff --git a/configs/fastvit/fastvit_t8_ascend.yaml b/configs/fastvit/fastvit_t8_ascend.yaml new file mode 100644 index 000000000..ad0abe128 --- /dev/null +++ b/configs/fastvit/fastvit_t8_ascend.yaml @@ -0,0 +1,60 @@ +# system +mode: 0 +distribute: False +num_parallel_workers: 8 +val_while_train: True +val_interval: 1 +log_interval: 100 + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 128 + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +vflip: 0.0 +interpolation: "bicubic" +re_prob: 0.1 +mixup: 0.8 +cutmix: 1.0 +color_jitter: 0.4 +auto_augment: "randaug-m7-mstd0.5" + +# model +model: "fastvit_t8" +num_classes: 1000 +pretrained: False +keep_checkpoint_max: 10 +ckpt_save_policy: "latest_k" +ckpt_save_interval: 1 +ckpt_save_dir: "./ckpt" +epoch_size: 300 +dataset_sink_mode: True +ema_decay: 0.9995 +amp_level: "O2" +loss_scale_type: 'auto' + +# loss +loss: "CE" +label_smoothing: 0.1 + +# lr scheduler +scheduler: "cosine_decay" +lr: 0.001 +min_lr: 0.0 +warmup_epochs: 5 +warmup_factor: 0.01 +decay_epochs: 295 + +# optimizer +opt: "adamw" +momentum: 0.9 +weight_decay: 0.05 +filter_bias_and_bn: True +use_nesterov: False diff --git a/mindcv/models/fastvit.py b/mindcv/models/fastvit.py new file mode 100644 index 000000000..9994f2c30 --- /dev/null +++ b/mindcv/models/fastvit.py @@ -0,0 +1,1651 @@ +"""Reference:https://github.com/apple/ml-fastvit""" +import copy +import math +import os +from collections import OrderedDict +from functools import partial +from typing import List, Optional, Tuple, Union + +import mindspore as ms +import mindspore.common.initializer as init +from mindspore import nn, ops +from mindspore.numpy import ones + +from mindcv.models.layers.pooling import GlobalAvgPooling +from mindcv.models.registry import register_model + +IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 256, 256), + "pool_size": None, + "crop_pct": 0.95, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "fastvit_t": _cfg(crop_pct=0.9), + "fastvit_s": _cfg(crop_pct=0.9), + "fastvit_m": _cfg(crop_pct=0.95), +} + + +def convolutional_stem( + in_channels: int, out_channels: int, inference_mode: bool = False +) -> nn.SequentialCell: + """Build convolutional stem with MobileOne blocks. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + + Returns: + nn.Sequential object with stem elements. + """ + return nn.SequentialCell( + MobileOneBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + group=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + group=out_channels, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + group=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + ) + + +class MHSA(nn.Cell): + """Multi-headed Self Attention module. + + Source modified from: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ + + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Build MHSA module that can handle 3D or 4D input tensors. + + Args: + dim: Number of embedding dimensions. + head_dim: Number of hidden dimensions per head. Default: ``32`` + qkv_bias: Use bias or not. Default: ``False`` + attn_drop: Dropout rate for attention tensor. + proj_drop: Dropout rate for projection tensor. + """ + super(MHSA, self).__init__() + assert dim % head_dim == 0, "dim should be divisible by head_dim" + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.scale = head_dim**-0.5 + + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.attn_drop = nn.Dropout(p=attn_drop) + self.proj = nn.Dense(dim, dim) + self.proj_drop = nn.Dropout(p=proj_drop) + self.batch_matmul = ops.BatchMatMul() + + def construct(self, x: ms.Tensor) -> ms.Tensor: + shape = x.shape + B, C, H, W = shape + N = H * W + if len(shape) == 4: + x = nn.flatten(x, start_dim=2).transpose((0, -1, -2)) # (B, N, C) + qkv = ( + self.qkv(x) + .reshape((B, N, 3, self.num_heads, self.head_dim)) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ops.Unstack(axis=0)(qkv) + + # trick here to make q@k.t more stable + attn = self.batch_matmul(q*self.scale, k.transpose(0, 1, -1, -2)) + attn = nn.Softmax(axis=-1)(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose((0, 2, 1, -1)).reshape((B, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + if len(shape) == 4: + x = x.transpose((0, -1, -2)).reshape(B, C, H, W) + + return x + + +class PatchEmbed(nn.Cell): + """Convolutional patch embedding layer.""" + + def __init__( + self, + patch_size: int, + stride: int, + in_channels: int, + embed_dim: int, + inference_mode: bool = False, + ) -> None: + """Build patch embedding layer. + + Args: + patch_size: Patch size for embedding computation. + stride: Stride for convolutional embedding layer. + in_channels: Number of channels of input tensor. + embed_dim: Number of embedding dimensions. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + """ + super().__init__() + self.layers = nn.CellList() + self.layers.append( + ReparamLargeKernelConv( + in_channels=in_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=stride, + group=in_channels, + small_kernel=3, + inference_mode=inference_mode, + ) + ) + self.layers.append( + MobileOneBlock( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=1, + stride=1, + padding=0, + group=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ) + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + for layer in self.layers: + x = layer(x) + return x + + +class RepMixer(nn.Cell): + """Reparameterizable token mixer. + + For more details, please refer to our paper: + FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization + """ + + def __init__( + self, + dim, + kernel_size=3, + use_layer_scale=True, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + use_layer_scale: If True, learnable layer scale is used. Default: ``True`` + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + self.reparam_conv = None + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=self.kernel_size, + stride=1, + pad_mode='pad', + padding=self.kernel_size // 2, + group=self.dim, + has_bias=True, + ) + else: + self.norm = MobileOneBlock( + dim, + dim, + kernel_size, + padding=kernel_size // 2, + group=dim, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + kernel_size, + padding=kernel_size // 2, + group=dim, + use_act=False, + ) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = ms.Parameter( + layer_scale_init_value * ops.ones((dim, 1, 1), ms.float32), name='w', requires_grad=True + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + return x + else: + if self.use_layer_scale: + x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) + else: + x = x + self.mixer(x) - self.norm(x) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single + convolutional layer for efficient inference. + """ + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if self.use_layer_scale: + w = self.mixer.id_tensor + ops.ExpandDims()(self.layer_scale, -1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = ops.Squeeze()(self.layer_scale) * ( + self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + ) + else: + w = ( + self.mixer.id_tensor + + self.mixer.reparam_conv.weight + - self.norm.reparam_conv.weight + ) + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=self.kernel_size, + stride=1, + pad_mode='pad', + padding=self.kernel_size // 2, + group=self.dim, + has_bias=True, + ) + self.reparam_conv.weight = w + self.reparam_conv.bias = b + + for para in self.get_parameters(): + para = ops.stop_gradient(para) + self.__delattr__("mixer") + self.__delattr__("norm") + if self.use_layer_scale: + self.__delattr__("layer_scale") + + +class ConvFFN(nn.Cell): + """Convolutional FFN Module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: Optional[int] = None, + out_channels: Optional[int] = None, + act_layer: nn.Cell = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_channels: Number of input channels. + hidden_channels: Number of channels after expansion. Default: None + out_channels: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.conv = nn.SequentialCell( + OrderedDict( + [("conv", nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=7, + pad_mode='pad', + padding=3, + group=in_channels, + has_bias=False,)), + ("bn", nn.BatchNorm2d(num_features=out_channels))])) + self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + self.drop = nn.Dropout(p=drop) + self._init_weights() + + def _init_weights(self) -> None: + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=0.02), + cell.weight.shape, + cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepCPE(nn.Cell): + """Implementation of conditional positional encoding. + + For more details refer to paper: + `Conditional Positional Encodings for Vision Transformers `_ + + In our implementation, we can reparameterize this module to eliminate a skip connection. + """ + + def __init__( + self, + in_channels: int, + embed_dim: int = 768, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, + ) -> None: + """Build reparameterizable conditional positional encoding + + Args: + in_channels: Number of input channels. + embed_dim: Number of embedding dimensions. Default: 768 + spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super(RepCPE, self).__init__() + if isinstance(spatial_shape, int): + spatial_shape = tuple([spatial_shape] * 2) + assert isinstance(spatial_shape, Tuple), ( + f'"spatial_shape" must by a sequence or int, ' + f"get {type(spatial_shape)} instead." + ) + assert len(spatial_shape) == 2, ( + f'Length of "spatial_shape" should be 2, ' + f"got {len(spatial_shape)} instead." + ) + + self.spatial_shape = spatial_shape + self.embed_dim = embed_dim + self.in_channels = in_channels + self.group = embed_dim + self.reparam_conv = None + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + pad_mode='pad', + padding=int(self.spatial_shape[0] // 2), + group=self.embed_dim, + has_bias=True, + ) + else: + self.pe = nn.Conv2d( + in_channels, + embed_dim, + spatial_shape, + 1, + 'pad', + int(spatial_shape[0] // 2), + has_bias=True, + group=embed_dim, + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + return x + else: + x = self.pe(x) + x + return x + + def reparameterize(self) -> None: + # Build equivalent Id tensor + input_dim = self.in_channels // self.group + kernel_value = ops.Zeros()( + ( + self.in_channels, + input_dim, + self.spatial_shape[0], + self.spatial_shape[1], + ), ms.float32 + ) + for i in range(self.in_channels): + kernel_value[ + i, + i % input_dim, + self.spatial_shape[0] // 2, + self.spatial_shape[1] // 2, + ] = 1 + id_tensor = kernel_value + + # Reparameterize Id tensor and conv + w_final = id_tensor + self.pe.weight + b_final = self.pe.bias + + # Introduce reparam conv + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + pad_mode='pad', + padding=int(self.spatial_shape[0] // 2), + group=self.embed_dim, + has_bias=True, + ) + self.reparam_conv.weight = w_final + self.reparam_conv.bias = b_final + + for para in self.get_parameters(): + para = ops.stop_gradient(para) + self.__delattr__("pe") + + +class RepMixerBlock(nn.Cell): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Cell = nn.GELU, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( + mlp_ratio + ) + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvFFN( + in_channels=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop Path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = ms.Parameter( + layer_scale_init_value * ops.ones((dim, 1, 1), ms.float32), requires_grad=True + ) + + def construct(self, x): + if self.use_layer_scale: + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale * self.convffn(x)) + else: + x = self.token_mixer(x) + x = x + self.drop_path(self.convffn(x)) + return x + + +class AttentionBlock(nn.Cell): + """Implementation of metaformer block with MHSA as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.BatchNorm2d, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + ): + """Build Attention Block. + + Args: + dim: Number of embedding dimensions. + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + """ + + super().__init__() + + self.norm = norm_layer(dim) + self.token_mixer = MHSA(dim=dim) + + assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( + mlp_ratio + ) + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvFFN( + in_channels=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = ms.Parameter( + layer_scale_init_value * ops.ones((dim, 1, 1), ms.float32), requires_grad=True + ) + self.layer_scale_2 = ms.Parameter( + layer_scale_init_value * ops.ones((dim, 1, 1), ms.float32), requires_grad=True + ) + + def construct(self, x): + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) + else: + x = x + self.drop_path(self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.convffn(x)) + return x + + +def basic_blocks( + dim: int, + block_index: int, + num_blocks: List[int], + token_mixer_type: str, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.BatchNorm2d, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode=False, +) -> nn.SequentialCell: + """Build FastViT blocks within a stage. + + Args: + dim: Number of embedding dimensions. + block_index: block index. + num_blocks: List containing number of blocks per stage. + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + use_layer_scale: Flag to turn on layer scale regularization. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + + Returns: + nn.Sequential object of all the blocks within the stage. + """ + blocks = [] + for block_idx in range(num_blocks[block_index]): + block_dpr = ( + drop_path_rate + * (block_idx + sum(num_blocks[:block_index])) + / (sum(num_blocks) - 1) + ) + if token_mixer_type == "repmixer": + blocks.append( + RepMixerBlock( + dim, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + ) + elif token_mixer_type == "attention": + blocks.append( + AttentionBlock( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + ) + ) + else: + raise ValueError( + "Token mixer type: {} not supported".format(token_mixer_type) + ) + blocks = nn.SequentialCell(*blocks) + + return blocks + + +class FastViT(nn.Cell): + """ + This class implements `FastViT architecture `_ + """ + + def __init__( + self, + layers, + token_mixers: Tuple[str, ...], + embed_dims=None, + mlp_ratios=None, + downsamples=None, + repmixer_kernel_size=3, + norm_layer: nn.Cell = nn.BatchNorm2d, + act_layer: nn.Cell = nn.GELU, + num_classes=1000, + pos_embs=None, + down_patch_size=7, + down_stride=2, + drop_rate=0.0, + drop_path_rate=0.0, + use_layer_scale=True, + layer_scale_init_value=1e-5, + fork_feat=False, + init_cfg=None, + pretrained=None, + cls_ratio=2.0, + inference_mode=False, + **kwargs, + ) -> None: + + super().__init__() + + if not fork_feat: + self.num_classes = num_classes + self.fork_feat = fork_feat + + if pos_embs is None: + pos_embs = [None] * len(layers) + + # Convolutional stem + self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) + + # Build the main stages of the network architecture + self.network = nn.CellList() + for i in range(len(layers)): + # Add position embeddings if requested + if pos_embs[i] is not None: + self.network.append( + pos_embs[i]( + embed_dims[i], embed_dims[i], inference_mode=inference_mode + ) + ) + stage = basic_blocks( + embed_dims[i], + i, + layers, + token_mixer_type=token_mixers[i], + kernel_size=repmixer_kernel_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + self.network.append(stage) + if i >= len(layers) - 1: + break + + # Patch merging/downsampling between stages. + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + self.network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_channels=embed_dims[i], + embed_dim=embed_dims[i + 1], + inference_mode=inference_mode, + ) + ) + # For segmentation and detection, extract intermediate output + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get("FORK_LAST3", None): + """For RetinaNet, `start_level=1`. The first norm layer will not used. + cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` + """ + layer = nn.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f"norm{i_layer}" + self.insert_child_to_cell(layer_name, layer) + else: + # Classifier head + self.gap = GlobalAvgPooling() + self.conv_exp = MobileOneBlock( + in_channels=embed_dims[-1], + out_channels=int(embed_dims[-1] * cls_ratio), + kernel_size=3, + stride=1, + padding=1, + group=embed_dims[-1], + inference_mode=inference_mode, + use_se=True, + num_conv_branches=1, + ) + self.head = ( + nn.Dense(int(embed_dims[-1] * cls_ratio), num_classes) + if num_classes > 0 + else nn.Identity() + ) + + self.cls_init_weights() + self.init_cfg = copy.deepcopy(init_cfg) + + def cls_init_weights(self) -> None: + """Init. for classification""" + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=0.02), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense) and cell.bias is not None: + cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) + + def forward_embeddings(self, x: ms.Tensor) -> ms.Tensor: + x = self.patch_embed(x) + return x + + def forward_tokens(self, x: ms.Tensor) -> ms.Tensor: + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if self.fork_feat and idx in self.out_indices: + norm_layer = getattr(self, f"norm{idx}") + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + # output the features of four stages for dense prediction + return outs + # output only the features of last layer for image classification + return x + + def construct(self, x: ms.Tensor) -> ms.Tensor: + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + if self.fork_feat: + # output features of four stages for dense prediction + return x + # for image classification + x = self.conv_exp(x) + x = self.gap(x) + x = x.view((x.shape[0], -1)) + cls_out = self.head(x) + return cls_out + + +@register_model +def fastvit_t8(pretrained=False, **kwargs): + """Instantiate FastViT-T8 model variant.""" + layers = [2, 2, 4, 2] + embed_dims = [48, 96, 192, 384] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_t"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_t12(pretrained=False, **kwargs): + """Instantiate FastViT-T12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_t"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_s12(pretrained=False, **kwargs): + """Instantiate FastViT-S12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa12(pretrained=False, **kwargs): + """Instantiate FastViT-SA12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa24(pretrained=False, **kwargs): + """Instantiate FastViT-SA24 model variant.""" + layers = [4, 4, 12, 4] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa36(pretrained=False, **kwargs): + """Instantiate FastViT-SA36 model variant.""" + layers = [6, 6, 18, 6] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + embed_dims=embed_dims, + token_mixers=token_mixers, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + layer_scale_init_value=1e-6, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_m"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_ma36(pretrained=False, **kwargs): + """Instantiate FastViT-MA36 model variant.""" + layers = [6, 6, 18, 6] + embed_dims = [76, 152, 304, 608] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + embed_dims=embed_dims, + token_mixers=token_mixers, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + layer_scale_init_value=1e-6, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_m"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +class DropPath(nn.Cell): + """DropPath (Stochastic Depth) regularization layers""" + + def __init__( + self, + drop_prob: float = 0.0, + scale_by_keep: bool = True, + ) -> None: + super().__init__() + self.keep_prob = 1.0 - drop_prob + self.scale_by_keep = scale_by_keep + self.dropout = nn.Dropout(p=drop_prob) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + if self.keep_prob == 1.0 or not self.training: + return x + shape = (x.shape[0], ) + (1,) * (x.ndim - 1) + random_tensor = self.dropout(ones(shape)) + if not self.scale_by_keep: + random_tensor = ops.mul(random_tensor, self.keep_prob) + return x * random_tensor + + +class SEBlock(nn.Cell): + + def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: + """Construct a Squeeze and Excite Module. + + Args: + in_channels: Number of input channels. + rd_ratio: Input channel reduction ratio. + """ + super(SEBlock, self).__init__() + self.reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=int(in_channels * rd_ratio), + kernel_size=1, + stride=1, + pad_mode='valid', + has_bias=True, + ) + self.expand = nn.Conv2d( + in_channels=int(in_channels * rd_ratio), + out_channels=in_channels, + kernel_size=1, + pad_mode='valid', + stride=1, + has_bias=True, + ) + + def construct(self, inputs: ms.Tensor) -> ms.Tensor: + """Apply forward pass.""" + b, c, h, w = inputs.shape + x = ops.AvgPool(pad_mode='valid', kernel_size=(h, w))(inputs) + x = self.reduce(x) + x = nn.ReLU()(x) + x = self.expand(x) + x = nn.Sigmoid()(x) + x = x.view((-1, c, 1, 1)) + return inputs * x + + +class MobileOneBlock(nn.Cell): + """MobileOne building block. + + This block has a multi-branched architecture at train-time + and plain-CNN style architecture at inference time + For more details, please refer to our paper: + `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + group: int = 1, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + activation: nn.Cell = nn.GELU, + ) -> None: + """Construct a MobileOneBlock module. + + Args: + in_channels: Number of channels in the input. + out_channels: Number of channels produced by the block. + kernel_size: Size of the convolution kernel. + stride: Stride size. + padding: Zero-padding size. + dilation: Kernel dilation factor. + group: Group number. + inference_mode: If True, instantiates model in inference mode. + use_se: Whether to use SE-ReLU activations. + use_act: Whether to use activation. Default: ``True`` + use_scale_branch: Whether to use scale branch. Default: ``True`` + num_conv_branches: Number of linear conv branches. + """ + super(MobileOneBlock, self).__init__() + self.inference_mode = inference_mode + self.group = group + self.stride = stride + self.padding = padding + self.dilation = dilation + self.kernel_size = kernel_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_conv_branches = num_conv_branches + + # Check if SE-ReLU is requested + if use_se: + self.se = SEBlock(out_channels) + else: + self.se = nn.Identity() + + if use_act: + self.activation = activation() + else: + self.activation = nn.Identity() + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + dilation=dilation, + group=group, + has_bias=True, + ) + else: + # Re-parameterizable skip connection + self.rbr_skip = ( + nn.BatchNorm2d(num_features=in_channels) + if out_channels == in_channels and stride == 1 + else None + ) + + # Re-parameterizable conv branches + if num_conv_branches > 0: + self.rbr_conv = nn.CellList() + for _ in range(self.num_conv_branches): + self.rbr_conv.append( + self._conv_bn(kernel_size=kernel_size, padding=padding) + ) + else: + self.rbr_conv = None + + # Re-parameterizable scale branch + self.rbr_scale = None + if (kernel_size > 1) and use_scale_branch: + self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + """Apply forward pass.""" + # Inference mode forward pass. + if self.inference_mode: + return self.activation(self.se(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Skip branch output + identity_out = 0 + if self.rbr_skip is not None: + identity_out = self.rbr_skip(x) + + # Scale branch output + scale_out = 0 + if self.rbr_scale is not None: + scale_out = self.rbr_scale(x) + + # Other branches + out = scale_out + identity_out + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + out += self.rbr_conv[ix](x) + + return self.activation(self.se(out)) + + def reparameterize(self): + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.inference_mode: + return + kernel, bias = self._get_kernel_bias() + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + pad_mode='pad', + padding=self.padding, + dilation=self.dilation, + group=self.group, + has_bias=True, + ) + self.reparam_conv.weight = kernel + self.reparam_conv.bias = bias + + # Delete un-used branches + for para in self.get_parameters(): + para = ops.stop_gradient(para) + self.__delattr__("rbr_conv") + self.__delattr__("rbr_scale") + if hasattr(self, "rbr_skip"): + self.__delattr__("rbr_skip") + + self.inference_mode = True + + def _get_kernel_bias(self) -> Tuple[ms.Tensor, ms.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.rbr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + pad_op = nn.Pad(paddings=((0, 0), (0, 0), (pad, pad), (pad, pad))) + kernel_scale = pad_op(kernel_scale) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.rbr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor( + self, branch: Union[nn.SequentialCell, nn.BatchNorm2d] + ) -> Tuple[ms.Tensor, ms.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + Args: + branch: Sequence of ops to be fused. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.SequentialCell): + kernel = branch.conv.weight + running_mean = branch.bn.moving_mean + running_var = branch.bn.moving_variance + gamma = branch.bn.gamma + beta = branch.bn.beta + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.group + kernel_value = ops.zeros( + (self.in_channels, input_dim, self.kernel_size, self.kernel_size), + ms.float32 + ) + for i in range(self.in_channels): + kernel_value[ + i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.moving_mean + running_var = branch.moving_variance + gamma = branch.gamma + beta = branch.beta + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, kernel_size: int, padding: int) -> nn.SequentialCell: + """Helper method to construct conv-batchnorm layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + Conv-BN module. + """ + mod_list = nn.SequentialCell( + OrderedDict( + [("conv", nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + pad_mode='pad', + padding=padding, + group=self.group, + has_bias=False, )), + ("bn", nn.BatchNorm2d(num_features=self.out_channels))])) + return mod_list + + +class ReparamLargeKernelConv(nn.Cell): + """Building Block of RepLKNet + + This class defines overparameterized large kernel conv block + introduced in `RepLKNet `_ + + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + group: int, + small_kernel: int, + inference_mode: bool = False, + activation: nn.Cell = nn.GELU, + ) -> None: + """Construct a ReparamLargeKernelConv module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Kernel size of the large kernel conv branch. + stride: Stride size. Default: 1 + groups: Group number. Default: 1 + small_kernel: Kernel size of small kernel conv branch. + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + activation: Activation module. Default: ``nn.GELU`` + """ + super(ReparamLargeKernelConv, self).__init__() + + self.stride = stride + self.group = group + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation() + + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.padding = kernel_size // 2 + self.lkb_reparam = None + self.small_conv = None + if inference_mode: + self.lkb_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode='pad', + padding=self.padding, + dilation=1, + group=group, + has_bias=True, + ) + else: + self.lkb_origin = self._conv_bn( + kernel_size=kernel_size, padding=self.padding + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = self._conv_bn( + kernel_size=small_kernel, padding=small_kernel // 2 + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + """Apply forward pass.""" + if self.lkb_reparam is not None: + out = self.lkb_reparam(x) + else: + out = self.lkb_origin(x) + if self.small_conv is not None: + out += self.small_conv(x) + + self.activation(out) + return out + + def get_kernel_bias(self) -> Tuple[ms.Tensor, ms.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + pad_op = nn.Pad(paddings=((0, 0), (0, 0), ((self.kernel_size - self.small_kernel) // 2, + (self.kernel_size - self.small_kernel) // 2), + ((self.kernel_size - self.small_kernel) // 2, + (self.kernel_size - self.small_kernel) // 2))) + eq_k += pad_op(small_k) + return eq_k, eq_b + + def reparameterize(self) -> None: + """ + Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + eq_k, eq_b = self.get_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + pad_mode='pad', + padding=self.padding, + dilation=self.lkb_origin.conv.dilation, + group=self.group, + has_bias=True, + ) + + self.lkb_reparam.weight = eq_k + self.lkb_reparam.bias = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + @staticmethod + def _fuse_bn( + conv: ms.Tensor, bn: nn.BatchNorm2d + ) -> Tuple[ms.Tensor, ms.Tensor]: + """Method to fuse batchnorm layer with conv layer. + + Args: + conv: Convolutional kernel weights. + bn: Batchnorm 2d layer. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + kernel = conv.weight + running_mean = bn.moving_mean + running_var = bn.moving_variance + gamma = bn.gamma + beta = bn.beta + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.SequentialCell: + """Helper method to construct conv-batchnorm layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + A nn.Sequential Conv-BN module. + """ + mod_list = nn.SequentialCell( + OrderedDict( + [("conv", nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + pad_mode='pad', + padding=padding, + group=self.group, + has_bias=False,)), + ("bn", nn.BatchNorm2d(num_features=self.out_channels))])) + return mod_list + + +def reparameterize_model(model: nn.Cell) -> nn.Cell: + """Method returns a model where a multi-branched structure + used in training is re-parameterized into a single branch + for inference. + + Args: + model: MobileOne model in train mode. + + Returns: + MobileOne model in inference mode. + """ + # Avoid editing original graph + model = copy.deepcopy(model) + for _, cell in model.cells_and_names(): + if hasattr(cell, "reparameterize"): + cell.reparameterize() + return model + + +class CosineWDSchedule: + def __init__(self, optimizer, t_max, eta_min=0, last_epoch=-1): + self.last_epoch = last_epoch + self.base_wds = [group["weight_decay"] for group in optimizer.param_groups] + self.t_max = t_max + self.eta_min = eta_min + + def _get_wd(self, optimizer): + if self.last_epoch == 0: + return self.base_wds + elif (self.last_epoch - 1 - self.t_max) % (2 * self.t_max) == 0: + return [ + group["weight_decay"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.t_max)) / 2 + for base_lr, group in zip(self.base_wds, optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * self.last_epoch / self.t_max)) + / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.t_max)) + * (group["weight_decay"] - self.eta_min) + + self.eta_min + for group in optimizer.param_groups + ] + + def update_weight_decay(self, optimizer): + self.last_epoch += 1 + values = self._get_wd(optimizer) + for i, data in enumerate(zip(optimizer.param_groups, values)): + param_group, wd = data + # Avoid updating weight decay of param_groups that should not be decayed. + if param_group["weight_decay"] > 0.0: + param_group["weight_decay"] = wd + + +class DistillationLoss(nn.Cell): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + + def __init__( + self, + base_criterion: nn.Cell, + teacher_model: nn.Cell, + distillation_type: str, + alpha: float, + tau: float, + ): + super(DistillationLoss, self).__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ["none", "soft", "hard"] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def construct(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model. + outputs: Output tensor from model being trained. + labels: the labels for the base criterion. + """ + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == "none": + return base_loss + teacher_outputs = self.teacher_model(inputs) + teacher_outputs = ops.stop_gradient(teacher_outputs) + if self.distillation_type == "soft": + T = self.tau + distillation_loss = ( + ops.KLDivLoss( + nn.LogSoftmax(outputs / T, axis=1), + nn.LogSoftmax(teacher_outputs / T, axis=1), + reduction="sum", + ) + * (T * T) + / ops.Size()(outputs) + ) + elif self.distillation_type == "hard": + distillation_loss = ops.cross_entropy(outputs, ops.Argmax(axis=1)(teacher_outputs)) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss