Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ViT for feature extraction, support elative positional embedding and layer scale #733

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows.

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) |
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) |
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) |
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) |
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) |
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) |

</div>

Expand Down
51 changes: 51 additions & 0 deletions mindcv/models/layers/pos_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""positional embedding"""
from typing import Tuple

import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops


class RelativePositionBiasWithCLS(nn.Cell):
def __init__(
self,
window_size: Tuple[int],
num_heads: int
):
super(RelativePositionBiasWithCLS, self).__init__()
self.window_size = window_size
self.num_tokens = window_size[0] * window_size[1]

num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# 3: cls to token, token to cls, cls to cls
self.relative_position_bias_table = Parameter(
Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16)
)
coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1)
coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1)
coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww]

relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww]
relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2]
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[0] - 1

relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1),
dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1]
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
relative_position_index = Tensor(relative_position_index.reshape(-1))

self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16)
self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False)

def construct(self):
out = ops.matmul(self.relative_position_index, self.relative_position_bias_table)
out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1))
out = ops.transpose(out, (2, 0, 1))
out = ops.expand_dims(out, 0)
return out
Loading