Skip to content

Commit

Permalink
transformer block local window attention
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Dec 30, 2023
1 parent b3d7a48 commit 43f0b52
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from __future__ import annotations

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.mlp import MLPBlock
from monai.networks.blocks.selfattention import SABlock
Expand All @@ -31,6 +35,7 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
window_size: int = 0,
) -> None:
"""
Args:
Expand All @@ -40,6 +45,10 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0).
If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C].
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
"""

Expand All @@ -55,8 +64,70 @@ def __init__(
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
self.norm2 = nn.LayerNorm(hidden_size)
self.window_size = window_size

def forward(self, x):
x = x + self.attn(self.norm1(x))
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
h, w = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)

x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (h, w))

x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed. Support only 2D.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
batch, h, w, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
hp, wp = h + pad_h, w + pad_w

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows, (hp, wp)


def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (hp, wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
hp, wp = pad_hw
h, w = hw
batch = windows.shape[0] // (hp * wp // window_size // window_size)
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1)

if hp > h or wp > w:
x = x[:, :h, :w, :].contiguous()
return x

0 comments on commit 43f0b52

Please sign in to comment.