Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 18, 2024
1 parent f04e407 commit 464e184
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 144 deletions.
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Every plugin has three parts:
What Can Plugins Do?
--------------------

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platform into vLLM. This is done by calling ``ModelRegistry.register_model`` or ``PlatformRegistry.register_platform``to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Guidelines for Writing Plugins
------------------------------
Expand Down
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from vllm.platforms import Platform


class DummyPlatform(Platform):
device_name = "dummy"

def __init__(self):
super().__init__()

def get_device_name(self) -> str:
return "dummy"
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -22,6 +23,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
146 changes: 39 additions & 107 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,55 @@
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .registry import PlatformRegistry, detect_current_platform

current_platform: Platform
_current_platform: Platform = UnspecifiedPlatform()

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass
def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
global _current_platform
# Get the current platform from the registry first. If the current platform
# is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
_current_platform = PlatformRegistry.get_current_platform()
else:
_current_platform = detect_current_platform()

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os
def update_current_platform(device_name: str):
"""Update the current platform. This function is used by users to set the
current platform by hand."""
global _current_platform
_current_platform = PlatformRegistry.get_current_platform(device_name)

Check failure on line 24 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Too many arguments for "get_current_platform" of "_PlatformRegistry" [call-arg]

Check failure on line 24 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Too many arguments for "get_current_platform" of "_PlatformRegistry" [call-arg]

Check failure on line 24 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Too many arguments for "get_current_platform" of "_PlatformRegistry" [call-arg]

Check failure on line 24 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Too many arguments for "get_current_platform" of "_PlatformRegistry" [call-arg]

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True
class CurrentPlatform:
"""A wrapper that provides an interface to the current platform.
`current_platform` is imported to many modules once vLLM is imported.
Updating `current_platform` value directly will not work in those modules.
So it needs the wrapper here to provide a dynamic platform loading
mechanism.
is_rocm = False
This class can make sure that the `current_platform` is always up-to-date.
"""

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
def __init__(self):
self.platform = _current_platform

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass
def _refresh_current_platform(self):
"""Refresh the current platform dynamically."""
global _current_platform
if _current_platform is not self.platform:
self.platform = _current_platform

is_xpu = False
def __getattr__(self, name):
"""Go through to the current platform."""
self._refresh_current_platform()
return getattr(self.platform, name)

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
# The global variable for other modules to use.
current_platform = CurrentPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
172 changes: 172 additions & 0 deletions vllm/platforms/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from dataclasses import dataclass, field
from typing import Dict, Optional

from vllm.platforms import Platform

from .interface import UnspecifiedPlatform

# The list of supported in-tree platforms. Update this list when adding/removing
# platforms.
_VLLM_PLATFORMS = {
"cpu": "vllm.platforms.cpu:CpuPlatform",
"cuda": "vllm.platforms.cuda:CudaPlatform",
"hpu": "vllm.platforms.hpu:HpuPlatform",
"neuron": "vllm.platforms.neuron:NeuronPlatform",
"openvino": "vllm.platforms.openvino:OpenVinoPlatform",
"rocm": "vllm.platforms.rocm:RocmPlatform",
"tpu": "vllm.platforms.tpu:TpuPlatform",
"xpu": "vllm.platforms.xpu:XPUPlatform",
}


@dataclass
class _PlatformRegistry:
# The mapping from device name to platform class string.
platforms: Dict[str, str] = field(default_factory=dict)
# The current platform name.
current_platform: Optional[str] = None

def _load_platform_cls(self, device_name: str) -> Platform:
"""Load a platform object by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
platform_cls_str = self.platforms[device_name]
module_name, cls_name = platform_cls_str.split(":")
module = __import__(module_name, fromlist=[cls_name])
return getattr(module, cls_name)

def register_platform(self, device_name: str, platform: str):
"""Register a platform by device name. This function is called by the
platform plugin."""
if device_name in self.platforms:
raise ValueError(f"Platform {device_name} already registered.")
self.platforms[device_name] = platform

def set_current_platform(self, device_name: str):
"""Set the current platform by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
self.current_platform = device_name

def get_current_platform(self) -> Platform:
"""Get the current platform object."""
if self.current_platform is None:
raise ValueError("No current platform set.")
return self._load_platform_cls(self.current_platform)


PlatformRegistry = _PlatformRegistry({
device_name: platform
for device_name, platform in _VLLM_PLATFORMS.items()
})


def detect_current_platform() -> Platform:
"""Detect the current platform by checking the installed packages."""
CurrentPlatform: Optional[type[Platform]] = None
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

# Load TPU Platform
try:
# While it's technically possible to install libtpu on a non-TPU
# machine, this is a very uncommon scenario. Therefore, we assume that
# libtpu is installed if and only if the machine has TPUs.
import libtpu # noqa: F401

from .tpu import TpuPlatform as CurrentPlatform
except Exception:
pass

# Load CUDA Platform
if not CurrentPlatform:
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
from .cuda import CudaPlatform as CurrentPlatform
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
from .cuda import CudaPlatform as CurrentPlatform

# Load ROCm Platform
if not CurrentPlatform:
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
from .rocm import RocmPlatform as CurrentPlatform
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

# Load HPU Platform
if not CurrentPlatform:
try:
from importlib import util
assert util.find_spec('habana_frameworks') is not None
from .hpu import HpuPlatform as CurrentPlatform
except Exception:
pass

# Load XPU Platform
if not CurrentPlatform:
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
from .xpu import XPUPlatform as CurrentPlatform
except Exception:
pass

# Load CPU Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "cpu" in version("vllm")
from .cpu import CpuPlatform as CurrentPlatform
except Exception:
pass

# Load Neuron Platform
if not CurrentPlatform:
try:
import transformers_neuronx # noqa: F401

from .neuron import NeuronPlatform as CurrentPlatform
except ImportError:
pass

# Load OpenVINO Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "openvino" in version("vllm")
from .openvino import OpenVinoPlatform as CurrentPlatform
except Exception:
pass

if CurrentPlatform:
PlatformRegistry.set_current_platform(CurrentPlatform.device_name)
return CurrentPlatform()

return UnspecifiedPlatform()
Loading

0 comments on commit 464e184

Please sign in to comment.