From 43f0b5246870fc3f95d9a434af9d0f8a2429f9b4 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 30 Dec 2023 17:47:31 +0100 Subject: [PATCH] transformer block local window attention Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 73 ++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ddf959dad2..baad0780a7 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,7 +11,11 @@ from __future__ import annotations +from typing import Tuple + +import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.blocks.mlp import MLPBlock from monai.networks.blocks.selfattention import SABlock @@ -31,6 +35,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + window_size: int = 0, ) -> None: """ Args: @@ -40,6 +45,10 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. + If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0). + If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C]. + See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. """ @@ -55,8 +64,70 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size) self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) + self.window_size = window_size def forward(self, x): - x = x + self.attn(self.norm1(x)) + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + h, w = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (h, w)) + + x = shortcut + x x = x + self.mlp(self.norm2(x)) return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. Support only 2D. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + batch, h, w, c = x.shape + + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + hp, wp = h + pad_h, w + pad_w + + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows, (hp, wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (hp, wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + hp, wp = pad_hw + h, w = hw + batch = windows.shape[0] // (hp * wp // window_size // window_size) + x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) + + if hp > h or wp > w: + x = x[:, :h, :w, :].contiguous() + return x