Skip to content

Commit

Permalink
6676 port generative networks transformer (Project-MONAI#7300)
Browse files Browse the repository at this point in the history
Towards Project-MONAI#6676  .

### Description

Adds a simple decoder-only transformer architecture.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham committed Jan 30, 2024
1 parent 20a57e7 commit 5e2466d
Show file tree
Hide file tree
Showing 4 changed files with 393 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,11 @@ Nets
.. autoclass:: VarAutoEncoder
:members:

`DecoderOnlyTransformer`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DecoderOnlyTransformer
:members:

`ViT`
~~~~~
.. autoclass:: ViT
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .transformer import DecoderOnlyTransformer
from .unet import UNet, Unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
Expand Down
314 changes: 314 additions & 0 deletions monai/networks/nets/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.mlp import MLPBlock
from monai.utils import optional_import

xops, has_xformers = optional_import("xformers.ops")
__all__ = ["DecoderOnlyTransformer"]


class _SABlock(nn.Module):
"""
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
use of this block as support is not guaranteed. For more information see:
https://github.com/Project-MONAI/MONAI/issues/7227
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: dropout ratio. Defaults to no dropout.
qkv_bias: bias term for the qkv linear layer.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.causal = causal
self.sequence_length = sequence_length
self.with_cross_attention = with_cross_attention
self.use_flash_attention = use_flash_attention

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
self.dropout_rate = dropout_rate

if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")

# key, query, value projections
self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)

# regularization
self.drop_weights = nn.Dropout(dropout_rate)
self.drop_output = nn.Dropout(dropout_rate)

# output projection
self.out_proj = nn.Linear(hidden_size, hidden_size)

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query = self.to_q(x)

kv = context if context is not None else x
_, kv_t, _ = kv.size()
key = self.to_k(kv)
value = self.to_v(kv)

query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs)
key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
y: torch.Tensor
if self.use_flash_attention:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
y = xops.memory_efficient_attention(
query=query,
key=key,
value=value,
scale=self.scale,
p=self.dropout_rate,
attn_bias=xops.LowerTriangularMask() if self.causal else None,
)

else:
query = query.transpose(1, 2) # (b, nh, t, hs)
key = key.transpose(1, 2) # (b, nh, kv_t, hs)
value = value.transpose(1, 2) # (b, nh, kv_t, hs)

# manual implementation of attention
query = query * self.scale
attention_scores = query @ key.transpose(-2, -1)

if self.causal:
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.drop_weights(attention_probs)
y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)

y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs)

y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side

y = self.out_proj(y)
y = self.drop_output(y)
return y


class _TransformerBlock(nn.Module):
"""
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
use of this block as support is not guaranteed. For more information see:
https://github.com/Project-MONAI/MONAI/issues/7227
A transformer block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
self.with_cross_attention = with_cross_attention
super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

self.norm1 = nn.LayerNorm(hidden_size)
self.attn = _SABlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)

if self.with_cross_attention:
self.norm2 = nn.LayerNorm(hidden_size)
self.cross_attn = _SABlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
with_cross_attention=with_cross_attention,
causal=False,
use_flash_attention=use_flash_attention,
)
self.norm3 = nn.LayerNorm(hidden_size)
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
if self.with_cross_attention:
x = x + self.cross_attn(self.norm2(x), context=context)
x = x + self.mlp(self.norm3(x))
return x


class AbsolutePositionalEmbedding(nn.Module):
"""Absolute positional embedding.
Args:
max_seq_len: Maximum sequence length.
embedding_dim: Dimensionality of the embedding.
"""

def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
self.embedding = nn.Embedding(max_seq_len, embedding_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = x.size()
positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
embedding: torch.Tensor = self.embedding(positions)
return embedding


class DecoderOnlyTransformer(nn.Module):
"""Decoder-only (Autoregressive) Transformer model.
Args:
num_tokens: Number of tokens in the vocabulary.
max_seq_len: Maximum sequence length.
attn_layers_dim: Dimensionality of the attention layers.
attn_layers_depth: Number of attention layers.
attn_layers_heads: Number of attention heads.
with_cross_attention: Whether to use cross attention for conditioning.
embedding_dropout_rate: Dropout rate for the embedding.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
self,
num_tokens: int,
max_seq_len: int,
attn_layers_dim: int,
attn_layers_depth: int,
attn_layers_heads: int,
with_cross_attention: bool = False,
embedding_dropout_rate: float = 0.0,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.num_tokens = num_tokens
self.max_seq_len = max_seq_len
self.attn_layers_dim = attn_layers_dim
self.attn_layers_depth = attn_layers_depth
self.attn_layers_heads = attn_layers_heads
self.with_cross_attention = with_cross_attention

self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
self.embedding_dropout = nn.Dropout(embedding_dropout_rate)

self.blocks = nn.ModuleList(
[
_TransformerBlock(
hidden_size=attn_layers_dim,
mlp_dim=attn_layers_dim * 4,
num_heads=attn_layers_heads,
dropout_rate=0.0,
qkv_bias=False,
causal=True,
sequence_length=max_seq_len,
with_cross_attention=with_cross_attention,
use_flash_attention=use_flash_attention,
)
for _ in range(attn_layers_depth)
]
)

self.to_logits = nn.Linear(attn_layers_dim, num_tokens)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
tok_emb = self.token_embeddings(x)
pos_emb = self.position_embeddings(x)
x = self.embedding_dropout(tok_emb + pos_emb)

for block in self.blocks:
x = block(x, context=context)
logits: torch.Tensor = self.to_logits(x)
return logits
Loading

0 comments on commit 5e2466d

Please sign in to comment.