diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py new file mode 100644 index 0000000000000..0b73d173040ab --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup(name='vllm_add_dummy_platform', + version='0.1', + packages=['vllm_add_dummy_platform'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_platform:register"] + }) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py new file mode 100644 index 0000000000000..29a35d142153a --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -0,0 +1,9 @@ +from vllm import PlatformRegistry + + +def register(): + # Register the dummy platform + PlatformRegistry.register( + "my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform") + # Set the current platform to the dummy platform + PlatformRegistry.set_current_platform("my_platform") diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py new file mode 100644 index 0000000000000..06249deed2449 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py @@ -0,0 +1,10 @@ +from vllm.platforms import Platform + + +class DummyPlatform(Platform): + + def __init__(self): + super().__init__() + + def get_device_name(self) -> str: + return "dummy" diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..3454dec3bb0f0 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,6 +12,7 @@ EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput, ScoringOutput, ScoringRequestOutput) +from vllm.platforms.registry import PlatformRegistry from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -22,6 +23,7 @@ "__version_tuple__", "LLM", "ModelRegistry", + "PlatformRegistry", "PromptType", "TextPrompt", "TokensPrompt", diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 419237c252ffd..be981d49f7682 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,123 +1,20 @@ from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform +from .registry import PlatformRegistry, detect_current_platform -current_platform: Platform +current_platform: Platform = UnspecifiedPlatform() -# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because -# they only indicate the build configuration, not the runtime environment. -# For example, people can install a cuda build of pytorch but run on tpu. -is_tpu = False -try: - # While it's technically possible to install libtpu on a non-TPU machine, - # this is a very uncommon scenario. Therefore, we assume that libtpu is - # installed if and only if the machine has TPUs. - import libtpu # noqa: F401 - is_tpu = True -except Exception: - pass +def initialize_current_platform(): + """Initialize the current platform. This function is called when loading + the vllm plugin.""" + global current_platform + # Get the current platform from the registry first. If the current platform + # is not set, try to detect the current platform. + if PlatformRegistry.current_platform is not None: + current_platform = PlatformRegistry.get_current_platform() + else: + current_platform = detect_current_platform() -is_cuda = False - -try: - import pynvml - pynvml.nvmlInit() - try: - if pynvml.nvmlDeviceGetCount() > 0: - is_cuda = True - finally: - pynvml.nvmlShutdown() -except Exception: - # CUDA is supported on Jetson, but NVML may not be. - import os - - def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") - - if cuda_is_jetson(): - is_cuda = True - -is_rocm = False - -try: - import amdsmi - amdsmi.amdsmi_init() - try: - if len(amdsmi.amdsmi_get_processor_handles()) > 0: - is_rocm = True - finally: - amdsmi.amdsmi_shut_down() -except Exception: - pass - -is_hpu = False -try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None -except Exception: - pass - -is_xpu = False - -try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 - import torch - if hasattr(torch, 'xpu') and torch.xpu.is_available(): - is_xpu = True -except Exception: - pass - -is_cpu = False -try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") -except Exception: - pass - -is_neuron = False -try: - import transformers_neuronx # noqa: F401 - is_neuron = True -except ImportError: - pass - -is_openvino = False -try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") -except Exception: - pass - -if is_tpu: - # people might install pytorch built with cuda but run on tpu - # so we need to check tpu first - from .tpu import TpuPlatform - current_platform = TpuPlatform() -elif is_cuda: - from .cuda import CudaPlatform - current_platform = CudaPlatform() -elif is_rocm: - from .rocm import RocmPlatform - current_platform = RocmPlatform() -elif is_hpu: - from .hpu import HpuPlatform - current_platform = HpuPlatform() -elif is_xpu: - from .xpu import XPUPlatform - current_platform = XPUPlatform() -elif is_cpu: - from .cpu import CpuPlatform - current_platform = CpuPlatform() -elif is_neuron: - from .neuron import NeuronPlatform - current_platform = NeuronPlatform() -elif is_openvino: - from .openvino import OpenVinoPlatform - current_platform = OpenVinoPlatform() -else: - current_platform = UnspecifiedPlatform() __all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] diff --git a/vllm/platforms/registry.py b/vllm/platforms/registry.py new file mode 100644 index 0000000000000..1b0aa6a551422 --- /dev/null +++ b/vllm/platforms/registry.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from typing import Dict, Optional + +from vllm.platforms import Platform + +from .interface import UnspecifiedPlatform + +_VLLM_PLATFORMS = { + "cpu": "vllm.platforms.cpu:CpuPlatform", + "cuda": "vllm.platforms.cuda:CudaPlatform", + "hpu": "vllm.platforms.hpu:HpuPlatform", + "neuron": "vllm.platforms.neuron:NeuronPlatform", + "openvino": "vllm.platforms.openvino:OpenVinoPlatform", + "rocm": "vllm.platforms.rocm:RocmPlatform", + "tpu": "vllm.platforms.tpu:TpuPlatform", + "xpu": "vllm.platforms.xpu:XPUPlatform", +} + + +@dataclass +class _PlatformRegistry: + platforms: Dict[str, str] = field(default_factory=dict) + current_platform: Optional[str] = None + + def register(self, device_name: str, platform: str): + """Register a platform by device name. This function is called by the + platform plugin.""" + if device_name in self.platforms: + raise ValueError(f"Platform {device_name} already registered.") + self.platforms[device_name] = platform + + def load_platform_cls(self, device_name: str) -> Platform: + """Load a platform object by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + platform_cls_str = self.platforms[device_name] + module_name, cls_name = platform_cls_str.split(":") + module = __import__(module_name, fromlist=[cls_name]) + return getattr(module, cls_name) + + def set_current_platform(self, device_name: str): + """Set the current platform by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + self.current_platform = device_name + + def get_current_platform(self) -> Platform: + """Get the current platform object.""" + if self.current_platform is None: + raise ValueError("No current platform set.") + return self.load_platform_cls(self.current_platform) + + +PlatformRegistry = _PlatformRegistry({ + device_name: platform + for device_name, platform in _VLLM_PLATFORMS.items() +}) + + +def detect_current_platform() -> Platform: + """Detect the current platform by checking the installed packages.""" + CurrentPlatform: Optional[type[Platform]] = None + # NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because + # they only indicate the build configuration, not the runtime environment. + # For example, people can install a cuda build of pytorch but run on tpu. + + # Load TPU Platform + try: + # While it's technically possible to install libtpu on a non-TPU + # machine, this is a very uncommon scenario. Therefore, we assume that + # libtpu is installed if and only if the machine has TPUs. + import libtpu # noqa: F401 + + from .tpu import TpuPlatform as CurrentPlatform + except Exception: + pass + + # Load CUDA Platform + if not CurrentPlatform: + try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + from .cuda import CudaPlatform as CurrentPlatform + finally: + pynvml.nvmlShutdown() + except Exception: + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + from .cuda import CudaPlatform as CurrentPlatform + + # Load ROCm Platform + if not CurrentPlatform: + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + from .rocm import RocmPlatform as CurrentPlatform + finally: + amdsmi.amdsmi_shut_down() + except Exception: + pass + + # Load HPU Platform + if not CurrentPlatform: + try: + from importlib import util + assert util.find_spec('habana_frameworks') is not None + from .hpu import HpuPlatform as CurrentPlatform + except Exception: + pass + + # Load XPU Platform + if not CurrentPlatform: + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + from .xpu import XPUPlatform as CurrentPlatform + except Exception: + pass + + # Load CPU Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "cpu" in version("vllm") + from .cpu import CpuPlatform as CurrentPlatform + except Exception: + pass + + # Load Neuron Platform + if not CurrentPlatform: + try: + import transformers_neuronx # noqa: F401 + + from .neuron import NeuronPlatform as CurrentPlatform + except ImportError: + pass + + # Load OpenVINO Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "openvino" in version("vllm") + from .openvino import OpenVinoPlatform as CurrentPlatform + except Exception: + pass + + if CurrentPlatform: + PlatformRegistry.set_current_platform(CurrentPlatform.device_name) + current_platform = CurrentPlatform() + else: + current_platform = UnspecifiedPlatform() + return current_platform diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 17f604ea0e202..a51fc902eca66 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,7 +4,7 @@ import torch import vllm.envs as envs -from vllm.platforms import current_platform +from vllm.platforms import initialize_current_platform logger = logging.getLogger(__name__) @@ -17,10 +17,49 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + global plugins_loaded + if not plugins_loaded: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group='vllm.general_plugins') + if len(discovered_plugins) == 0: + logger.debug("No plugins found.") + else: + logger.info("Available plugins:") + for plugin in discovered_plugins: + logger.info("name=%s, value=%s, group=%s", plugin.name, + plugin.value, plugin.group) + if allowed_plugins is None: + logger.info("all available plugins will be loaded.") + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + else: + logger.info("plugins to load: %s", allowed_plugins) + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + func() + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", + plugin.name) + # initialize current platform should be called after all plugins are + # loaded. + initialize_current_platform() + + plugins_loaded = True # all processes created by vllm will load plugins, # and here we can inject some common environment variables # for all processes. + from vllm.platforms import current_platform # see https://github.com/vllm-project/vllm/issues/10480 os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' @@ -42,38 +81,3 @@ def load_general_plugins(): # requires enabling lazy collectives # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' - - global plugins_loaded - if plugins_loaded: - return - plugins_loaded = True - import sys - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points - - allowed_plugins = envs.VLLM_PLUGINS - - discovered_plugins = entry_points(group='vllm.general_plugins') - if len(discovered_plugins) == 0: - logger.debug("No plugins found.") - return - logger.info("Available plugins:") - for plugin in discovered_plugins: - logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value, - plugin.group) - if allowed_plugins is None: - logger.info("all available plugins will be loaded.") - logger.info("set environment variable VLLM_PLUGINS to control" - " which plugins to load.") - else: - logger.info("plugins to load: %s", allowed_plugins) - for plugin in discovered_plugins: - if allowed_plugins is None or plugin.name in allowed_plugins: - try: - func = plugin.load() - func() - logger.info("plugin %s loaded.", plugin.name) - except Exception: - logger.exception("Failed to load plugin %s", plugin.name)