Skip to content

Commit

Permalink
[Model] Remove transformers attention porting in VITs (vllm-project#1…
Browse files Browse the repository at this point in the history
…0414)

Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Nov 18, 2024
1 parent 5be4e52 commit e7ebb66
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 102 deletions.
66 changes: 36 additions & 30 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention

from vllm.attention.selector import _Backend
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -21,11 +22,7 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down Expand Up @@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipParallelAttention(nn.Module):
class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -208,6 +205,12 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -231,11 +234,26 @@ def forward(
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

Expand Down Expand Up @@ -285,18 +303,11 @@ def __init__(
super().__init__()

# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config,
Expand Down Expand Up @@ -374,11 +385,6 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.config = config

self.embeddings = BlipVisionEmbeddings(config)
Expand Down Expand Up @@ -422,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)
Expand Down
65 changes: 36 additions & 29 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

from vllm.attention.selector import _Backend
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -23,11 +24,7 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down Expand Up @@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPParallelAttention(nn.Module):
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -237,6 +234,12 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -261,11 +264,26 @@ def forward(
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

Expand Down Expand Up @@ -311,17 +329,11 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = CLIPSdpaAttention(config)
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config,
Expand Down Expand Up @@ -461,11 +473,6 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
Expand All @@ -490,7 +497,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
Expand Down
32 changes: 22 additions & 10 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.attention.selector import _Backend
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand All @@ -24,11 +25,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend

NORM2FN = {
'rms_norm': RMSNorm,
Expand Down Expand Up @@ -186,6 +183,11 @@ def __init__(
prefix=f"{prefix}.proj",
)

self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")

def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
Expand All @@ -211,11 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)

x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

x, _ = self.proj(x)
return x
out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)

out = out.view(B, N, -1)
out, _ = self.proj(out)
return out


class InternSdpaAttention(nn.Module):
Expand Down Expand Up @@ -362,7 +374,7 @@ def _init_attn(
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads

if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend()
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
prefix=f"{prefix}.proj")

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend()
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
Expand Down
Loading

0 comments on commit e7ebb66

Please sign in to comment.