Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Genai SA block integration #7720

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 64 additions & 14 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from __future__ import annotations

from typing import Optional, Tuple
import warnings

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import

xops, has_xformers = optional_import("xformers.ops")
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


Expand All @@ -38,6 +40,9 @@ def __init__(
save_attn: bool = False,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
causal: bool = False,
sequence_length: int | None = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -49,6 +54,8 @@ def __init__(
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
causal (bool): wether to use causal attention. If true `sequence_length` has to be set
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

"""
Expand All @@ -61,15 +68,32 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and rel_pos_embedding is not None:
self.use_flash_attention = False
warnings.warn(
"flash attention set to `False`: flash attention can't be used with relative position embedding. Set `rel_pos_embedding` to `None` to use flash attention"
)
else:
self.use_flash_attention = use_flash_attention

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.dropout_rate = dropout_rate
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.causal = causal
self.sequence_length = sequence_length
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
Expand All @@ -79,6 +103,14 @@ def __init__(
)
self.input_size = input_size

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor

def forward(self, x: torch.Tensor):
"""
Args:
Expand All @@ -87,22 +119,40 @@ def forward(self, x: torch.Tensor):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
_, t, _ = x.size()
output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h
q, k, v = output[0], output[1], output[2]
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
if self.use_flash_attention:
x = xops.memory_efficient_attention(
query=q.contiguous(),
key=k.contiguous(),
value=v.contiguous(),
scale=self.scale,
p=self.dropout_rate,
attn_bias=xops.LowerTriangularMask() if self.causal else None,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = (
self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
)
# apply causal mask if set
att_mat = (
att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat
)

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
56 changes: 44 additions & 12 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,29 @@
from monai.utils import optional_import

einops, has_einops = optional_import("einops")
xops, has_xformers = optional_import("xformers.ops")

TEST_CASE_SABLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
for causal in [False, True]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
"causal": causal,
"sequence_length": 512,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):
Expand All @@ -54,6 +58,34 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

@skipUnless(has_xformers, "Requires xformers")
def test_flash_attention(self):
hidden_size = 360
num_heads = 4
dropout_rate = 0
input_shape = (2, 512, hidden_size)
expected_shape = (2, 512, hidden_size)
flash_attention_block = SABlock(hidden_size, num_heads, dropout_rate, use_flash_attention=True)
# flash attention set to false because of conflict using relative position embedding at the same time
no_flash_attention_block = SABlock(
hidden_size,
num_heads,
dropout_rate,
use_flash_attention=True,
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
sequence_length=512,
input_size=([16, 32]),
)

self.assertFalse(no_flash_attention_block.use_flash_attention)

with eval_mode(flash_attention_block):
result = flash_attention_block(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
with eval_mode(no_flash_attention_block):
result = no_flash_attention_block(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)
Expand Down
Loading