Skip to content

Commit

Permalink
modify intern vit
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py committed Nov 18, 2024
1 parent 301d21c commit 07cb82a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def forward(
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=0.0)
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def forward(
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=0.0)
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
Expand Down
25 changes: 17 additions & 8 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.attention.selector import _Backend
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand All @@ -24,11 +25,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend

NORM2FN = {
'rms_norm': RMSNorm,
Expand Down Expand Up @@ -186,6 +183,8 @@ def __init__(
prefix=f"{prefix}.proj",
)

self.attn_backend = get_vit_attn_backend()

def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
Expand All @@ -211,9 +210,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)

x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN):
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)

x = x.view(B, N, -1)
x, _ = self.proj(x)
return x

Expand Down Expand Up @@ -362,7 +371,7 @@ def _init_attn(
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads

if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def forward(
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=0.0)
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(batch_size, q_len, -1)
Expand Down

0 comments on commit 07cb82a

Please sign in to comment.