Skip to content

Commit

Permalink
[torch.compile] fast inductor (vllm-project#11108)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
2 people authored and BKitor committed Dec 30, 2024
1 parent 9f1ec6b commit a5a89fc
Show file tree
Hide file tree
Showing 3 changed files with 624 additions and 7 deletions.
213 changes: 210 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import ast
import copy
import dataclasses
import os
import pprint
import time
from collections import defaultdict
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
Expand All @@ -21,6 +25,122 @@
logger = init_logger(__name__)


class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""

def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
"inductor_hash_cache.py")
if disabled:
return
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
if os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f:
self.deserialize(f.read())

def deserialize(self, data: str):
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

def save_to_file(self):
if self.disabled:
return
with open(self.cache_file_path, "w") as f:
f.write(self.serialize())

def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
if self.disabled:
return False
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""

def __init__(self) -> None:
self.guards: List[Any] = []

def evaluate_guards_expression(self, *args, **kwargs):
return True

def get_pruned_guards(self, *args, **kwargs):
return []

def produce_guards_expression(self, *args, **kwargs):
return ""


def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
Expand Down Expand Up @@ -55,9 +175,93 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

cache_data = compilation_config.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
"Directly lookup the graph for shape %s from the cache",
str(runtime_shape)) # noqa
logger.debug(
"directly lookup the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), hash_str)
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
)

# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any

# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)

# this is the graph we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
return out

def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return

def _get_shape_env():
return AlwaysHitShapeEnv()

with patch(# for hijacking the hash of the compiled graph
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash), \
patch(# for providing a dummy shape environment
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env), \
patch(# for forcing the graph to be cached
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache):
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
Expand Down Expand Up @@ -457,6 +661,9 @@ def __call__(self, *args) -> Any:

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:

# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

if not entry.use_cudagraph:
Expand Down
Loading

0 comments on commit a5a89fc

Please sign in to comment.