Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Add platform pluggable framework #11222

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,32 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/openai
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests

# Plugin test need a single pipeline, since it will install a new fake platform plugin.
- label: Generic Plugin Test
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models
commands:
# test model plugin first since it needs to install a new fake platform when platform plugin is installed.
- pip install -e ./plugins/vllm_add_dummy_model
- VLLM_PLUGINS='' pytest -v -s plugins/test_model_plugin_disabled.py
- pytest -v -s plugins/test_model_plugin.py
# test platform plugin second.
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins/test_platform_plugin.py

DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
- label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down Expand Up @@ -333,8 +348,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py

Expand Down Expand Up @@ -469,6 +482,7 @@ steps:
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
# distributed test need to be run on 2 gpus, move this test to plugins test once the plugin test runs on 2 gpus.
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
Expand Down
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Every plugin has three parts:

## What Can Plugins Do?

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platforms into vLLM. This is done by calling `ModelRegistry.register_model` or `PlatformRegistry.register_platform` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.

## Guidelines for Writing Plugins

Expand Down
3 changes: 1 addition & 2 deletions tests/distributed/test_distributed_oot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ..entrypoints.openai.test_oot_registration import (
run_and_test_dummy_opt_api_server)
from ..plugins.test_model_plugin import run_and_test_dummy_opt_api_server


def test_distributed_oot(dummy_opt_path: str):
Expand Down
42 changes: 0 additions & 42 deletions tests/entrypoints/openai/test_oot_registration.py

This file was deleted.

Empty file added tests/plugins/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
import os

import pytest

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test


@fork_new_process_for_each_test
def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)
from ..utils import (VLLM_PATH, RemoteOpenAIServer,
fork_new_process_for_each_test)


@fork_new_process_for_each_test
Expand Down Expand Up @@ -78,3 +69,45 @@ def test_oot_registration_multimodal(dummy_llava_path):
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""


chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
]
with RemoteOpenAIServer(model, server_args) as server:
client = server.get_client()
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""


def test_oot_registration_for_api_server(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path)
12 changes: 12 additions & 0 deletions tests/plugins/test_model_plugin_disabled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from vllm import LLM

# The test in this file should be run with env VLLM_PLUGINS='', for example:
# VLLM_PLUGINS='' pytest -v -s test_model_plugin_disabled.py


def test_plugin_disabled(dummy_opt_path):
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)
10 changes: 10 additions & 0 deletions tests/plugins/test_platform_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from vllm.platforms import PlatformRegistry, current_platform


def test_current_platform_register():
# make sure the platform is registered
assert PlatformRegistry.current_platform == "my_platform"
# make sure the platform is loaded
assert current_platform.device_name == "dummy"
assert current_platform.is_async_output_supported(enforce_eager=True) \
is False
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
wangxiyuan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform.DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class DummyAttentionImpl:

def forward(self):
pass


class DummyAttentionBackend:

def __init__(self):
pass

def get_impl_cls(self):
return DummyAttentionImpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from my_attention import DummyAttentionBackend


class DummyModelRunner:

def __init__(self):
self.attn_backend = DummyAttentionBackend()
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from vllm.config import VllmConfig
from vllm.platforms import Platform, PlatformEnum


class DummyPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
device_name = "dummy"

def __init__(self):
super().__init__()

@classmethod
def get_device_name(cls) -> str:
return "dummy"

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = \
"vllm_add_dummy_platform.my_worker.DummyWorker"

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from my_model_runner import DummyModelRunner


class DummyCacheEngine:
pass


class DummyWorker:

def __init__(self):
self.cache_engine = List[DummyCacheEngine]
self.model_runner = DummyModelRunner()
7 changes: 7 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.plugins import load_general_plugins
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

from .version import __version__, __version_tuple__

# Load general plugins first when the module is imported to make sure that all
# necessary global variables are set. Such as the `current_platform`.
load_general_plugins()

__all__ = [
"__version__",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
32 changes: 16 additions & 16 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,6 @@ def unified_attention_fake(
return torch.empty_like(query).contiguous()


direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)


def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -308,10 +299,19 @@ def unified_attention_with_output_fake(
return


direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
def register_custom_ops():
"""Register custom ops for attention."""
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
Loading
Loading