From 65b83e470549ac16bebd21529624faa1cd9165db Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Mon, 16 Dec 2024 10:22:39 +0800 Subject: [PATCH] Add platform pluggable --- .../plugins/vllm_add_dummy_platform/setup.py | 9 + .../vllm_add_dummy_platform/__init__.py | 9 + .../vllm_add_dummy_platform/my_platform.py | 9 + vllm/__init__.py | 2 + vllm/platforms/__init__.py | 220 +++++++++--------- vllm/platforms/registry.py | 59 +++++ 6 files changed, 194 insertions(+), 114 deletions(-) create mode 100644 tests/plugins/vllm_add_dummy_platform/setup.py create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py create mode 100644 vllm/platforms/registry.py diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py new file mode 100644 index 0000000000000..0b73d173040ab --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup(name='vllm_add_dummy_platform', + version='0.1', + packages=['vllm_add_dummy_platform'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_platform:register"] + }) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py new file mode 100644 index 0000000000000..29a35d142153a --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -0,0 +1,9 @@ +from vllm import PlatformRegistry + + +def register(): + # Register the dummy platform + PlatformRegistry.register( + "my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform") + # Set the current platform to the dummy platform + PlatformRegistry.set_current_platform("my_platform") diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py new file mode 100644 index 0000000000000..0eb37481fd416 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py @@ -0,0 +1,9 @@ +from vllm.platforms import Platform + + +class DummyPlatform(Platform): + def __init__(self): + super().__init__() + + def get_device_name(self) -> str: + return "dummy" diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..3454dec3bb0f0 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,6 +12,7 @@ EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput, ScoringOutput, ScoringRequestOutput) +from vllm.platforms.registry import PlatformRegistry from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -22,6 +23,7 @@ "__version_tuple__", "LLM", "ModelRegistry", + "PlatformRegistry", "PromptType", "TextPrompt", "TokensPrompt", diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 419237c252ffd..db9ef3e9a3901 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,123 +1,115 @@ from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform +from .registry import PlatformRegistry # noqa: F401 current_platform: Platform -# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because -# they only indicate the build configuration, not the runtime environment. -# For example, people can install a cuda build of pytorch but run on tpu. - -is_tpu = False -try: - # While it's technically possible to install libtpu on a non-TPU machine, - # this is a very uncommon scenario. Therefore, we assume that libtpu is - # installed if and only if the machine has TPUs. - import libtpu # noqa: F401 - is_tpu = True -except Exception: - pass - -is_cuda = False - -try: - import pynvml - pynvml.nvmlInit() - try: - if pynvml.nvmlDeviceGetCount() > 0: - is_cuda = True - finally: - pynvml.nvmlShutdown() -except Exception: - # CUDA is supported on Jetson, but NVML may not be. - import os - - def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") - - if cuda_is_jetson(): - is_cuda = True - -is_rocm = False +# Get the current platform from the registry first. If the current platform is +# not set, try to detect the current platform. +if PlatformRegistry.current_platform is not None: + current_platform = PlatformRegistry.get_current_platform() +else: + CurrentPlatform = None + # NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because + # they only indicate the build configuration, not the runtime environment. + # For example, people can install a cuda build of pytorch but run on tpu. -try: - import amdsmi - amdsmi.amdsmi_init() + # Load TPU Platform try: - if len(amdsmi.amdsmi_get_processor_handles()) > 0: - is_rocm = True - finally: - amdsmi.amdsmi_shut_down() -except Exception: - pass - -is_hpu = False -try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None -except Exception: - pass - -is_xpu = False - -try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 - import torch - if hasattr(torch, 'xpu') and torch.xpu.is_available(): - is_xpu = True -except Exception: - pass - -is_cpu = False -try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") -except Exception: - pass - -is_neuron = False -try: - import transformers_neuronx # noqa: F401 - is_neuron = True -except ImportError: - pass - -is_openvino = False -try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") -except Exception: - pass - -if is_tpu: - # people might install pytorch built with cuda but run on tpu - # so we need to check tpu first - from .tpu import TpuPlatform - current_platform = TpuPlatform() -elif is_cuda: - from .cuda import CudaPlatform - current_platform = CudaPlatform() -elif is_rocm: - from .rocm import RocmPlatform - current_platform = RocmPlatform() -elif is_hpu: - from .hpu import HpuPlatform - current_platform = HpuPlatform() -elif is_xpu: - from .xpu import XPUPlatform - current_platform = XPUPlatform() -elif is_cpu: - from .cpu import CpuPlatform - current_platform = CpuPlatform() -elif is_neuron: - from .neuron import NeuronPlatform - current_platform = NeuronPlatform() -elif is_openvino: - from .openvino import OpenVinoPlatform - current_platform = OpenVinoPlatform() -else: - current_platform = UnspecifiedPlatform() + # While it's technically possible to install libtpu on a non-TPU + # machine, this is a very uncommon scenario. Therefore, we assume that + # libtpu is installed if and only if the machine has TPUs. + import libtpu # noqa: F401 + from .tpu import TpuPlatform as CurrentPlatform + except Exception: + pass + + # Load CUDA Platform + if not CurrentPlatform: + try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + from .cuda import CudaPlatform as CurrentPlatform + finally: + pynvml.nvmlShutdown() + except Exception: + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + from .cuda import CudaPlatform as CurrentPlatform + + # Load ROCm Platform + if not CurrentPlatform: + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + from .rocm import RocmPlatform as CurrentPlatform + finally: + amdsmi.amdsmi_shut_down() + except Exception: + pass + + # Load HPU Platform + if not CurrentPlatform: + try: + from importlib import util + assert util.find_spec('habana_frameworks') is not None + from.hpu import HpuPlatform as CurrentPlatform + except Exception: + pass + + # Load XPU Platform + if not CurrentPlatform: + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + from.xpu import XPUPlatform as CurrentPlatform + except Exception: + pass + + # Load CPU Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "cpu" in version("vllm") + from.cpu import CpuPlatform as CurrentPlatform + except Exception: + pass + + # Load Neuron Platform + if not CurrentPlatform: + try: + import transformers_neuronx # noqa: F401 + from.neuron import NeuronPlatform as CurrentPlatform + except ImportError: + pass + + # Load OpenVINO Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "openvino" in version("vllm") + from.openvino import OpenVinoPlatform as CurrentPlatform + except Exception: + pass + + if CurrentPlatform: + device_name = CurrentPlatform.get_device_name() + PlatformRegistry.set_current_platform(device_name) + current_platform = PlatformRegistry.load_platfrom_cls(device_name) + else: + current_platform = UnspecifiedPlatform() __all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] diff --git a/vllm/platforms/registry.py b/vllm/platforms/registry.py new file mode 100644 index 0000000000000..58e4683da89b3 --- /dev/null +++ b/vllm/platforms/registry.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from typing import Dict, Optional, Type + +from vllm.platforms import Platform + + +_VLLM_PLATFORMS = { + "cpu": "vllm.platforms.cpu:CpuPlatform", + "cuda": "vllm.platforms.cuda:CudaPlatform", + "hpu": "vllm.platforms.hpu:HpuPlatform", + "neuron": "vllm.platforms.neuron:NeuronPlatform", + "openvino": "vllm.platforms.openvino:OpenVinoPlatform", + "rocm": "vllm.platforms.rocm:RocmPlatform", + "tpu": "vllm.platforms.tpu:TpuPlatform", + "xpu": "vllm.platforms.xpu:XPUPlatform", +} + + +@dataclass +class _PlatformRegistry: + platforms: Dict[str, str] = field(default_factory=dict) + current_platform: Optional[str] = None + + def register(self, device_name: str, platform: str): + """Register a platform by device name. This function is called by the + platform plugin.""" + if device_name in self.platforms: + raise ValueError(f"Platform {device_name} already registered.") + self.platforms[device_name] = platform + + def load_platfrom_cls(self, device_name: str) -> Platform: + """Load a platform object by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + platform_cls_str = self.platforms[device_name] + module_name, cls_name = platform_cls_str.split(":") + module = __import__(module_name, fromlist=[cls_name]) + return getattr(module, cls_name) + + def set_current_platform(self, device_name: str): + """Set the current platform by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + self.current_platform = device_name + + def get_current_platform(self) -> Platform: + """Get the current platform object.""" + if self.current_platform is None: + raise ValueError("No current platform set.") + return self.load_platfrom_cls(self.current_platform) + +PlatformRegistry = _PlatformRegistry({ + device_name: platform + for device_name, platform in _VLLM_PLATFORMS.items() +})