-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of a Masked Autoencoder for representation learning (#…
…8152) This follows a previous PR (#7598). In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model. In the official masked autoencoder implementation, noise is first generated and then sorted twice using `torch.argsort`. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices. In our implementation, we use `torch.multinomial` to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder. **Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.** ### Description Implementation of the Masked Autoencoder as described in the paper: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/pdf/2111.06377.pdf) from Kaiming et al. Its effectiveness has already been demonstrated in the literature for medical tasks in the paper [Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation](https://arxiv.org/abs/2203.05573). The PR contains the architecture and associated unit tests. **Note:** The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think? ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Co-authored-by: YunLiu <[email protected]>
- Loading branch information
1 parent
e73257c
commit 20372f0
Showing
4 changed files
with
377 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock | ||
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding | ||
from monai.networks.blocks.transformerblock import TransformerBlock | ||
from monai.networks.layers import trunc_normal_ | ||
from monai.utils import ensure_tuple_rep | ||
from monai.utils.module import look_up_option | ||
|
||
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} | ||
|
||
__all__ = ["MaskedAutoEncoderViT"] | ||
|
||
|
||
class MaskedAutoEncoderViT(nn.Module): | ||
""" | ||
Masked Autoencoder (ViT), based on: "Kaiming et al., | ||
Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>" | ||
Only a subset of the patches passes through the encoder. The decoder tries to reconstruct | ||
the masked patches, resulting in improved training speed. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
img_size: Sequence[int] | int, | ||
patch_size: Sequence[int] | int, | ||
hidden_size: int = 768, | ||
mlp_dim: int = 512, | ||
num_layers: int = 12, | ||
num_heads: int = 12, | ||
masking_ratio: float = 0.75, | ||
decoder_hidden_size: int = 384, | ||
decoder_mlp_dim: int = 512, | ||
decoder_num_layers: int = 4, | ||
decoder_num_heads: int = 12, | ||
proj_type: str = "conv", | ||
pos_embed_type: str = "sincos", | ||
decoder_pos_embed_type: str = "sincos", | ||
dropout_rate: float = 0.0, | ||
spatial_dims: int = 3, | ||
qkv_bias: bool = False, | ||
save_attn: bool = False, | ||
) -> None: | ||
""" | ||
Args: | ||
in_channels: dimension of input channels or the number of channels for input. | ||
img_size: dimension of input image. | ||
patch_size: dimension of patch size | ||
hidden_size: dimension of hidden layer. Defaults to 768. | ||
mlp_dim: dimension of feedforward layer. Defaults to 512. | ||
num_layers: number of transformer blocks. Defaults to 12. | ||
num_heads: number of attention heads. Defaults to 12. | ||
masking_ratio: ratio of patches to be masked. Defaults to 0.75. | ||
decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384. | ||
decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512. | ||
decoder_num_layers: number of transformer blocks for decoder. Defaults to 4. | ||
decoder_num_heads: number of attention heads for decoder. Defaults to 12. | ||
proj_type: position embedding layer type. Defaults to "conv". | ||
pos_embed_type: position embedding layer type. Defaults to "sincos". | ||
decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos". | ||
dropout_rate: fraction of the input units to drop. Defaults to 0.0. | ||
spatial_dims: number of spatial dimensions. Defaults to 3. | ||
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. | ||
save_attn: to make accessible the attention in self attention block. Defaults to False. | ||
Examples:: | ||
# for single channel input with image size of (96,96,96), and sin-cos positional encoding | ||
>>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16), | ||
pos_embed_type='sincos') | ||
# for 3-channel with image size of (128,128,128) and a learnable positional encoding | ||
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable') | ||
# for 3-channel with image size of (224,224) and a masking ratio of 0.25 | ||
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25, | ||
spatial_dims=2) | ||
""" | ||
|
||
super().__init__() | ||
|
||
if not (0 <= dropout_rate <= 1): | ||
raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.") | ||
|
||
if hidden_size % num_heads != 0: | ||
raise ValueError("hidden_size should be divisible by num_heads.") | ||
|
||
if decoder_hidden_size % decoder_num_heads != 0: | ||
raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.") | ||
|
||
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) | ||
self.img_size = ensure_tuple_rep(img_size, spatial_dims) | ||
self.spatial_dims = spatial_dims | ||
for m, p in zip(self.img_size, self.patch_size): | ||
if m % p != 0: | ||
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") | ||
|
||
self.decoder_hidden_size = decoder_hidden_size | ||
|
||
if masking_ratio <= 0 or masking_ratio >= 1: | ||
raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.") | ||
|
||
self.masking_ratio = masking_ratio | ||
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) | ||
|
||
self.patch_embedding = PatchEmbeddingBlock( | ||
in_channels=in_channels, | ||
img_size=img_size, | ||
patch_size=patch_size, | ||
hidden_size=hidden_size, | ||
num_heads=num_heads, | ||
proj_type=proj_type, | ||
pos_embed_type=pos_embed_type, | ||
dropout_rate=dropout_rate, | ||
spatial_dims=self.spatial_dims, | ||
) | ||
blocks = [ | ||
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) | ||
for _ in range(num_layers) | ||
] | ||
self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size)) | ||
|
||
# decoder | ||
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size) | ||
|
||
self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) | ||
|
||
self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) | ||
self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size)) | ||
|
||
decoder_blocks = [ | ||
TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn) | ||
for _ in range(decoder_num_layers) | ||
] | ||
self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size)) | ||
self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels) | ||
|
||
self._init_weights() | ||
|
||
def _init_weights(self): | ||
""" | ||
similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and | ||
classification tokens | ||
""" | ||
if self.decoder_pos_embed_type == "none": | ||
pass | ||
elif self.decoder_pos_embed_type == "learnable": | ||
trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0) | ||
elif self.decoder_pos_embed_type == "sincos": | ||
grid_size = [] | ||
for in_size, pa_size in zip(self.img_size, self.patch_size): | ||
grid_size.append(in_size // pa_size) | ||
|
||
self.decoder_pos_embedding = build_sincos_position_embedding( | ||
grid_size, self.decoder_hidden_size, self.spatial_dims | ||
) | ||
|
||
else: | ||
raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.") | ||
|
||
# initialize patch_embedding like nn.Linear (instead of nn.Conv2d) | ||
trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0) | ||
trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0) | ||
|
||
def _masking(self, x, masking_ratio: float | None = None): | ||
batch_size, num_tokens, _ = x.shape | ||
percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio | ||
selected_indices = torch.multinomial( | ||
torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False | ||
) | ||
x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens | ||
mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device) | ||
mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0 | ||
|
||
return x_masked, selected_indices, mask | ||
|
||
def forward(self, x, masking_ratio: float | None = None): | ||
x = self.patch_embedding(x) | ||
x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio) | ||
|
||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) | ||
x = torch.cat((cls_tokens, x), dim=1) | ||
|
||
x = self.blocks(x) | ||
|
||
# decoder | ||
x = self.decoder_embed(x) | ||
|
||
x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1) | ||
x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token | ||
x_ = x_ + self.decoder_pos_embedding | ||
x = torch.cat([x[:, :1, :], x_], dim=1) | ||
x = self.decoder_blocks(x) | ||
x = self.decoder_pred(x) | ||
|
||
x = x[:, 1:, :] | ||
return x, mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.networks import eval_mode | ||
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT | ||
from tests.utils import skip_if_quick | ||
|
||
TEST_CASE_MaskedAutoEncoderViT = [] | ||
for masking_ratio in [0.5]: | ||
for dropout_rate in [0.6]: | ||
for in_channels in [4]: | ||
for hidden_size in [768]: | ||
for img_size in [96, 128]: | ||
for patch_size in [16]: | ||
for num_heads in [12]: | ||
for mlp_dim in [3072]: | ||
for num_layers in [4]: | ||
for decoder_hidden_size in [384]: | ||
for decoder_mlp_dim in [512]: | ||
for decoder_num_layers in [4]: | ||
for decoder_num_heads in [16]: | ||
for pos_embed_type in ["sincos", "learnable"]: | ||
for proj_type in ["conv", "perceptron"]: | ||
for nd in (2, 3): | ||
test_case = [ | ||
{ | ||
"in_channels": in_channels, | ||
"img_size": (img_size,) * nd, | ||
"patch_size": (patch_size,) * nd, | ||
"hidden_size": hidden_size, | ||
"mlp_dim": mlp_dim, | ||
"num_layers": num_layers, | ||
"decoder_hidden_size": decoder_hidden_size, | ||
"decoder_mlp_dim": decoder_mlp_dim, | ||
"decoder_num_layers": decoder_num_layers, | ||
"decoder_num_heads": decoder_num_heads, | ||
"pos_embed_type": pos_embed_type, | ||
"masking_ratio": masking_ratio, | ||
"decoder_pos_embed_type": pos_embed_type, | ||
"num_heads": num_heads, | ||
"proj_type": proj_type, | ||
"dropout_rate": dropout_rate, | ||
}, | ||
(2, in_channels, *([img_size] * nd)), | ||
( | ||
2, | ||
(img_size // patch_size) ** nd, | ||
in_channels * (patch_size**nd), | ||
), | ||
] | ||
if nd == 2: | ||
test_case[0]["spatial_dims"] = 2 # type: ignore | ||
TEST_CASE_MaskedAutoEncoderViT.append(test_case) | ||
|
||
TEST_CASE_ill_args = [ | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}], | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}], | ||
[{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}], | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}], | ||
[{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}], | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}], | ||
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}], | ||
] | ||
|
||
|
||
@skip_if_quick | ||
class TestMaskedAutoencoderViT(unittest.TestCase): | ||
|
||
@parameterized.expand(TEST_CASE_MaskedAutoEncoderViT) | ||
def test_shape(self, input_param, input_shape, expected_shape): | ||
net = MaskedAutoEncoderViT(**input_param) | ||
with eval_mode(net): | ||
result, _ = net(torch.randn(input_shape)) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
def test_frozen_pos_embedding(self): | ||
net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16)) | ||
|
||
self.assertEqual(net.decoder_pos_embedding.requires_grad, False) | ||
|
||
@parameterized.expand(TEST_CASE_ill_args) | ||
def test_ill_arg(self, input_param): | ||
with self.assertRaises(ValueError): | ||
MaskedAutoEncoderViT(**input_param) | ||
|
||
def test_access_attn_matrix(self): | ||
# input format | ||
in_channels = 1 | ||
img_size = (96, 96, 96) | ||
patch_size = (16, 16, 16) | ||
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) | ||
|
||
# no data in the matrix | ||
no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) | ||
no_matrix_acess_blk(torch.randn(in_shape)) | ||
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor) | ||
# no of elements is zero | ||
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0 | ||
|
||
# be able to acess the attention matrix | ||
matrix_acess_blk = MaskedAutoEncoderViT( | ||
in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True | ||
) | ||
matrix_acess_blk(torch.randn(in_shape)) | ||
|
||
assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55) | ||
|
||
def test_masking_ratio(self): | ||
# input format | ||
in_channels = 1 | ||
img_size = (96, 96, 96) | ||
patch_size = (16, 16, 16) | ||
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) | ||
|
||
# masking ratio 0.25 | ||
masking_ratio_blk = MaskedAutoEncoderViT( | ||
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True | ||
) | ||
masking_ratio_blk(torch.randn(in_shape)) | ||
desired_num_tokens = int( | ||
(img_size[0] // patch_size[0]) | ||
* (img_size[1] // patch_size[1]) | ||
* (img_size[2] // patch_size[2]) | ||
* (1 - 0.25) | ||
) | ||
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens | ||
|
||
# masking ratio 0.33 | ||
masking_ratio_blk = MaskedAutoEncoderViT( | ||
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True | ||
) | ||
masking_ratio_blk(torch.randn(in_shape)) | ||
desired_num_tokens = int( | ||
(img_size[0] // patch_size[0]) | ||
* (img_size[1] // patch_size[1]) | ||
* (img_size[2] // patch_size[2]) | ||
* (1 - 0.33) | ||
) | ||
|
||
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |