Skip to content

Commit

Permalink
add VitSAM
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou committed Oct 24, 2023
1 parent d714673 commit 76a9170
Show file tree
Hide file tree
Showing 6 changed files with 763 additions and 16 deletions.
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
vgg,
visformer,
vit,
vit_sam,
volo,
xception,
xcit,
Expand Down Expand Up @@ -107,6 +108,7 @@
from .vgg import *
from .visformer import *
from .vit import *
from .vit_sam import *
from .volo import *
from .xception import *
from .xcit import *
Expand Down Expand Up @@ -165,6 +167,7 @@
__all__.extend(vgg.__all__)
__all__.extend(visformer.__all__)
__all__.extend(vit.__all__)
__all__.extend(vit_sam.__all__)
__all__.extend(volo.__all__)
__all__.extend(["Xception", "xception"])
__all__.extend(xcit.__all__)
14 changes: 13 additions & 1 deletion mindcv/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
"""layers init"""
from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite
from . import (
activation,
conv_norm_act,
drop_path,
format,
identity,
patch_dropout,
pooling,
selective_kernel,
squeeze_excite,
)
from .activation import *
from .conv_norm_act import *
from .drop_path import *
from .format import *
from .identity import *
from .patch_dropout import *
from .pooling import *
from .selective_kernel import *
from .squeeze_excite import *
Expand Down
30 changes: 30 additions & 0 deletions mindcv/models/layers/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from enum import Enum

import mindspore


class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'


def nchw_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(start_dim=2).transpose((0, 2, 1))
elif fmt == Format.NCL:
x = x.flatten(start_dim=2)
return x


def nhwc_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(start_dim=1, end_dim=2)
elif fmt == Format.NCL:
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1))
return x
54 changes: 54 additions & 0 deletions mindcv/models/layers/patch_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

import mindspore as ms
from mindspore import nn, ops


class PatchDropout(nn.Cell):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
self.sort = ops.Sort()

def forward(self, x):
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x

if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
keep_indices = indices[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices, _ = self.sort(keep_indices)
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
x = ops.gather_elements(x, dim=1, index=keep_indices)

if prefix_tokens is not None:
x = ops.concat((prefix_tokens, x), axis=1)

if self.return_indices:
return x, keep_indices
return x
65 changes: 50 additions & 15 deletions mindcv/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mindspore import Tensor, nn, ops

from .format import Format, nchw_to
from .helpers import to_2tuple


Expand All @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Cell, optional): Normalization layer. Default: None
"""
output_fmt: Format

def __init__(
self,
image_size: int = 224,
image_size: Optional[int] = 224,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96,
norm_layer: Optional[nn.Cell] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
) -> None:
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
self.image_size = image_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]

self.in_chans = in_chans
self.patch_size = to_2tuple(patch_size)
if image_size is not None:
self.image_size = to_2tuple(image_size)
self.grid_size = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.image_size = None
self.grid_size = None
self.num_patches = None

if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
self.flatten = flatten
self.output_fmt = Format.NCHW

self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")
pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")

if norm_layer is not None:
if isinstance(embed_dim, int):
Expand All @@ -50,11 +67,29 @@ def __init__(

def construct(self, x: Tensor) -> Tensor:
"""docstring"""
B = x.shape[0]
# FIXME look at relaxing size constraints
x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
B, C, H, W = x.shape
if self.image_size is not None:
if self.strict_img_size:
if (H, W) != (self.image_size[0], self.image_size[1]):
raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
f"{self.image_size[1]}).")
elif not self.dynamic_img_pad:
if H % self.patch_size[0] != 0:
raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
if W % self.patch_size[1] != 0:
raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = ops.pad(x, (0, pad_w, 0, pad_h))

# FIXME look at relaxing size constraints
x = self.proj(x)
if self.flatten:
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
elif self.output_fmt != "NCHW":
x = nchw_to(x, self.output_fmt)
if self.norm is not None:
x = self.norm(x)
return x
Loading

0 comments on commit 76a9170

Please sign in to comment.