diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst index 5a96cc8b3a464..60552e21cb71e 100644 --- a/docs/source/design/plugin_system.rst +++ b/docs/source/design/plugin_system.rst @@ -49,7 +49,7 @@ Every plugin has three parts: What Can Plugins Do? -------------------- -Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. +Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platform into vLLM. This is done by calling ``ModelRegistry.register_model`` or ``PlatformRegistry.register_platform`` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. Guidelines for Writing Plugins ------------------------------ diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py new file mode 100644 index 0000000000000..0b73d173040ab --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup(name='vllm_add_dummy_platform', + version='0.1', + packages=['vllm_add_dummy_platform'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_platform:register"] + }) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py new file mode 100644 index 0000000000000..dd954b4066ab9 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -0,0 +1,9 @@ +from vllm import PlatformRegistry + + +def register(): + # Register the dummy platform + PlatformRegistry.register_platform( + "my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform") + # Set the current platform to the dummy platform + PlatformRegistry.set_current_platform("my_platform") diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py new file mode 100644 index 0000000000000..8c0df08fa29b8 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py @@ -0,0 +1,13 @@ +class DummyAttentionImpl: + + def forward(self): + pass + + +class DummyAttentionBackend: + + def __init__(self): + pass + + def get_impl_cls(self): + return DummyAttentionImpl diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py new file mode 100644 index 0000000000000..1d9060b6b7e68 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py @@ -0,0 +1,7 @@ +from my_attention import DummyAttentionBackend + + +class DummyModelRunner: + + def __init__(self): + self.attn_backend = DummyAttentionBackend() diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py new file mode 100644 index 0000000000000..77478fedf22ae --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py @@ -0,0 +1,19 @@ +from vllm.config import VllmConfig +from vllm.platforms import Platform + + +class DummyPlatform(Platform): + device_name = "dummy" + + def __init__(self): + super().__init__() + + @classmethod + def get_device_name(cls) -> str: + return "dummy" + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + parallel_config.worker_cls = \ + "vllm_add_dummy_platform.my_worker.DummyWorker" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py new file mode 100644 index 0000000000000..a144df2438b20 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py @@ -0,0 +1,14 @@ +from typing import List + +from my_model_runner import DummyModelRunner + + +class DummyCacheEngine: + pass + + +class DummyWorker: + + def __init__(self): + self.cache_engine = List[DummyCacheEngine] + self.model_runner = DummyModelRunner() diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..3454dec3bb0f0 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,6 +12,7 @@ EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput, ScoringOutput, ScoringRequestOutput) +from vllm.platforms.registry import PlatformRegistry from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -22,6 +23,7 @@ "__version_tuple__", "LLM", "ModelRegistry", + "PlatformRegistry", "PromptType", "TextPrompt", "TokensPrompt", diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 419237c252ffd..3b394b9661716 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,123 +1,57 @@ from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform +from .registry import PlatformRegistry, detect_current_platform -current_platform: Platform +_current_platform: Platform = UnspecifiedPlatform() -# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because -# they only indicate the build configuration, not the runtime environment. -# For example, people can install a cuda build of pytorch but run on tpu. -is_tpu = False -try: - # While it's technically possible to install libtpu on a non-TPU machine, - # this is a very uncommon scenario. Therefore, we assume that libtpu is - # installed if and only if the machine has TPUs. - import libtpu # noqa: F401 - is_tpu = True -except Exception: - pass +def initialize_current_platform(): + """Initialize the current platform. This function is called when loading + the vllm plugin.""" + global _current_platform + # Get the current platform from the registry first. If the current platform + # is not set, try to detect the current platform. + if PlatformRegistry.current_platform is not None: + _current_platform = PlatformRegistry.get_current_platform_cls() + else: + _current_platform = detect_current_platform() -is_cuda = False -try: - import pynvml - pynvml.nvmlInit() - try: - if pynvml.nvmlDeviceGetCount() > 0: - is_cuda = True - finally: - pynvml.nvmlShutdown() -except Exception: - # CUDA is supported on Jetson, but NVML may not be. - import os +def update_current_platform(device_name: str): + """Update the current platform. This function is used by users to set the + current platform by hand.""" + global _current_platform + PlatformRegistry.set_current_platform(device_name) + _current_platform = PlatformRegistry.get_current_platform_cls() - def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") - if cuda_is_jetson(): - is_cuda = True +class CurrentPlatform: + """A wrapper that provides an interface to the current platform. + + `current_platform` is imported to many modules once vLLM is imported. + Updating `current_platform` value directly will not work in those modules. + So it needs the wrapper here to provide a dynamic platform loading + mechanism. -is_rocm = False + This class can make sure that the `current_platform` is always up-to-date. + """ -try: - import amdsmi - amdsmi.amdsmi_init() - try: - if len(amdsmi.amdsmi_get_processor_handles()) > 0: - is_rocm = True - finally: - amdsmi.amdsmi_shut_down() -except Exception: - pass + def __init__(self): + self.platform = _current_platform -is_hpu = False -try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None -except Exception: - pass + def _refresh_current_platform(self): + """Refresh the current platform dynamically.""" + global _current_platform + if _current_platform is not self.platform: + self.platform = _current_platform -is_xpu = False + def __getattr__(self, name): + """Go pass to the current platform.""" + self._refresh_current_platform() + return getattr(self.platform, name) -try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 - import torch - if hasattr(torch, 'xpu') and torch.xpu.is_available(): - is_xpu = True -except Exception: - pass -is_cpu = False -try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") -except Exception: - pass - -is_neuron = False -try: - import transformers_neuronx # noqa: F401 - is_neuron = True -except ImportError: - pass - -is_openvino = False -try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") -except Exception: - pass - -if is_tpu: - # people might install pytorch built with cuda but run on tpu - # so we need to check tpu first - from .tpu import TpuPlatform - current_platform = TpuPlatform() -elif is_cuda: - from .cuda import CudaPlatform - current_platform = CudaPlatform() -elif is_rocm: - from .rocm import RocmPlatform - current_platform = RocmPlatform() -elif is_hpu: - from .hpu import HpuPlatform - current_platform = HpuPlatform() -elif is_xpu: - from .xpu import XPUPlatform - current_platform = XPUPlatform() -elif is_cpu: - from .cpu import CpuPlatform - current_platform = CpuPlatform() -elif is_neuron: - from .neuron import NeuronPlatform - current_platform = NeuronPlatform() -elif is_openvino: - from .openvino import OpenVinoPlatform - current_platform = OpenVinoPlatform() -else: - current_platform = UnspecifiedPlatform() +# The global variable for other modules to use. +current_platform: CurrentPlatform = CurrentPlatform() __all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] diff --git a/vllm/platforms/registry.py b/vllm/platforms/registry.py new file mode 100644 index 0000000000000..bb17e3edf124e --- /dev/null +++ b/vllm/platforms/registry.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass, field +from typing import Dict, Optional + +from vllm.platforms import Platform + +from .interface import UnspecifiedPlatform + +# The list of supported in-tree platforms. Update this list when adding/removing +# platforms. +_VLLM_PLATFORMS = { + "cpu": "vllm.platforms.cpu:CpuPlatform", + "cuda": "vllm.platforms.cuda:CudaPlatform", + "hpu": "vllm.platforms.hpu:HpuPlatform", + "neuron": "vllm.platforms.neuron:NeuronPlatform", + "openvino": "vllm.platforms.openvino:OpenVinoPlatform", + "rocm": "vllm.platforms.rocm:RocmPlatform", + "tpu": "vllm.platforms.tpu:TpuPlatform", + "xpu": "vllm.platforms.xpu:XPUPlatform", +} + + +@dataclass +class _PlatformRegistry: + # The mapping from device name to platform class string. + platforms: Dict[str, str] = field(default_factory=dict) + # The current platform name. + current_platform: Optional[str] = None + + def _load_platform_cls(self, device_name: str) -> Platform: + """Load a platform object by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + platform_cls_str = self.platforms[device_name] + module_name, cls_name = platform_cls_str.split(":") + module = __import__(module_name, fromlist=[cls_name]) + return getattr(module, cls_name) + + def register_platform(self, device_name: str, platform: str): + """Register a platform by device name. This function is called by the + platform plugin.""" + if device_name in self.platforms: + raise ValueError(f"Platform {device_name} already registered.") + self.platforms[device_name] = platform + + def set_current_platform(self, device_name: str): + """Set the current platform by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + self.current_platform = device_name + + def get_current_platform_cls(self) -> Platform: + """Get the current platform object.""" + if self.current_platform is None: + raise ValueError("No current platform set.") + return self._load_platform_cls(self.current_platform) + + +PlatformRegistry = _PlatformRegistry({ + device_name: platform + for device_name, platform in _VLLM_PLATFORMS.items() +}) + + +def detect_current_platform() -> Platform: + """Detect the current platform by checking the installed packages.""" + CurrentPlatform: Optional[type[Platform]] = None + # NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because + # they only indicate the build configuration, not the runtime environment. + # For example, people can install a cuda build of pytorch but run on tpu. + + # Load TPU Platform + try: + # While it's technically possible to install libtpu on a non-TPU + # machine, this is a very uncommon scenario. Therefore, we assume that + # libtpu is installed if and only if the machine has TPUs. + import libtpu # noqa: F401 + + from .tpu import TpuPlatform as CurrentPlatform + except Exception: + pass + + # Load CUDA Platform + if not CurrentPlatform: + try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + from .cuda import CudaPlatform as CurrentPlatform + finally: + pynvml.nvmlShutdown() + except Exception: + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + from .cuda import CudaPlatform as CurrentPlatform + + # Load ROCm Platform + if not CurrentPlatform: + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + from .rocm import RocmPlatform as CurrentPlatform + finally: + amdsmi.amdsmi_shut_down() + except Exception: + pass + + # Load HPU Platform + if not CurrentPlatform: + try: + from importlib import util + assert util.find_spec('habana_frameworks') is not None + from .hpu import HpuPlatform as CurrentPlatform + except Exception: + pass + + # Load XPU Platform + if not CurrentPlatform: + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + from .xpu import XPUPlatform as CurrentPlatform + except Exception: + pass + + # Load CPU Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "cpu" in version("vllm") + from .cpu import CpuPlatform as CurrentPlatform + except Exception: + pass + + # Load Neuron Platform + if not CurrentPlatform: + try: + import transformers_neuronx # noqa: F401 + + from .neuron import NeuronPlatform as CurrentPlatform + except ImportError: + pass + + # Load OpenVINO Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "openvino" in version("vllm") + from .openvino import OpenVinoPlatform as CurrentPlatform + except Exception: + pass + + if CurrentPlatform: + PlatformRegistry.set_current_platform(CurrentPlatform.device_name) + return CurrentPlatform() + + return UnspecifiedPlatform() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 17f604ea0e202..0336b69a78376 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,7 +4,7 @@ import torch import vllm.envs as envs -from vllm.platforms import current_platform +from vllm.platforms import current_platform, initialize_current_platform logger = logging.getLogger(__name__) @@ -17,6 +17,44 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + global plugins_loaded + if not plugins_loaded: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group='vllm.general_plugins') + if len(discovered_plugins) == 0: + logger.debug("No plugins found.") + else: + logger.info("Available plugins:") + for plugin in discovered_plugins: + logger.info("name=%s, value=%s, group=%s", plugin.name, + plugin.value, plugin.group) + if allowed_plugins is None: + logger.info("all available plugins will be loaded.") + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + else: + logger.info("plugins to load: %s", allowed_plugins) + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + func() + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", + plugin.name) + # initialize current platform should be called after all plugins are + # loaded. + initialize_current_platform() + + plugins_loaded = True # all processes created by vllm will load plugins, # and here we can inject some common environment variables @@ -42,38 +80,3 @@ def load_general_plugins(): # requires enabling lazy collectives # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' - - global plugins_loaded - if plugins_loaded: - return - plugins_loaded = True - import sys - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points - - allowed_plugins = envs.VLLM_PLUGINS - - discovered_plugins = entry_points(group='vllm.general_plugins') - if len(discovered_plugins) == 0: - logger.debug("No plugins found.") - return - logger.info("Available plugins:") - for plugin in discovered_plugins: - logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value, - plugin.group) - if allowed_plugins is None: - logger.info("all available plugins will be loaded.") - logger.info("set environment variable VLLM_PLUGINS to control" - " which plugins to load.") - else: - logger.info("plugins to load: %s", allowed_plugins) - for plugin in discovered_plugins: - if allowed_plugins is None or plugin.name in allowed_plugins: - try: - func = plugin.load() - func() - logger.info("plugin %s loaded.", plugin.name) - except Exception: - logger.exception("Failed to load plugin %s", plugin.name)