From b6ef644d286bfbca70304c39cedad6d0f20aa4a2 Mon Sep 17 00:00:00 2001 From: adverbial03 Date: Sun, 9 Apr 2023 21:45:39 +0800 Subject: [PATCH] moe layer --- README.md | 36 +- .../models/cascade_mask_rcnn_swin_moe_fpn.py | 216 ++++ ...rain_480-800_giou_4conv1f_adamw_3x_coco.py | 144 +++ mmdet/models/backbones/__init__.py | 4 +- .../models/backbones/swin_transformer_moe.py | 927 ++++++++++++++++++ mmdet/models/detectors/two_stage.py | 14 +- 6 files changed, 1334 insertions(+), 7 deletions(-) create mode 100644 configs/_base_/models/cascade_mask_rcnn_swin_moe_fpn.py create mode 100644 configs/swin/cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py create mode 100644 mmdet/models/backbones/swin_transformer_moe.py diff --git a/README.md b/README.md index 181dd284df0..a6d86bc34ae 100644 --- a/README.md +++ b/README.md @@ -133,9 +133,43 @@ optimizer_config = dict( > **Image Classification**: See [Swin Transformer for Image Classification](https://github.com/microsoft/Swin-Transformer). -> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). +> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/Transformer/Swin-Transformer-Semantic-Segmentation). > **Self-Supervised Learning**: See [MoBY with Swin Transformer](https://github.com/SwinTransformer/Transformer-SSL). > **Video Recognition**, See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). +## swin-T moe + +I added Swin Transformer MoE (referred to as Swin-T MoE hereafter) to the backbone network. MoE is a method that expands the model parameters and improves the model performance. The implementation of Swin Transformer MoE used Microsoft's Tutel framework. + +### Install [Tutel](https://github.com/microsoft/tutel) + +``` +python3 -m pip uninstall tutel -y +python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main +``` + +You can check out Swin-T MoE at . + +``` +.\mmdet\models\backbones\swin_transformer_moe.py. +``` + +I provided the relevant configuration files for reference: + +contains the parameters for the Swin-T MoE backbone network: + +``` +.\configs\swin\cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py +``` + +contains the modified configuration for the backbone network: + +``` +.\configs\swin\cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py +``` + +As the output of Swin-T MoE is different from Swin-T, I modified the `extract_feat` function in `.\mmdet\models\detectors\two_stage.py`. + +You can change the config according to your needs. diff --git a/configs/_base_/models/cascade_mask_rcnn_swin_moe_fpn.py b/configs/_base_/models/cascade_mask_rcnn_swin_moe_fpn.py new file mode 100644 index 00000000000..93f9e760156 --- /dev/null +++ b/configs/_base_/models/cascade_mask_rcnn_swin_moe_fpn.py @@ -0,0 +1,216 @@ +# model settings +model = dict( + type='CascadeRCNN', + pretrained=None, + backbone=dict( + type='SwinTransformerMoE', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False, + moe_blocks=[ [ -1 ], [ -1 ], [ 1, 3, 5 ], [ 1 ] ], + # moe_blocks=[ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ], + num_local_experts=1, + top_value= 1, + capacity_factor= 1.25, + is_gshard_loss= False, + moe_drop= 0.1, + aux_loss_weight= 0.01 + ), + neck=dict( + type='FPN', + in_channels=[96, 192, 384, 768], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/configs/swin/cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py b/configs/swin/cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py new file mode 100644 index 00000000000..8e2e068336b --- /dev/null +++ b/configs/swin/cascade_mask_rcnn_swin_moe_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py @@ -0,0 +1,144 @@ +_base_ = [ + '../_base_/models/cascade_mask_rcnn_swin_moe_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + pretrained=None, + backbone=dict( + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + ape=False, + drop_path_rate=0.2, + patch_norm=True, + use_checkpoint=False + ), + neck=dict(in_channels=[96, 192, 384, 768]), + roi_head=dict( + bbox_head=[ + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + # norm_cfg=dict(type='BN', requires_grad=True), + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + # norm_cfg=dict(type='BN', requires_grad=True), + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + # norm_cfg=dict(type='BN', requires_grad=True), + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)) + ])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='AutoAugment', + policies=[ + [ + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict(type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ] + ]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) +lr_config = dict(step=[27, 33]) +runner = dict(type='EpochBasedRunnerAmp', max_epochs=36) + +# do not use mmdet version fp16 +fp16 = None +optimizer_config = dict( + type="DistOptimizerHook", + update_interval=1, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + use_fp16=True, +) diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 5d35e48ab7b..c92616bd693 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -11,9 +11,9 @@ from .ssd_vgg import SSDVGG from .trident_resnet import TridentResNet from .swin_transformer import SwinTransformer - +from .swin_transformer_moe import SwinTransformerMoE __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', - 'ResNeSt', 'TridentResNet', 'SwinTransformer' + 'ResNeSt', 'TridentResNet', 'SwinTransformer','SwinTransformerMoE' ] diff --git a/mmdet/models/backbones/swin_transformer_moe.py b/mmdet/models/backbones/swin_transformer_moe.py new file mode 100644 index 00000000000..a02c9ad5ab2 --- /dev/null +++ b/mmdet/models/backbones/swin_transformer_moe.py @@ -0,0 +1,927 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torch.distributed as dist +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from mmcv_custom import load_checkpoint +from mmdet.utils import get_root_logger +from ..builder import BACKBONES + +try: + from tutel import moe as tutel_moe +except: + tutel_moe = None + print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.") + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., + mlp_fc2_bias=True): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class MoEMlp(nn.Module): + def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25, + cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, + gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02, + mlp_fc2_bias=True): + super().__init__() + + self.in_features = in_features + self.hidden_features = hidden_features + self.num_local_experts = num_local_experts + self.top_value = top_value + self.capacity_factor = capacity_factor + self.cosine_router = cosine_router + self.normalize_gate = normalize_gate + self.use_bpr = use_bpr + self.init_std = init_std + self.mlp_fc2_bias = mlp_fc2_bias + + self.dist_rank = dist.get_rank() + + self._dropout = nn.Dropout(p=moe_drop) + + _gate_type = {'type': 'cosine_top' if cosine_router else 'top', + 'k': top_value, 'capacity_factor': capacity_factor, + 'gate_noise': gate_noise, 'fp32_gate': True} + if cosine_router: + _gate_type['proj_dim'] = cosine_router_dim + _gate_type['init_t'] = cosine_router_init_t + self._moe_layer = tutel_moe.moe_layer( + gate_type=_gate_type, + model_dim=in_features, + experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features, + 'activation_fn': lambda x: self._dropout(F.gelu(x))}, + scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True), + seeds=(1, self.dist_rank + 1, self.dist_rank + 1), + batch_prioritized_routing=use_bpr, + normalize_gate=normalize_gate, + is_gshard_loss=is_gshard_loss, + + ) + if not self.mlp_fc2_bias: + self._moe_layer.experts.batched_fc2_bias.requires_grad = False + + def forward(self, x): + x = self._moe_layer(x) + return x, x.l_aux + + def extra_repr(self) -> str: + return f'[Statistics-{self.dist_rank}] param count for MoE, ' \ + f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \ + f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \ + f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}' + + def _init_weights(self): + if hasattr(self._moe_layer, "experts"): + trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std) + trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std) + nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0) + nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0) + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +# from swin-moe +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +#from swin-t +# class WindowAttention(nn.Module): +# """ Window based multi-head self attention (W-MSA) module with relative position bias. +# It supports both of shifted and non-shifted window. + +# Args: +# dim (int): Number of input channels. +# window_size (tuple[int]): The height and width of the window. +# num_heads (int): Number of attention heads. +# qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True +# qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set +# attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 +# proj_drop (float, optional): Dropout ratio of output. Default: 0.0 +# """ + +# def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + +# super().__init__() +# self.dim = dim +# self.window_size = window_size # Wh, Ww +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = qk_scale or head_dim ** -0.5 + +# # define a parameter table of relative position bias +# self.relative_position_bias_table = nn.Parameter( +# torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + +# # get pair-wise relative position index for each token inside the window +# coords_h = torch.arange(self.window_size[0]) +# coords_w = torch.arange(self.window_size[1]) +# coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww +# coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww +# relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww +# relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 +# relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 +# relative_coords[:, :, 1] += self.window_size[1] - 1 +# relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 +# relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww +# self.register_buffer("relative_position_index", relative_position_index) + +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) + +# trunc_normal_(self.relative_position_bias_table, std=.02) +# self.softmax = nn.Softmax(dim=-1) + +# def forward(self, x, mask=None): +# """ Forward function. + +# Args: +# x: input features with shape of (num_windows*B, N, C) +# mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None +# """ +# B_, N, C = x.shape +# qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + +# q = q * self.scale +# attn = (q @ k.transpose(-2, -1)) + +# relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( +# self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH +# relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww +# attn = attn + relative_position_bias.unsqueeze(0) + +# if mask is not None: +# nW = mask.shape[0] +# attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) +# attn = attn.view(-1, self.num_heads, N, N) +# attn = self.softmax(attn) +# else: +# attn = self.softmax(attn) + +# attn = self.attn_drop(attn) + +# x = (attn @ v).transpose(1, 2).reshape(B_, N, C) +# x = self.proj(x) +# x = self.proj_drop(x) + +# return x + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True + init_std: Initialization std. Default: 0.02 + pretrained_window_size (int): Window size in pre-training. + is_moe (bool): If True, this block is a MoE block. + num_local_experts (int): number of local experts in each device (GPU). Default: 1 + top_value (int): the value of k in top-k gating. Default: 1 + capacity_factor (float): the capacity factor in MoE. Default: 1.25 + cosine_router (bool): Whether to use cosine router. Default: False + normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False + use_bpr (bool): Whether to use batch-prioritized-routing. Default: True + is_gshard_loss (bool): If True, use Gshard balance loss. + If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False + gate_noise (float): the noise ratio in top-k gating. Default: 1.0 + cosine_router_dim (int): Projection dimension in cosine router. + cosine_router_init_t (float): Initialization temperature in cosine router. + moe_drop (float): Dropout rate in MoE. Default: 0.0 + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0, + + is_moe=True, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, + normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, + cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.is_moe = is_moe + self.capacity_factor = capacity_factor + self.top_value = top_value + + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if self.is_moe: + self.mlp = MoEMlp(in_features=dim, + hidden_features=mlp_hidden_dim, + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std) + else: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + mlp_fc2_bias=mlp_fc2_bias) + + self.H = None + self.W = None + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + + # FFN + shortcut = x + x = self.norm2(x) + if self.is_moe: + x, l_aux = self.mlp(x) + x = shortcut + self.drop_path(x) + return x, l_aux + else: + x = shortcut + self.drop_path(self.mlp(x)) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True + init_std: Initialization std. Default: 0.02 + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + moe_blocks (tuple(int)): The index of each MoE block. + num_local_experts (int): number of local experts in each device (GPU). Default: 1 + top_value (int): the value of k in top-k gating. Default: 1 + capacity_factor (float): the capacity factor in MoE. Default: 1.25 + cosine_router (bool): Whether to use cosine router Default: False + normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False + use_bpr (bool): Whether to use batch-prioritized-routing. Default: True + is_gshard_loss (bool): If True, use Gshard balance loss. + If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False + gate_noise (float): the noise ratio in top-k gating. Default: 1.0 + cosine_router_dim (int): Projection dimension in cosine router. + cosine_router_init_t (float): Initialization temperature in cosine router. + moe_drop (float): Dropout rate in MoE. Default: 0.0 + """ + def __init__(self, dim, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, + mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0, + moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, + normalize_gate=False, use_bpr=True, is_gshard_loss=True, + cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.depth = depth + self.use_checkpoint = use_checkpoint + self.window_size=window_size + self.shift_size = window_size // 2 + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std, + pretrained_window_size=pretrained_window_size, + + is_moe=True if i in moe_block else False, + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + l_aux = 0.0 + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + out = checkpoint.checkpoint(blk, x, attn_mask) + else: + out = blk(x, attn_mask) + if isinstance(out, tuple): + x = out[0] + cur_l_aux = out[1] + l_aux = cur_l_aux + l_aux + else: + x = out + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww, l_aux + else: + return x, H, W, x, H, W, l_aux + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +@BACKBONES.register_module() +class SwinTransformerMoE(nn.Module): + """ + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True + init_std: Initialization std. Default: 0.02 + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. + moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer. + num_local_experts (int): number of local experts in each device (GPU). Default: 1 + top_value (int): the value of k in top-k gating. Default: 1 + capacity_factor (float): the capacity factor in MoE. Default: 1.25 + cosine_router (bool): Whether to use cosine router Default: False + normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False + use_bpr (bool): Whether to use batch-prioritized-routing. Default: True + is_gshard_loss (bool): If True, use Gshard balance loss. + If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False + gate_noise (float): the noise ratio in top-k gating. Default: 1.0 + cosine_router_dim (int): Projection dimension in cosine router. + cosine_router_init_t (float): Initialization temperature in cosine router. + moe_drop (float): Dropout rate in MoE. Default: 0.0 + aux_loss_weight (float): auxiliary loss weight. Default: 0.1 + """ + + + def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + + mlp_fc2_bias=True, init_std=0.02, pretrained_window_sizes=[0, 0, 0, 0], + moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25, + cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, + cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs): + + super().__init__() + + self._ddp_params_and_buffers_to_ignore = list() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.init_std = init_std + self.aux_loss_weight = aux_loss_weight + self.num_local_experts = num_local_experts + + # torch.distributed.init_process_group('nccl',world_size=1,rank=0) + + self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \ + else dist.get_world_size() // (-num_local_experts) + self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts) + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std, + pretrained_window_size=pretrained_window_sizes[i_layer], + + moe_block=moe_blocks[i_layer], + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop) + self.layers.append(layer) + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, MoEMlp): + m._init_weights() + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + l_aux = 0.0 + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww,cur_l_aux = layer(x, Wh, Ww) + l_aux = cur_l_aux + l_aux + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + k=tuple(outs) + return tuple(outs),l_aux * self.aux_loss_weight + + def add_param_to_skip_allreduce(self, param_name): + self._ddp_params_and_buffers_to_ignore.append(param_name) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + # TraceStack() + super(SwinTransformerMoE, self).train(mode) + self._freeze_stages() +import sys +def TraceStack(): + print ("--------------------") + frame=sys._getframe(1) + while frame: + print (frame.f_code.co_name), + print (frame.f_code.co_filename), + print (frame.f_lineno) + frame=frame.f_back \ No newline at end of file diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index ba5bdde980d..bf0270b485d 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -2,8 +2,9 @@ import torch.nn as nn # from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler -from ..builder import DETECTORS, build_backbone, build_head, build_neck from .base import BaseDetector +from ..builder import DETECTORS, build_backbone, build_head, build_neck +import time @DETECTORS.register_module() @@ -80,8 +81,12 @@ def init_weights(self, pretrained=None): def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) + is_l_aux=x[-1] if self.with_neck: - x = self.neck(x) + if(type(is_l_aux)==float or len(is_l_aux.shape)<4): + x = self.neck(x[0]) + else : + x= self.neck(x) return x def forward_dummy(self, img): @@ -144,6 +149,7 @@ def forward_train(self, losses = dict() # RPN forward and loss + if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) @@ -161,9 +167,9 @@ def forward_train(self, roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, - **kwargs) - losses.update(roi_losses) + **kwargs) + losses.update(roi_losses) return losses async def async_simple_test(self,