Skip to content

Commit

Permalink
[Platform][Refactor] Extract func get_default_attn_backend to `Plat…
Browse files Browse the repository at this point in the history
…form` (#10358)

Signed-off-by: Mengqing Cao <[email protected]>
  • Loading branch information
MengqingCao authored Nov 19, 2024
1 parent 7eb719d commit 8c1fb50
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 69 deletions.
19 changes: 11 additions & 8 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL


Expand All @@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform.is_openvino",
return_value=True):
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name


Expand Down
56 changes: 6 additions & 50 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
Expand All @@ -9,26 +8,12 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR

logger = init_logger(__name__)


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()


def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None

Expand Down Expand Up @@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX

if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

if current_platform.is_hpu():
return _Backend.HPU_ATTN
# get device-specific default attn_backend
default_backend = current_platform.get_default_attn_backend(
selected_backend)
if default_backend is not None:
return default_backend

if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
Expand All @@ -38,6 +37,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
make_batched_images, make_batched_videos, smart_resize)

from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -65,6 +64,7 @@
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Platform
Expand Down
10 changes: 9 additions & 1 deletion vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand All @@ -22,6 +24,12 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
Expand Down
6 changes: 5 additions & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend


class HpuPlatform(Platform):
_enum = PlatformEnum.HPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN

@staticmethod
def inference_mode():
return torch.no_grad()
19 changes: 19 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
VllmConfig = None


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()


class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
Expand Down Expand Up @@ -71,6 +85,11 @@ def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None

@classmethod
def get_device_capability(
cls,
Expand Down
8 changes: 7 additions & 1 deletion vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
import vllm.envs as envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)


class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"
Expand Down
14 changes: 13 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

logger = init_logger(__name__)

Expand All @@ -19,6 +19,18 @@
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
Expand Down
12 changes: 11 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@

import torch

from .interface import Platform, PlatformEnum
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None

logger = init_logger(__name__)


class TpuPlatform(Platform):
_enum = PlatformEnum.TPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
Expand Down
12 changes: 11 additions & 1 deletion vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import torch

from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

logger = init_logger(__name__)


class XPUPlatform(Platform):
_enum = PlatformEnum.XPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX

@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability(
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from vllm.attention.selector import (get_env_variable_attn_backend,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
Expand All @@ -18,6 +18,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
Expand Down

0 comments on commit 8c1fb50

Please sign in to comment.