Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/N][torch.compile] consolidate custom op logging #10399

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import warnings
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
Literal, Mapping, Optional, Set, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
Union)

import torch
from pydantic import BaseModel, Field, PrivateAttr
Expand Down Expand Up @@ -2169,6 +2170,10 @@ class CompilationConfig(BaseModel):
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr

def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL

Expand All @@ -2190,6 +2195,9 @@ def model_post_init(self, __context: Any) -> None:
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func

self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()

def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ def forward_hpu(self, *args, **kwargs):
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.

compilation_config = get_current_vllm_config().compilation_config
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
else:
compilation_config.disabled_custom_ops.update(
[self.__class__.name])

if not enabled:
return self.forward_native
Expand Down
4 changes: 4 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def set_current_vllm_config(vllm_config: VllmConfig):
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
_current_vllm_config = old_vllm_config


Expand Down