Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] fast inductor #11108

Merged
merged 40 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
02876a0
show draft
youkaichao Dec 11, 2024
de61c66
Merge branch 'main' into fast_inductor
youkaichao Dec 11, 2024
d59d2d8
polish
youkaichao Dec 11, 2024
9f599cc
fix
youkaichao Dec 11, 2024
fc0f60f
change read
youkaichao Dec 11, 2024
3a8e678
use InductorHashCache class
youkaichao Dec 11, 2024
9515e63
rename key to hash_str
youkaichao Dec 11, 2024
82d60e6
comment
youkaichao Dec 11, 2024
84aa0b3
python-style
youkaichao Dec 11, 2024
56a03c7
comments
youkaichao Dec 11, 2024
e3f0a14
add comments
youkaichao Dec 11, 2024
51a1efb
fix serialize
youkaichao Dec 11, 2024
b20435c
fix
youkaichao Dec 11, 2024
c41a8d4
fix
youkaichao Dec 11, 2024
eb0dba2
fix high-order ops
youkaichao Dec 12, 2024
f48a4f6
fix splitting ops
youkaichao Dec 12, 2024
955989c
move file writing inside InductorHashCache
youkaichao Dec 13, 2024
4f1c4a0
add comments
youkaichao Dec 13, 2024
6c325d9
rename
youkaichao Dec 13, 2024
37d744d
give error message
youkaichao Dec 13, 2024
516db43
move to another file
youkaichao Dec 13, 2024
28b98fb
add comments
youkaichao Dec 13, 2024
2a7f729
add more factors to consider
youkaichao Dec 13, 2024
8104cfa
Update vllm/config.py
youkaichao Dec 13, 2024
75cf1f5
typo
youkaichao Dec 13, 2024
75da0b6
Merge branch 'main' into fast_inductor
youkaichao Dec 13, 2024
d7946ab
merge conflict
youkaichao Dec 13, 2024
ee60692
consider all factors
youkaichao Dec 13, 2024
5c5eb2b
add vllm version
youkaichao Dec 13, 2024
b346bd9
bugfix
youkaichao Dec 13, 2024
f58b566
add disable
youkaichao Dec 13, 2024
a264175
redirect inductor
youkaichao Dec 13, 2024
76fcc99
add more logging
youkaichao Dec 13, 2024
59365c4
add more logging
youkaichao Dec 13, 2024
aacf7c8
fix shape 1
youkaichao Dec 13, 2024
37829dd
fix tests
youkaichao Dec 14, 2024
c4bc393
fix inductor
youkaichao Dec 15, 2024
4b31b4f
Merge branch 'main' into fast_inductor
youkaichao Dec 15, 2024
40f1355
add warning
youkaichao Dec 16, 2024
c4478c8
add info logging
youkaichao Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 109 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,44 @@
logger = init_logger(__name__)


class AlwaysHitShapeEnv:
"""
Why do we need this class:

For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.

For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.

By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.

The following dummy methods are obtained by trial-and-error
until it works.
"""

def __init__(self) -> None:
self.guards: List[Any] = []

def evaluate_guards_expression(self, *args, **kwargs):
return True

def get_pruned_guards(self, *args, **kwargs):
return []

def produce_guards_expression(self, *args, **kwargs):
return ""


def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
Expand Down Expand Up @@ -55,9 +93,71 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

cache_data = compilation_config.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any

# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)

# this is the graph we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

youkaichao marked this conversation as resolved.
Show resolved Hide resolved
def mocked_compiled_fx_graph_hash(*args, **kwargs):
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
return out

def _check_can_cache(*args, **kwargs):
# no error means it can be cached
# vLLM computation graph can always be cached
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return

def _get_shape_env():
return AlwaysHitShapeEnv()

with patch(# for hijacking the hash of the compiled graph
"torch._inductor.codecache.compiled_fx_graph_hash",
mocked_compiled_fx_graph_hash), \
patch(# for providing a dummy shape environment
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env), \
patch(# for forcing the graph to be cached
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache):
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
Expand Down Expand Up @@ -457,6 +557,12 @@ def __call__(self, *args) -> Any:

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:

# save the hash of the inductor graph for the next run
with open(self.compilation_config.inductor_hash_cache_path,
"w") as f:
f.write(self.compilation_config.inductor_hash_cache.
serialize())
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
end_monitoring_torch_compile(self.vllm_config)

if not entry.use_cudagraph:
Expand Down
108 changes: 104 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import enum
import hashlib
import json
import os
import pprint
import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
Expand Down Expand Up @@ -2212,6 +2215,53 @@ class CompilationLevel:
PIECEWISE = 3


class InductorHashCache:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to place this class into vllm.compilation.backends , but then needs to be lazily imported, and pydantic will complain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put into a separate file?

"""
Disk format: a Python list of tuples, each tuple is
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
(runtime_shape, graph_index, hash_str)
We use list of tuple for readability.

In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
"""

def __init__(self, cache_file_path: str):
self.cache_file_path = cache_file_path
self.cache: defaultdict = defaultdict(dict)
if os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f:
self.deserialize(f.read())

def deserialize(self, data: str):
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value


class CompilationConfig(BaseModel):
"""
Configuration for compilation.
Expand All @@ -2223,6 +2273,9 @@ class CompilationConfig(BaseModel):
- 2: dynamo once.
- 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- cache_dir: the directory to store the compiled graph, to
accelerate Inductor compilation. By default, it will use
model-related information to generate a cache directory.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
Expand Down Expand Up @@ -2291,12 +2344,10 @@ class CompilationConfig(BaseModel):
""" # noqa
level: int = 0
debug_dump_path: str = ""
cache_dir: str = ""
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
])
splitting_ops: List[str] = Field(default=None) # type: ignore

use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
Expand Down Expand Up @@ -2354,6 +2405,9 @@ def model_post_init(self, __context: Any) -> None:
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
compilation_time: float = PrivateAttr
# should be InductorHashCache , but Pydantic does not support it
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
inductor_hash_cache: Any = PrivateAttr
inductor_hash_cache_path: str = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
Expand All @@ -2375,6 +2429,19 @@ def model_post_init(self, __context: Any) -> None:
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

if self.splitting_ops is None:
if envs.VLLM_USE_V1:
# v1 must split the graph on attention ops
# for piecewise cudagraph
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
else:
# v0 can use full graph compilation without splitting,
# splitting is optional.
self.splitting_ops = []

for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
Expand Down Expand Up @@ -2414,6 +2481,16 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE

# every rank writes to its own cache dir
self.cache_dir = os.path.join(
self.cache_dir, f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(self.cache_dir, exist_ok=True)
self.inductor_hash_cache_path = os.path.join(self.cache_dir,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better if we also save a serialized form of the config, but we need to design the serialized format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which config is not serializable? Isn't CompilationConfig serializable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is serializable, but i want a human-readable form, so that we can also manually check the config.

"inductor_hash_cache.py")
self.inductor_hash_cache = InductorHashCache(
self.inductor_hash_cache_path)

from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)

Expand Down Expand Up @@ -2649,6 +2726,29 @@ def __post_init__(self):
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

if self.model_config is not None and \
not self.compilation_config.cache_dir:
# generate a cache directory based on the model information
# TODO: consider more factors that will affect model forward,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I missed some quantization args that can affect model execution, but I don't know how to pull out all factors that affect quantization.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM version? We can add the git SHA to the key

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be a large source of potential bugs so definitely should be careful here. Most quantization related stuff from NM goes in the model_config but there's a lot of arguments to LLM that can affect things like dtype and quantization. Are these in the key already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet. that's why I want to ask for reviews.

one direction is we consider all factors affecting compilation, and we can use compilation cache by default.

another approach is we don't cache by default, but tell user the cache directory, and users can specify the cache directory if they know nothing changed.

which one would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should always check the known factors when we cache, and expose an accessible switch for enabling/disabling caching. And then it's less important whether it's on by default or not. And for that decision @robertgshaw2-neuralmagic should chime in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more factors to consider in 2a7f729 . Let me know if I miss anything.

# and hence affect the compilation.
model = self.model_config.model
assert self.parallel_config is not None
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
splitting_ops = sorted(self.compilation_config.splitting_ops)
compilation_factors = (tp_size, pp_size, model, splitting_ops)
import hashlib
hash_str = hashlib.md5(
str(compilation_factors).encode()).hexdigest()[:10]
cache_dir = os.path.join(envs.VLLM_CACHE_ROOT,
"torch_compile_cache", hash_str)
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

if self.compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("Using cache directory: %s for vLLM's torch.compile",
self.compilation_config.cache_dir)

current_platform.check_and_update_config(self)

if not self.instance_id:
Expand Down
Loading