Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 18, 2024
1 parent f04e407 commit 1006d7c
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 156 deletions.
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Every plugin has three parts:
What Can Plugins Do?
--------------------

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platform into vLLM. This is done by calling ``ModelRegistry.register_model`` or ``PlatformRegistry.register_platform`` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.

Guidelines for Writing Plugins
------------------------------
Expand Down
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from vllm.platforms import Platform


class DummyPlatform(Platform):
device_name = "dummy"

def __init__(self):
super().__init__()

def get_device_name(self) -> str:
return "dummy"
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -22,6 +23,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
171 changes: 52 additions & 119 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,56 @@
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass

is_xpu = False

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
from .registry import PlatformRegistry, detect_current_platform

_current_platform: Platform = UnspecifiedPlatform()


def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
global _current_platform
# Get the current platform from the registry first. If the current platform
# is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
_current_platform = PlatformRegistry.get_current_platform_cls()
else:
_current_platform = detect_current_platform()


def update_current_platform(device_name: str):
"""Update the current platform. This function is used by users to set the
current platform by hand."""
global _current_platform
PlatformRegistry.set_current_platform(device_name)
_current_platform = PlatformRegistry.get_current_platform_cls()


class CurrentPlatform:
"""A wrapper that provides an interface to the current platform.
`current_platform` is imported to many modules once vLLM is imported.
Updating `current_platform` value directly will not work in those modules.
So it needs the wrapper here to provide a dynamic platform loading
mechanism.
This class can make sure that the `current_platform` is always up-to-date.
"""

def __init__(self):
self.platform = _current_platform

def _refresh_current_platform(self):
"""Refresh the current platform dynamically."""
global _current_platform
if _current_platform is not self.platform:
self.platform = _current_platform

def __getattr__(self, name):
"""Go through to the current platform."""
self._refresh_current_platform()
return getattr(self.platform, name)

# The global variable for other modules to use.
current_platform = CurrentPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
Loading

0 comments on commit 1006d7c

Please sign in to comment.