-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
763 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from enum import Enum | ||
|
||
import mindspore | ||
|
||
|
||
class Format(str, Enum): | ||
NCHW = 'NCHW' | ||
NHWC = 'NHWC' | ||
NCL = 'NCL' | ||
NLC = 'NLC' | ||
|
||
|
||
def nchw_to(x: mindspore.Tensor, fmt: Format): | ||
if fmt == Format.NHWC: | ||
x = x.permute(0, 2, 3, 1) | ||
elif fmt == Format.NLC: | ||
x = x.flatten(start_dim=2).transpose((0, 2, 1)) | ||
elif fmt == Format.NCL: | ||
x = x.flatten(start_dim=2) | ||
return x | ||
|
||
|
||
def nhwc_to(x: mindspore.Tensor, fmt: Format): | ||
if fmt == Format.NCHW: | ||
x = x.permute(0, 3, 1, 2) | ||
elif fmt == Format.NLC: | ||
x = x.flatten(start_dim=1, end_dim=2) | ||
elif fmt == Format.NCL: | ||
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1)) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
|
||
import mindspore as ms | ||
from mindspore import nn, ops | ||
|
||
|
||
class PatchDropout(nn.Cell): | ||
""" | ||
https://arxiv.org/abs/2212.00794 | ||
""" | ||
def __init__( | ||
self, | ||
prob: float = 0.5, | ||
num_prefix_tokens: int = 1, | ||
ordered: bool = False, | ||
return_indices: bool = False, | ||
): | ||
super().__init__() | ||
assert 0 <= prob < 1. | ||
self.prob = prob | ||
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) | ||
self.ordered = ordered | ||
self.return_indices = return_indices | ||
self.sort = ops.Sort() | ||
|
||
def forward(self, x): | ||
if not self.training or self.prob == 0.: | ||
if self.return_indices: | ||
return x, None | ||
return x | ||
|
||
if self.num_prefix_tokens: | ||
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] | ||
else: | ||
prefix_tokens = None | ||
|
||
B = x.shape[0] | ||
L = x.shape[1] | ||
num_keep = max(1, int(L * (1. - self.prob))) | ||
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) | ||
keep_indices = indices[:, :num_keep] | ||
if self.ordered: | ||
# NOTE does not need to maintain patch order in typical transformer use, | ||
# but possibly useful for debug / visualization | ||
keep_indices, _ = self.sort(keep_indices) | ||
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2])) | ||
x = ops.gather_elements(x, dim=1, index=keep_indices) | ||
|
||
if prefix_tokens is not None: | ||
x = ops.concat((prefix_tokens, x), axis=1) | ||
|
||
if self.return_indices: | ||
return x, keep_indices | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.