Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention block] relative positional embedding #7346

Merged
merged 13 commits into from
Jan 18, 2024
6 changes: 6 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ Blocks
.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
:members:

`Attention utilities`
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.attention_utils
.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos
.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos

N-Dim Fourier Transform
~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.fft_utils_t
Expand Down
128 changes: 128 additions & 0 deletions monai/networks/blocks/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Tuple

import torch
import torch.nn.functional as F
from torch import nn


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.

Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).

Returns:
Extracted positional embeddings according to relative positions.
"""
rel_pos_resized: torch.Tensor = torch.Tensor()
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos

# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

return rel_pos_resized[relative_coords.long()]


def add_decomposed_rel_pos(
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
) -> torch.Tensor:
r"""
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Only 2D and 3D are supported.

Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
`d` apart will have the same embedding value (unlike absolute positional embedding).

.. math::
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale

where

.. math::
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}

with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
respectively spatial positions of element :math:`i` and :math:`j`

When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:

.. math::
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}

with :math:`n = 1...dim`

Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.

Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).

Returns:
attn (Tensor): attention logits with added relative positional embeddings.
"""
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])

batch, _, dim = q.shape

if len(rel_pos_lst) == 2:
q_h, q_w = q_size[:2]
k_h, k_w = k_size[:2]
r_q = q.reshape(batch, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)

attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
batch, q_h * q_w, k_h * k_w
)
elif len(rel_pos_lst) == 3:
q_h, q_w, q_d = q_size[:3]
k_h, k_w, k_d = k_size[:3]

rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])

r_q = q.reshape(batch, q_h, q_w, q_d, dim)
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)

attn = (
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
+ rel_h[:, :, :, :, None, None]
+ rel_w[:, :, :, None, :, None]
+ rel_d[:, :, :, None, None, :]
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)

return attn
56 changes: 56 additions & 0 deletions monai/networks/blocks/rel_pos_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Iterable, Tuple

import torch
from torch import nn

from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
from monai.utils.misc import ensure_tuple_size


class DecomposedRelativePosEmbedding(nn.Module):
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
"""
Args:
s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
c_dim (int): channel dimension
num_heads(int): number of attention heads
"""
super().__init__()

# validate inputs
if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")

self.s_input_dims = s_input_dims
self.c_dim = c_dim
self.num_heads = num_heads
self.rel_pos_arr = nn.ParameterList(
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
)

def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
""""""
batch = x.shape[0]
h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)

vgrau98 marked this conversation as resolved.
Show resolved Hide resolved
att_mat = add_decomposed_rel_pos(
att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
q.contiguous().view(batch * self.num_heads, h * w * d, -1),
self.rel_pos_arr,
(h, w) if d == 1 else (h, w, d),
(h, w) if d == 1 else (h, w, d),
)

att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
return att_mat
33 changes: 31 additions & 2 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
Expand All @@ -23,6 +26,7 @@ class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
"""

def __init__(
Expand All @@ -32,13 +36,19 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

"""
Expand All @@ -62,11 +72,30 @@ def __init__(
self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
if rel_pos_embedding is not None
else None
)
self.input_size = input_size

def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C

def forward(self, x):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
13 changes: 12 additions & 1 deletion monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def use_factory(fact_args):
from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import ComponentStore, look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]


class LayerFactory(ComponentStore):
Expand Down Expand Up @@ -201,6 +201,10 @@ def split_args(args):
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
RelPosEmbedding = LayerFactory(
name="Relative positional embedding layers",
description="Factory for creating relative positional embedding factory",
)


@Dropout.factory_function("dropout")
Expand Down Expand Up @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
"""
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]


@RelPosEmbedding.factory_function("decomposed")
def decomposed_rel_pos_embedding() -> type[nn.Module]:
from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding

return DecomposedRelativePosEmbedding
15 changes: 14 additions & 1 deletion monai/networks/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

from __future__ import annotations

from typing import Optional

import torch.nn

from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
from monai.utils import has_option

__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
Expand Down Expand Up @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
pool_name, pool_args = split_args(name)
pool_type = Pool[pool_name, spatial_dims]
return pool_type(**pool_args)


def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
embedding_name, embedding_args = split_args(name)
embedding_type = RelPosEmbedding[embedding_name]
# create a dictionary with the default values which can be overridden by embedding_args
kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
# filter out unused argument names
kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}

return embedding_type(**kw_args)
21 changes: 15 additions & 6 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.networks import eval_mode
from monai.networks.blocks.selfattention import SABlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import

einops, has_einops = optional_import("einops")
Expand All @@ -28,12 +29,20 @@
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
test_case = [
{"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):
Expand Down
Loading