Skip to content

Commit

Permalink
[Misc] Improve type annotations for support_torch_compile (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#10763)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and weilong.yu committed Dec 13, 2024
1 parent b50993d commit f42c6c2
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload

import torch
import torch.nn as nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
Expand All @@ -12,10 +13,27 @@

logger = init_logger(__name__)

_T = TypeVar("_T", bound=type[nn.Module])


@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
...


@overload
def support_torch_compile(cls: _T) -> _T:
...


def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
Expand Down Expand Up @@ -66,7 +84,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
computation graph.
"""

def cls_decorator_helper(cls: type):
def cls_decorator_helper(cls: _T) -> _T:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
Expand Down Expand Up @@ -105,8 +123,10 @@ def cls_decorator_helper(cls: type):
return cls_decorator_helper


def _support_torch_compile(cls: type,
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
Expand All @@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__ # type: ignore
old_init = cls.__init__

def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
Expand All @@ -135,7 +155,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

cls.__init__ = __init__ # type: ignore
cls.__init__ = __init__

def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
Expand Down Expand Up @@ -180,5 +200,5 @@ def __call__(self, *args, **kwargs):
model_output = self.forward(*args, **kwargs)
return model_output

cls.__call__ = __call__ # type: ignore
cls.__call__ = __call__
return cls

0 comments on commit f42c6c2

Please sign in to comment.