From 8c1fb507052d385d94ac49a7388fd6db5d0069e7 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 19 Nov 2024 11:22:26 +0800 Subject: [PATCH] [Platform][Refactor] Extract func `get_default_attn_backend` to `Platform` (#10358) Signed-off-by: Mengqing Cao --- tests/kernels/test_attention_selector.py | 19 ++++---- vllm/attention/selector.py | 56 +++--------------------- vllm/model_executor/models/molmo.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/utils.py | 4 +- vllm/platforms/__init__.py | 1 + vllm/platforms/cpu.py | 10 ++++- vllm/platforms/hpu.py | 6 ++- vllm/platforms/interface.py | 19 ++++++++ vllm/platforms/openvino.py | 8 +++- vllm/platforms/rocm.py | 14 +++++- vllm/platforms/tpu.py | 12 ++++- vllm/platforms/xpu.py | 12 ++++- vllm/worker/enc_dec_model_runner.py | 3 +- 14 files changed, 99 insertions(+), 69 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 169ce040d370c..d37f95d48d5b2 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -5,6 +5,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use +from vllm.platforms import cpu, cuda, openvino, rocm from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch): override_backend_env_variable(monkeypatch, name) if device == "cpu": - with patch("vllm.attention.selector.current_platform.is_cpu", - return_value=True): + with patch("vllm.attention.selector.current_platform", + cpu.CpuPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform.is_rocm", - return_value=True): + with patch("vllm.attention.selector.current_platform", + rocm.RocmPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": - with patch("vllm.attention.selector.current_platform.is_openvino", - return_value=True): + with patch("vllm.attention.selector.current_platform", + openvino.OpenVinoPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) + with patch("vllm.attention.selector.current_platform", + cuda.CudaPlatform()): + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == name diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65d..d263839705690 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,3 @@ -import enum import os from contextlib import contextmanager from functools import lru_cache @@ -9,26 +8,12 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.utils import STR_BACKEND_ENV_VAR logger = init_logger(__name__) -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - FLASH_ATTN_VLLM_V1 = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - TORCH_SDPA = enum.auto() - OPENVINO = enum.auto() - FLASHINFER = enum.auto() - HPU_ATTN = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - - def backend_name_to_enum(backend_name: str) -> _Backend: assert backend_name is not None @@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int, if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) - if current_platform.is_cpu(): - if selected_backend != _Backend.TORCH_SDPA: - logger.info("Cannot use %s backend on CPU.", selected_backend) - return _Backend.TORCH_SDPA - - if current_platform.is_openvino(): - if selected_backend != _Backend.OPENVINO: - logger.info("Cannot use %s backend on OpenVINO.", selected_backend) - return _Backend.OPENVINO - - if current_platform.is_xpu(): - if selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) - return _Backend.IPEX - - if current_platform.is_tpu(): - if selected_backend != _Backend.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS - - if current_platform.is_rocm(): - # AMD GPUs. - selected_backend = (_Backend.ROCM_FLASH if selected_backend - == _Backend.FLASH_ATTN else selected_backend) - if selected_backend == _Backend.ROCM_FLASH: - if not current_platform.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - return _Backend.ROCM_FLASH - - if current_platform.is_hpu(): - return _Backend.HPU_ATTN + # get device-specific default attn_backend + default_backend = current_platform.get_default_attn_backend( + selected_backend) + if default_backend is not None: + return default_backend if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a7c90a3f5031b..2528f741864b3 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -13,7 +13,6 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.attention.selector import _Backend from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer +from vllm.platforms import _Backend from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.processor import get_processor diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a929b9323b245..0ac81387b1bd8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -39,7 +39,6 @@ make_batched_images, make_batched_videos, smart_resize) from vllm.attention import AttentionMetadata -from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils @@ -65,6 +64,7 @@ from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs) from vllm.multimodal.utils import cached_get_tokenizer +from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 03226f42ee053..2ab9b19e22068 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -9,13 +9,13 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.attention.selector import (_Backend, backend_name_to_enum, +from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 9e740837381f8..1f68fc2e25df3 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,3 +1,4 @@ +from .interface import _Backend # noqa: F401 from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 42bee31dfb0e9..f9a34a47959ec 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -5,7 +5,9 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -22,6 +24,12 @@ class CpuPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "cpu" + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) + return _Backend.TORCH_SDPA + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 170cfff94f90d..1e0888a30ba96 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,11 +1,15 @@ import torch -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend class HpuPlatform(Platform): _enum = PlatformEnum.HPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + return _Backend.HPU_ATTN + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 970c0d1be617e..f4849fa2ccfb0 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -11,6 +11,20 @@ VllmConfig = None +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + HPU_ATTN = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + + class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() @@ -71,6 +85,11 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend): + """Get the default attention backend of a device.""" + return None + @classmethod def get_device_capability( cls, diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 31fe3f1fcbfe4..ad69ced5417b3 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -3,7 +3,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -11,6 +11,12 @@ class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.OPENVINO: + logger.info("Cannot use %s backend on OpenVINO.", selected_backend) + return _Backend.OPENVINO + @classmethod def get_device_name(self, device_id: int = 0) -> str: return "openvino" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index fd8afc92b0f28..022256996f97b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger -from .interface import DeviceCapability, Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -19,6 +19,18 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if not cls.has_device_capability(90): + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) + return _Backend.ROCM_FLASH + @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 643db835c85ff..9057afb6514e4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -3,17 +3,27 @@ import torch -from .interface import Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class TpuPlatform(Platform): _enum = PlatformEnum.TPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS + @classmethod def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 106e8eddf458f..d0b3dca9a4195 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,11 +1,21 @@ import torch -from .interface import DeviceCapability, Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) class XPUPlatform(Platform): _enum = PlatformEnum.XPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + return _Backend.IPEX + @staticmethod def get_device_capability(device_id: int = 0) -> DeviceCapability: major, minor, *_ = torch.xpu.get_device_capability( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 82824faa6629a..687d2cc79360f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, +from vllm.attention.selector import (get_env_variable_attn_backend, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.forward_context import set_forward_context @@ -18,6 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) +from vllm.platforms import _Backend from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata)