Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] fast inductor #11108

Merged
merged 40 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
02876a0
show draft
youkaichao Dec 11, 2024
de61c66
Merge branch 'main' into fast_inductor
youkaichao Dec 11, 2024
d59d2d8
polish
youkaichao Dec 11, 2024
9f599cc
fix
youkaichao Dec 11, 2024
fc0f60f
change read
youkaichao Dec 11, 2024
3a8e678
use InductorHashCache class
youkaichao Dec 11, 2024
9515e63
rename key to hash_str
youkaichao Dec 11, 2024
82d60e6
comment
youkaichao Dec 11, 2024
84aa0b3
python-style
youkaichao Dec 11, 2024
56a03c7
comments
youkaichao Dec 11, 2024
e3f0a14
add comments
youkaichao Dec 11, 2024
51a1efb
fix serialize
youkaichao Dec 11, 2024
b20435c
fix
youkaichao Dec 11, 2024
c41a8d4
fix
youkaichao Dec 11, 2024
eb0dba2
fix high-order ops
youkaichao Dec 12, 2024
f48a4f6
fix splitting ops
youkaichao Dec 12, 2024
955989c
move file writing inside InductorHashCache
youkaichao Dec 13, 2024
4f1c4a0
add comments
youkaichao Dec 13, 2024
6c325d9
rename
youkaichao Dec 13, 2024
37d744d
give error message
youkaichao Dec 13, 2024
516db43
move to another file
youkaichao Dec 13, 2024
28b98fb
add comments
youkaichao Dec 13, 2024
2a7f729
add more factors to consider
youkaichao Dec 13, 2024
8104cfa
Update vllm/config.py
youkaichao Dec 13, 2024
75cf1f5
typo
youkaichao Dec 13, 2024
75da0b6
Merge branch 'main' into fast_inductor
youkaichao Dec 13, 2024
d7946ab
merge conflict
youkaichao Dec 13, 2024
ee60692
consider all factors
youkaichao Dec 13, 2024
5c5eb2b
add vllm version
youkaichao Dec 13, 2024
b346bd9
bugfix
youkaichao Dec 13, 2024
f58b566
add disable
youkaichao Dec 13, 2024
a264175
redirect inductor
youkaichao Dec 13, 2024
76fcc99
add more logging
youkaichao Dec 13, 2024
59365c4
add more logging
youkaichao Dec 13, 2024
aacf7c8
fix shape 1
youkaichao Dec 13, 2024
37829dd
fix tests
youkaichao Dec 14, 2024
c4bc393
fix inductor
youkaichao Dec 15, 2024
4b31b4f
Merge branch 'main' into fast_inductor
youkaichao Dec 15, 2024
40f1355
add warning
youkaichao Dec 16, 2024
c4478c8
add info logging
youkaichao Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 205 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import ast
import copy
import dataclasses
import os
import pprint
import time
from collections import defaultdict
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
Expand All @@ -21,6 +25,126 @@
logger = init_logger(__name__)


class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
We use list of tuple for readability.

In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.

The data is essentially `Dict[Optional[int], Dict[int, str]]`,
we don't use json here because json doesn't support int as key.

TODO: better off-the-shelf solution to serialize the data?
"""

def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
"inductor_hash_cache.py")
if disabled:
return
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
Comment on lines +51 to +59
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redirect inductor/triton cache to the vllm cache location.

if os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f:
self.deserialize(f.read())

def deserialize(self, data: str):
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

def save_to_file(self):
if self.disabled:
return
with open(self.cache_file_path, "w") as f:
f.write(self.serialize())

def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
if self.disabled:
return False
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
if runtime_shape == 1:
# FIXME: it seems the cache only works for runtime_shape >= 2
# we need to investigate why
return
self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
"""
Why do we need this class:

For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.

For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.

By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.

The following dummy methods are obtained by trial-and-error
until it works.
"""

def __init__(self) -> None:
self.guards: List[Any] = []

def evaluate_guards_expression(self, *args, **kwargs):
return True

def get_pruned_guards(self, *args, **kwargs):
return []

def produce_guards_expression(self, *args, **kwargs):
return ""


def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
Expand Down Expand Up @@ -55,9 +179,84 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

cache_data = compilation_config.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
logger.debug(
"directly lookup the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), hash_str)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
)

# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any

# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)

# this is the graph we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

youkaichao marked this conversation as resolved.
Show resolved Hide resolved
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
return out

def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return

def _get_shape_env():
return AlwaysHitShapeEnv()

with patch(# for hijacking the hash of the compiled graph
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash), \
patch(# for providing a dummy shape environment
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env), \
patch(# for forcing the graph to be cached
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache):
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
Expand Down Expand Up @@ -457,6 +656,9 @@ def __call__(self, *args) -> Any:

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:

# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

if not entry.use_cudagraph:
Expand Down
Loading
Loading