Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiyuan committed Dec 16, 2024
1 parent 69ba344 commit 65b83e4
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 114 deletions.
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register(
"my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm.platforms import Platform


class DummyPlatform(Platform):
def __init__(self):
super().__init__()

def get_device_name(self) -> str:
return "dummy"
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -22,6 +23,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
220 changes: 106 additions & 114 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,115 @@
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .registry import PlatformRegistry # noqa: F401

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False
# Get the current platform from the registry first. If the current platform is
# not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
current_platform = PlatformRegistry.get_current_platform()
else:
CurrentPlatform = None
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

try:
import amdsmi
amdsmi.amdsmi_init()
# Load TPU Platform
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass

is_xpu = False

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
# While it's technically possible to install libtpu on a non-TPU
# machine, this is a very uncommon scenario. Therefore, we assume that
# libtpu is installed if and only if the machine has TPUs.
import libtpu # noqa: F401
from .tpu import TpuPlatform as CurrentPlatform

Check failure on line 23 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[TpuPlatform]", local name has type "object") [assignment]

Check failure on line 23 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[TpuPlatform]", local name has type "object") [assignment]

Check failure on line 23 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[TpuPlatform]", local name has type "object") [assignment]

Check failure on line 23 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[TpuPlatform]", local name has type "object") [assignment]
except Exception:
pass

# Load CUDA Platform
if not CurrentPlatform:
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
from .cuda import CudaPlatform as CurrentPlatform

Check failure on line 34 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 34 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 34 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 34 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
from .cuda import CudaPlatform as CurrentPlatform

Check failure on line 46 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 46 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 46 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

Check failure on line 46 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[CudaPlatformBase]", local name has type "object") [assignment]

# Load ROCm Platform
if not CurrentPlatform:
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
from .rocm import RocmPlatform as CurrentPlatform

Check failure on line 55 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[RocmPlatform]", local name has type "object") [assignment]

Check failure on line 55 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[RocmPlatform]", local name has type "object") [assignment]

Check failure on line 55 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[RocmPlatform]", local name has type "object") [assignment]

Check failure on line 55 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[RocmPlatform]", local name has type "object") [assignment]
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

# Load HPU Platform
if not CurrentPlatform:
try:
from importlib import util
assert util.find_spec('habana_frameworks') is not None
from.hpu import HpuPlatform as CurrentPlatform

Check failure on line 66 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[HpuPlatform]", local name has type "object") [assignment]

Check failure on line 66 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[HpuPlatform]", local name has type "object") [assignment]

Check failure on line 66 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[HpuPlatform]", local name has type "object") [assignment]

Check failure on line 66 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[HpuPlatform]", local name has type "object") [assignment]
except Exception:
pass

# Load XPU Platform
if not CurrentPlatform:
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
from.xpu import XPUPlatform as CurrentPlatform

Check failure on line 78 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[XPUPlatform]", local name has type "object") [assignment]

Check failure on line 78 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[XPUPlatform]", local name has type "object") [assignment]

Check failure on line 78 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[XPUPlatform]", local name has type "object") [assignment]

Check failure on line 78 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[XPUPlatform]", local name has type "object") [assignment]
except Exception:
pass

# Load CPU Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "cpu" in version("vllm")
from.cpu import CpuPlatform as CurrentPlatform

Check failure on line 87 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[CpuPlatform]", local name has type "object") [assignment]

Check failure on line 87 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[CpuPlatform]", local name has type "object") [assignment]

Check failure on line 87 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[CpuPlatform]", local name has type "object") [assignment]

Check failure on line 87 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[CpuPlatform]", local name has type "object") [assignment]
except Exception:
pass

# Load Neuron Platform
if not CurrentPlatform:
try:
import transformers_neuronx # noqa: F401
from.neuron import NeuronPlatform as CurrentPlatform

Check failure on line 95 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[NeuronPlatform]", local name has type "object") [assignment]

Check failure on line 95 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[NeuronPlatform]", local name has type "object") [assignment]

Check failure on line 95 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[NeuronPlatform]", local name has type "object") [assignment]

Check failure on line 95 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[NeuronPlatform]", local name has type "object") [assignment]
except ImportError:
pass

# Load OpenVINO Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "openvino" in version("vllm")
from.openvino import OpenVinoPlatform as CurrentPlatform

Check failure on line 104 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible import of "CurrentPlatform" (imported name has type "type[OpenVinoPlatform]", local name has type "object") [assignment]

Check failure on line 104 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible import of "CurrentPlatform" (imported name has type "type[OpenVinoPlatform]", local name has type "object") [assignment]

Check failure on line 104 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible import of "CurrentPlatform" (imported name has type "type[OpenVinoPlatform]", local name has type "object") [assignment]

Check failure on line 104 in vllm/platforms/__init__.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible import of "CurrentPlatform" (imported name has type "type[OpenVinoPlatform]", local name has type "object") [assignment]
except Exception:
pass

if CurrentPlatform:
device_name = CurrentPlatform.get_device_name()
PlatformRegistry.set_current_platform(device_name)
current_platform = PlatformRegistry.load_platfrom_cls(device_name)
else:
current_platform = UnspecifiedPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
59 changes: 59 additions & 0 deletions vllm/platforms/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from dataclasses import dataclass, field
from typing import Dict, Optional, Type

Check failure on line 2 in vllm/platforms/registry.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/platforms/registry.py:2:36: F401 `typing.Type` imported but unused

from vllm.platforms import Platform


_VLLM_PLATFORMS = {
"cpu": "vllm.platforms.cpu:CpuPlatform",
"cuda": "vllm.platforms.cuda:CudaPlatform",
"hpu": "vllm.platforms.hpu:HpuPlatform",
"neuron": "vllm.platforms.neuron:NeuronPlatform",
"openvino": "vllm.platforms.openvino:OpenVinoPlatform",
"rocm": "vllm.platforms.rocm:RocmPlatform",
"tpu": "vllm.platforms.tpu:TpuPlatform",
"xpu": "vllm.platforms.xpu:XPUPlatform",
}


@dataclass
class _PlatformRegistry:
platforms: Dict[str, str] = field(default_factory=dict)
current_platform: Optional[str] = None

def register(self, device_name: str, platform: str):
"""Register a platform by device name. This function is called by the
platform plugin."""
if device_name in self.platforms:
raise ValueError(f"Platform {device_name} already registered.")
self.platforms[device_name] = platform

def load_platfrom_cls(self, device_name: str) -> Platform:
"""Load a platform object by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
platform_cls_str = self.platforms[device_name]
module_name, cls_name = platform_cls_str.split(":")
module = __import__(module_name, fromlist=[cls_name])
return getattr(module, cls_name)

def set_current_platform(self, device_name: str):
"""Set the current platform by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
self.current_platform = device_name

def get_current_platform(self) -> Platform:
"""Get the current platform object."""
if self.current_platform is None:
raise ValueError("No current platform set.")
return self.load_platfrom_cls(self.current_platform)

PlatformRegistry = _PlatformRegistry({
device_name: platform
for device_name, platform in _VLLM_PLATFORMS.items()
})

0 comments on commit 65b83e4

Please sign in to comment.