Skip to content

Commit

Permalink
Implementation of a Masked Autoencoder for representation learning (#…
Browse files Browse the repository at this point in the history
…8152)

This follows a previous PR (#7598).

In the previous PR, the official implementation was under a
non-compatible license. This is a clean-sheet implementation I
developed. The code is fairly straightforward, involving a transformer,
encoder, and decoder. The primary changes are in how masks are selected
and how patches are organized as they pass through the model.

In the official masked autoencoder implementation, noise is first
generated and then sorted twice using `torch.argsort`. This rearranges
the tokens and identifies which ones are retained, ultimately selecting
only a subset of the shuffled indices.

In our implementation, we use `torch.multinomial` to generate mask
indices, followed by simple boolean indexing to manage the sub-selection
of patches for encoding and the reordering with mask tokens in the
decoder.

**Let me know if you need a detailed, line-by-line explanation of the
new code, including how it works and how it differs from the previous
version.**
### Description

Implementation of the Masked Autoencoder as described in the paper:
[Masked Autoencoders Are Scalable Vision
Learners](https://arxiv.org/pdf/2111.06377.pdf) from Kaiming et al.

Its effectiveness has already been demonstrated in the literature for
medical tasks in the paper [Self Pre-training with Masked Autoencoders
for Medical Image Classification and
Segmentation](https://arxiv.org/abs/2203.05573).
The PR contains the architecture and associated unit tests.

**Note:** The output includes the prediction, which is a tensor of size:
($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$).
The mask is used to apply loss only to masked patches, but I'm not sure
it's the “best” output format, what do you think?

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
Lucas-rbnt and KumoLiu authored Nov 27, 2024
1 parent e73257c commit 20372f0
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ Nets
.. autoclass:: ViTAutoEnc
:members:

`MaskedAutoEncoderViT`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MaskedAutoEncoderViT
:members:

`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .masked_autoencoder_vit import MaskedAutoEncoderViT
from .mednext import (
MedNeXt,
MedNext,
Expand Down
211 changes: 211 additions & 0 deletions monai/networks/nets/masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import trunc_normal_
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}

__all__ = ["MaskedAutoEncoderViT"]


class MaskedAutoEncoderViT(nn.Module):
"""
Masked Autoencoder (ViT), based on: "Kaiming et al.,
Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
the masked patches, resulting in improved training speed.
"""

def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int = 768,
mlp_dim: int = 512,
num_layers: int = 12,
num_heads: int = 12,
masking_ratio: float = 0.75,
decoder_hidden_size: int = 384,
decoder_mlp_dim: int = 512,
decoder_num_layers: int = 4,
decoder_num_heads: int = 12,
proj_type: str = "conv",
pos_embed_type: str = "sincos",
decoder_pos_embed_type: str = "sincos",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels: dimension of input channels or the number of channels for input.
img_size: dimension of input image.
patch_size: dimension of patch size
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 512.
num_layers: number of transformer blocks. Defaults to 12.
num_heads: number of attention heads. Defaults to 12.
masking_ratio: ratio of patches to be masked. Defaults to 0.75.
decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
decoder_num_heads: number of attention heads for decoder. Defaults to 12.
proj_type: position embedding layer type. Defaults to "conv".
pos_embed_type: position embedding layer type. Defaults to "sincos".
decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
spatial_dims: number of spatial dimensions. Defaults to 3.
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
Examples::
# for single channel input with image size of (96,96,96), and sin-cos positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
pos_embed_type='sincos')
# for 3-channel with image size of (128,128,128) and a learnable positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
# for 3-channel with image size of (224,224) and a masking ratio of 0.25
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
spatial_dims=2)
"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

if decoder_hidden_size % decoder_num_heads != 0:
raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")

self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.img_size = ensure_tuple_rep(img_size, spatial_dims)
self.spatial_dims = spatial_dims
for m, p in zip(self.img_size, self.patch_size):
if m % p != 0:
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")

self.decoder_hidden_size = decoder_hidden_size

if masking_ratio <= 0 or masking_ratio >= 1:
raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")

self.masking_ratio = masking_ratio
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
img_size=img_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=self.spatial_dims,
)
blocks = [
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(num_layers)
]
self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))

# decoder
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)

self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))

self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))

decoder_blocks = [
TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(decoder_num_layers)
]
self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)

self._init_weights()

def _init_weights(self):
"""
similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
classification tokens
"""
if self.decoder_pos_embed_type == "none":
pass
elif self.decoder_pos_embed_type == "learnable":
trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
elif self.decoder_pos_embed_type == "sincos":
grid_size = []
for in_size, pa_size in zip(self.img_size, self.patch_size):
grid_size.append(in_size // pa_size)

self.decoder_pos_embedding = build_sincos_position_embedding(
grid_size, self.decoder_hidden_size, self.spatial_dims
)

else:
raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")

# initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)

def _masking(self, x, masking_ratio: float | None = None):
batch_size, num_tokens, _ = x.shape
percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
selected_indices = torch.multinomial(
torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
)
x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0

return x_masked, selected_indices, mask

def forward(self, x, masking_ratio: float | None = None):
x = self.patch_embedding(x)
x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)

cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

x = self.blocks(x)

# decoder
x = self.decoder_embed(x)

x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
x_ = x_ + self.decoder_pos_embedding
x = torch.cat([x[:, :1, :], x_], dim=1)
x = self.decoder_blocks(x)
x = self.decoder_pred(x)

x = x[:, 1:, :]
return x, mask
160 changes: 160 additions & 0 deletions tests/test_masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT
from tests.utils import skip_if_quick

TEST_CASE_MaskedAutoEncoderViT = []
for masking_ratio in [0.5]:
for dropout_rate in [0.6]:
for in_channels in [4]:
for hidden_size in [768]:
for img_size in [96, 128]:
for patch_size in [16]:
for num_heads in [12]:
for mlp_dim in [3072]:
for num_layers in [4]:
for decoder_hidden_size in [384]:
for decoder_mlp_dim in [512]:
for decoder_num_layers in [4]:
for decoder_num_heads in [16]:
for pos_embed_type in ["sincos", "learnable"]:
for proj_type in ["conv", "perceptron"]:
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"decoder_hidden_size": decoder_hidden_size,
"decoder_mlp_dim": decoder_mlp_dim,
"decoder_num_layers": decoder_num_layers,
"decoder_num_heads": decoder_num_heads,
"pos_embed_type": pos_embed_type,
"masking_ratio": masking_ratio,
"decoder_pos_embed_type": pos_embed_type,
"num_heads": num_heads,
"proj_type": proj_type,
"dropout_rate": dropout_rate,
},
(2, in_channels, *([img_size] * nd)),
(
2,
(img_size // patch_size) ** nd,
in_channels * (patch_size**nd),
),
]
if nd == 2:
test_case[0]["spatial_dims"] = 2 # type: ignore
TEST_CASE_MaskedAutoEncoderViT.append(test_case)

TEST_CASE_ill_args = [
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}],
[{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}],
]


@skip_if_quick
class TestMaskedAutoencoderViT(unittest.TestCase):

@parameterized.expand(TEST_CASE_MaskedAutoEncoderViT)
def test_shape(self, input_param, input_shape, expected_shape):
net = MaskedAutoEncoderViT(**input_param)
with eval_mode(net):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_frozen_pos_embedding(self):
net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16))

self.assertEqual(net.decoder_pos_embedding.requires_grad, False)

@parameterized.expand(TEST_CASE_ill_args)
def test_ill_arg(self, input_param):
with self.assertRaises(ValueError):
MaskedAutoEncoderViT(**input_param)

def test_access_attn_matrix(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# no data in the matrix
no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
no_matrix_acess_blk(torch.randn(in_shape))
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)
# no of elements is zero
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True
)
matrix_acess_blk(torch.randn(in_shape))

assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55)

def test_masking_ratio(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# masking ratio 0.25
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.25)
)
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens

# masking ratio 0.33
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.33)
)

assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens


if __name__ == "__main__":
unittest.main()

0 comments on commit 20372f0

Please sign in to comment.