Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 23, 2024
1 parent d6f685a commit 75eedc0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 42 deletions.
10 changes: 1 addition & 9 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op

CUSTOM_OPS_REGISTERED = False


class Attention(nn.Module):
"""Attention layer.
Expand Down Expand Up @@ -131,7 +129,6 @@ def forward(
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
_register_custom_ops()

if self.use_direct_call:
return self.impl.forward(query,
Expand Down Expand Up @@ -301,12 +298,8 @@ def unified_attention_with_output_fake(
return


def _register_custom_ops():
def register_custom_ops():
"""Register custom ops for attention."""
global CUSTOM_OPS_REGISTERED
if CUSTOM_OPS_REGISTERED:
return

direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
Expand All @@ -321,4 +314,3 @@ def _register_custom_ops():
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
CUSTOM_OPS_REGISTERED = True
61 changes: 30 additions & 31 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,6 @@
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .registry import PlatformRegistry, detect_current_platform

_current_platform: Platform = UnspecifiedPlatform()


def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
global _current_platform
# Get the current platform from the registry first. If the current platform
# is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
_current_platform = PlatformRegistry.get_current_platform_cls()
else:
_current_platform = detect_current_platform()


def update_current_platform(device_name: str):
"""Update the current platform. This function is used by users to set the
current platform by hand."""
global _current_platform
PlatformRegistry.set_current_platform(device_name)
_current_platform = PlatformRegistry.get_current_platform_cls()


class CurrentPlatform:
"""A wrapper that provides an interface to the current platform.
Expand All @@ -37,17 +15,38 @@ class CurrentPlatform:
"""

def __init__(self):
self.platform = _current_platform

def _refresh_current_platform(self):
"""Refresh the current platform dynamically."""
global _current_platform
if _current_platform is not self.platform:
self.platform = _current_platform
self.platform: Platform = UnspecifiedPlatform()

def initialize_current_platform(self):
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
# Get the current platform from the registry first. If the current
# platform is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
self.platform = PlatformRegistry.get_current_platform_cls()
else:
self.platform = detect_current_platform()

# Register custom ops for the current platform.
from vllm.attention.layer import register_custom_ops
register_custom_ops()

def update_current_platform(self, device_name: str):
"""Update the current platform. This function is used by users to set
the current platform by hand."""
PlatformRegistry.set_current_platform(device_name)
self.platform = PlatformRegistry.get_current_platform_cls()

def __getattr__(self, name):
"""Go pass to the current platform."""
self._refresh_current_platform()
"""Get the attribute. If the attribute is not found, go pass to the
current platform."""
if name == 'platform':
return self.platform
if name == 'initialize_current_platform':
return self.initialize_current_platform
if name == 'update_current_platform':
return self.update_current_platform
# Go pass to the current platform.
return getattr(self.platform, name)


Expand Down
4 changes: 2 additions & 2 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import vllm.envs as envs
from vllm.platforms import current_platform, initialize_current_platform
from vllm.platforms import current_platform

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,7 +52,7 @@ def load_general_plugins():
plugin.name)
# initialize current platform should be called after all plugins are
# loaded.
initialize_current_platform()
current_platform.initialize_current_platform()

plugins_loaded = True

Expand Down

0 comments on commit 75eedc0

Please sign in to comment.