Skip to content

Commit

Permalink
add VitSAM
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou committed Oct 30, 2023
1 parent 5c87ac5 commit 078ef5b
Show file tree
Hide file tree
Showing 6 changed files with 621 additions and 27 deletions.
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
vgg,
visformer,
vit,
vit_sam,
volo,
xception,
xcit,
Expand Down Expand Up @@ -109,6 +110,7 @@
from .vgg import *
from .visformer import *
from .vit import *
from .vit_sam import *
from .volo import *
from .xception import *
from .xcit import *
Expand Down Expand Up @@ -168,6 +170,7 @@
__all__.extend(vgg.__all__)
__all__.extend(visformer.__all__)
__all__.extend(vit.__all__)
__all__.extend(vit_sam.__all__)
__all__.extend(volo.__all__)
__all__.extend(["Xception", "xception"])
__all__.extend(xcit.__all__)
4 changes: 0 additions & 4 deletions mindcv/models/layers/format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import Union

import mindspore

Expand All @@ -11,9 +10,6 @@ class Format(str, Enum):
NLC = 'NLC'


FormatT = Union[str, Format]


def nchw_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
Expand Down
26 changes: 13 additions & 13 deletions mindcv/models/layers/patch_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,41 @@ class PatchDropout(nn.Cell):
"""
https://arxiv.org/abs/2212.00794
"""

def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
assert 0 <= prob < 1.0
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
self.sort = ops.Sort()

def forward(self, x):
if not self.training or self.prob == 0.:
def construct(self, x):
if not self.training or self.prob == 0.0:
if self.return_indices:
return x, None
return x

if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
prefix_tokens, x = x[:, : self.num_prefix_tokens], x[:, self.num_prefix_tokens :]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
num_keep = max(1, int(L * (1.0 - self.prob)))
_, indices = ops.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
keep_indices = indices[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices, _ = self.sort(keep_indices)
keep_indices, _ = ops.sort(keep_indices)
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
x = ops.gather_elements(x, dim=1, index=keep_indices)

Expand Down
11 changes: 6 additions & 5 deletions mindcv/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PatchEmbed(nn.Cell):
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Cell, optional): Normalization layer. Default: None
"""

output_fmt: Format

def __init__(
Expand All @@ -37,11 +38,11 @@ def __init__(
self.patch_size = to_2tuple(patch_size)
if image_size is not None:
self.image_size = to_2tuple(image_size)
self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.grid_size = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.image_size = None
self.patches_resolution = None
self.grid_size = None
self.num_patches = None

if output_fmt is not None:
Expand Down Expand Up @@ -86,8 +87,8 @@ def construct(self, x: Tensor) -> Tensor:
# FIXME look at relaxing size constraints
x = self.proj(x)
if self.flatten:
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
x = ops.reshape(x, (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.transpose(x, (0, 2, 1))
elif self.output_fmt != "NCHW":
x = nchw_to(x, self.output_fmt)
if self.norm is not None:
Expand Down
10 changes: 5 additions & 5 deletions mindcv/models/layers/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@


def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'nearest',
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'nearest',
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
Expand Down
Loading

0 comments on commit 078ef5b

Please sign in to comment.