Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 25, 2024
1 parent 3f3e92e commit fc4da42
Show file tree
Hide file tree
Showing 14 changed files with 366 additions and 162 deletions.
11 changes: 11 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ steps:
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests

# platform plugin test need a single pipeline, since it needs to install a new fake platform plugin
- label: Platform Plugin Test # < 1min
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s platform/test_platform_plugin.py

- label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Every plugin has three parts:

## What Can Plugins Do?

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platforms into vLLM. This is done by calling `ModelRegistry.register_model` or `PlatformRegistry.register_platform` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.

## Guidelines for Writing Plugins

Expand Down
10 changes: 10 additions & 0 deletions tests/platform/test_platform_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from vllm.platforms import PlatformRegistry, current_platform


def test_current_platform_register():
# make sure the platform is registered
assert PlatformRegistry.current_platform == "my_platform"
# make sure the platform is loaded
assert current_platform.device_name == "dummy"
assert current_platform.is_async_output_supported(enforce_eager=True) \
is False
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform.DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class DummyAttentionImpl:

def forward(self):
pass


class DummyAttentionBackend:

def __init__(self):
pass

def get_impl_cls(self):
return DummyAttentionImpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from my_attention import DummyAttentionBackend


class DummyModelRunner:

def __init__(self):
self.attn_backend = DummyAttentionBackend()
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from vllm.config import VllmConfig
from vllm.platforms import Platform, PlatformEnum


class DummyPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
device_name = "dummy"

def __init__(self):
super().__init__()

@classmethod
def get_device_name(cls) -> str:
return "dummy"

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = \
"vllm_add_dummy_platform.my_worker.DummyWorker"

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from my_model_runner import DummyModelRunner


class DummyCacheEngine:
pass


class DummyWorker:

def __init__(self):
self.cache_engine = List[DummyCacheEngine]
self.model_runner = DummyModelRunner()
7 changes: 7 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.plugins import load_general_plugins
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

from .version import __version__, __version_tuple__

# Load general plugins first when the module is imported to make sure that all
# necessary global variables are set. Such as the `current_platform`.
load_general_plugins()

__all__ = [
"__version__",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
32 changes: 16 additions & 16 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,6 @@ def unified_attention_fake(
return torch.empty_like(query).contiguous()


direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)


def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -308,10 +299,19 @@ def unified_attention_with_output_fake(
return


direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
def register_custom_ops():
"""Register custom ops for attention."""
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
142 changes: 33 additions & 109 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,47 @@
from typing import Any

from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .registry import PlatformRegistry, detect_current_platform

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True
_current_platform: Platform = UnspecifiedPlatform()

is_rocm = False

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
# Get the current platform from the registry first. If the current
# platform is not set, try to detect the current platform.
global _current_platform
if PlatformRegistry.current_platform is not None:
_current_platform = PlatformRegistry.get_current_platform_cls()()
else:
_current_platform = detect_current_platform()

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass
# Register custom ops for the current platform.
from vllm.attention.layer import register_custom_ops
register_custom_ops()

is_xpu = False

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass
class CurrentPlatform(Platform):
"""A wrapper that provides an interface to the current platform.
`current_platform` is imported to many modules once vLLM is imported.
Updating `current_platform` value directly will not work in those modules.
So it needs the wrapper here to provide a dynamic platform loading
mechanism.
is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass
This class can make sure that the `current_platform` is always up-to-date.
"""

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
def __getattribute__(self, name: str) -> Any:
"""If the attribute is not found, go pass to the current platform."""
global _current_platform
# Go pass to the current platform.
return _current_platform.__getattribute__(name)

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
# The global variable for other modules to use.
current_platform = CurrentPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
Loading

0 comments on commit fc4da42

Please sign in to comment.