Skip to content

Commit

Permalink
refacto
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Apr 28, 2024
1 parent 7d82d8a commit 0132e6b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 186 deletions.
163 changes: 163 additions & 0 deletions monai/networks/blocks/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import torch.nn.functional as F
from torch import nn

from monai.utils import optional_import

rearrange, _ = optional_import("einops", name="rearrange")


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -126,3 +130,162 @@ def add_decomposed_rel_pos(
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)

return attn


def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]:
"""
Partition into non-overlapping windows with padding if needed. Support 2D and 3D.
Args:
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size)
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D)
window_size (int): window size
Returns:
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C].
with n = 1...len(input_size) and window_size_i == window_size.
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size)
"""
if x.shape[1] != int(torch.prod(torch.tensor(input_size))):
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product")

if len(input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1])
x, pad_hw = window_partition_2d(x, window_size)
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size)
return x, pad_hw
elif len(input_size) == 3:
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2])
x, pad_hwd = window_partition_3d(x, window_size)
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size)
return x, pad_hwd
else:
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ")


def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed. Support only 2D.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
batch, h, w, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
hp, wp = h + pad_h, w + pad_w

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows, (hp, wp)


def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Partition into non-overlapping windows with padding if needed. 3d implementation.
Args:
x (tensor): input tokens with [B, H, W, D, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C].
(Hp, Wp, Dp): padded height, width and depth before partition
"""
batch, h, w, d, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
pad_d = (window_size - d % window_size) % window_size
if pad_h > 0 or pad_w > 0 or pad_d > 0:
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h))
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
return windows, (hp, wp, dp)


def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C].
with n = 1...len(spatial_dims) and window_size == window_size_i
window_size (int): window size.
pad (Tuple): padded spatial dims (H, W) or (H, W, D)
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding.
Returns:
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C].
"""
x: torch.Tensor
if len(spatial_dims) == 2:
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size)
x = window_unpartition_2d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1])
return x
elif len(spatial_dims) == 3:
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size)
x = window_unpartition_3d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2])
return x
else:
raise ValueError()


def window_unpartition_2d(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (hp, wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
hp, wp = pad_hw
h, w = hw
batch = windows.shape[0] // (hp * wp // window_size // window_size)
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1)

if hp > h or wp > w:
x = x[:, :h, :w, :].contiguous()
return x


def window_unpartition_3d(
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding. 3d implementation.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C].
window_size (int): window size.
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp).
hwd (Tuple): original height, width and depth (H, W, D) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, D, C].
"""
hp, wp, dp = pad_hwd
h, w, d = hwd
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size)
x = windows.view(
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1)

if hp > h or wp > w or dp > d:
x = x[:, :h, :w, :d, :].contiguous()
return x
33 changes: 28 additions & 5 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.networks.blocks.attention_utils import window_partition, window_unpartition
from monai.utils import optional_import

xops, has_xformers = optional_import("xformers.ops")
Expand All @@ -26,9 +27,14 @@

class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
<<<<<<< HEAD
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
=======
and some additional features:
- local window attention
>>>>>>> f7aca872 (refacto)
"""

def __init__(
Expand All @@ -43,6 +49,7 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
use_flash_attention: bool = False,
window_size: int = 0,
) -> None:
"""
Args:
Expand All @@ -53,11 +60,13 @@ def __init__(
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
positional parameter size. Has to be set if local window attention is used
causal (bool): wether to use causal attention. If true `sequence_length` has to be set
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
If 0, global attention used.
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
"""

super().__init__()
Expand All @@ -81,6 +90,10 @@ def __init__(

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")
if window_size > 0 and len(input_size) not in [2, 3]:
raise ValueError(
"If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)"
)

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
Expand All @@ -101,6 +114,7 @@ def __init__(
if rel_pos_embedding is not None
else None
)
self.window_size = window_size
self.input_size = input_size

if causal and sequence_length is not None:
Expand All @@ -119,6 +133,10 @@ def forward(self, x: torch.Tensor):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""

if self.window_size > 0:
x, pad = window_partition(x, self.window_size, self.input_size)

_, t, _ = x.size()
output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h
q, k, v = output[0], output[1], output[2]
Expand Down Expand Up @@ -156,4 +174,9 @@ def forward(self, x: torch.Tensor):
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)

# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad, self.input_size)

return x
Loading

0 comments on commit 0132e6b

Please sign in to comment.