diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b2f0f5ea6953a..7369de79f5083 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,4 +1,3 @@ -import logging import math import re from array import array @@ -14,10 +13,8 @@ from torch.nn import functional as F from transformers import PretrainedConfig -import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata -from vllm.attention.selector import (_Backend, backend_name_to_enum, - get_global_forced_attn_backend) +from vllm.attention.selector import _Backend from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -43,12 +40,11 @@ from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import make_layers from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs -from vllm.platforms import current_platform from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.processor import get_processor -log = logging.getLogger(__name__) +from .utils import get_vit_attn_backend # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -190,35 +186,12 @@ def __init__( ) # Detect attention implementation. - selected_backend: Optional[_Backend] = get_global_forced_attn_backend() - if selected_backend is None: - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: - # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.get_device_capability()[0] >= 8 - if device_available: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - self._use_flash_attn = True - else: - log.warning( - "Current Molmo implementation has a bug with " - "`vllm-flash-attn` inside vision module, so we use " - "xformers backend instead. You can run `pip install " - "flash-attn to use flash-attention backend.") - self._use_flash_attn = False - else: - self._use_flash_attn = False - else: - if selected_backend == _Backend.FLASH_ATTN: - self._use_flash_attn = True - elif selected_backend == _Backend.XFORMERS: - self._use_flash_attn = False - else: - raise RuntimeError( - f"Molmo does not support {selected_backend} backend now.") + self.attn_backend: _Backend = get_vit_attn_backend() + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Molmo does not support {self.attn_backend} backend now.") def forward(self, inputs_q: torch.Tensor, @@ -240,10 +213,15 @@ def forward(self, xk = xk.view(*kv_shape) xv = xv.view(*kv_shape) - if self._use_flash_attn: + if self.attn_backend == _Backend.FLASH_ATTN: from flash_attn import flash_attn_func output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) - else: + elif self.attn_backend == _Backend.TORCH_SDPA: + xq, xk, xv = (rearrange(x, "b s h d -> b h s d") + for x in (xq, xk, xv)) + output = F.scaled_dot_product_attention(xq, xk, xv) + output = rearrange(output, "b h s d -> b s h d ") + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 94c7d65077701..f7d632a83cc33 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -39,10 +39,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( make_batched_images, make_batched_videos, smart_resize) -import vllm.envs as envs from vllm.attention import AttentionMetadata -from vllm.attention.selector import (_Backend, backend_name_to_enum, - get_global_forced_attn_backend) +from vllm.attention.selector import _Backend from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils @@ -63,14 +61,13 @@ MultiModalInputs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import get_processor -from vllm.utils import is_cpu from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_vit_attn_backend, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory) logger = init_logger(__name__) @@ -215,37 +212,12 @@ def __init__( quant_config=quant_config) # Detect attention implementation. - selected_backend: Optional[_Backend] = get_global_forced_attn_backend() - if selected_backend is None: - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: - # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.has_device_capability(80) - if device_available: - from transformers.utils import is_flash_attn_2_available - - if is_flash_attn_2_available(): - self._use_flash_attn = True - else: - logger.warning( - "Current Qwen2-VL implementation has a bug with " - "`vllm-flash-attn` inside vision module, so we use " - "xformers backend instead. You can run `pip install " - "flash-attn to use flash-attention backend.") - self._use_flash_attn = False - else: - self._use_flash_attn = False - else: - if selected_backend == _Backend.FLASH_ATTN: - self._use_flash_attn = True - elif selected_backend == _Backend.XFORMERS: - self._use_flash_attn = False - else: - raise RuntimeError( - f"Qwen2-VL does not support {selected_backend} backend now." - ) + self.attn_backend: _Backend = get_vit_attn_backend() + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Qwen2-VL does not support {self.attn_backend} backend now.") def forward( self, @@ -274,7 +246,7 @@ def forward( q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - if self._use_flash_attn: + if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) from flash_attn import flash_attn_varlen_func @@ -295,7 +267,7 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif is_cpu(): + elif self.attn_backend == _Backend.TORCH_SDPA: seq_length = q.size(1) q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] attention_mask = torch.zeros([1, seq_length, seq_length], @@ -310,7 +282,7 @@ def forward( attention_mask, dropout_p=0.0) context_layer = rearrange(output, "b h s d -> b s h d ") - else: + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8aac9c0eb3a0e..9e2f5476f3aff 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -8,15 +8,22 @@ from torch.func import functional_call from transformers import PretrainedConfig +import vllm.envs as envs +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, SchedulerConfig) +from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available +from vllm.utils import is_cpu, is_pin_memory_available + +logger = init_logger(__name__) WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" @@ -487,3 +494,29 @@ def __getattr__(self, key: str): def __call__(self, *args: Any, **kwargs: Any) -> Any: llm = super().__getattr__(self.model_name) return llm(*args, **kwargs) + + +def get_vit_attn_backend() -> _Backend: + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.has_device_capability(80) + if device_available: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + selected_backend = _Backend.FLASH_ATTN + else: + logger.warning( + "Current `vllm-flash-attn` has a bug inside vision module, " + "so we use xformers backend instead. You can run " + "`pip install flash-attn` to use flash-attention backend.") + selected_backend = _Backend.XFORMERS + elif is_cpu(): + selected_backend = _Backend.TORCH_SDPA + else: + selected_backend = _Backend.XFORMERS + return selected_backend