From a5a89fcaae40cfab8f140718b0c58aee28a564cf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 16:15:22 -0800 Subject: [PATCH] [torch.compile] fast inductor (#11108) Signed-off-by: youkaichao Co-authored-by: Tyler Michael Smith --- vllm/compilation/backends.py | 213 +++++++++++++++++- vllm/config.py | 415 ++++++++++++++++++++++++++++++++++- vllm/envs.py | 3 + 3 files changed, 624 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4a5dc337d01b8..0c7bbfe599b02 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,6 +1,10 @@ +import ast import copy import dataclasses +import os +import pprint import time +from collections import defaultdict from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -21,6 +25,122 @@ logger = init_logger(__name__) +class InductorHashCache: + """ + Disk format: a Python list of tuples, each tuple is + (runtime_shape, graph_index, hash_str) + We use list of tuple for readability. + + In-memory format: a defaultdict of dict, where the key is + runtime_shape, and the value is a dict of graph_index to hash_str. + + The data is essentially `Dict[Optional[int], Dict[int, str]]`, + we don't use json here because json doesn't support int as key. + + TODO: better off-the-shelf solution to serialize the data? + """ + + def __init__(self, cache_dir: str, disabled: bool = False): + self.cache: defaultdict = defaultdict(dict) + self.disabled = disabled + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, + "inductor_hash_cache.py") + if disabled: + return + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + if os.path.exists(self.cache_file_path): + with open(self.cache_file_path) as f: + self.deserialize(f.read()) + + def deserialize(self, data: str): + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + list_data = ast.literal_eval(data) + for runtime_shape, graph_index, hash_str in list_data: + self.cache[runtime_shape][graph_index] = hash_str + + def serialize(self) -> str: + data = [] + for runtime_shape, graph_index_to_hash_str in self.cache.items(): + for graph_index, hash_str in graph_index_to_hash_str.items(): + data.append((runtime_shape, graph_index, hash_str)) + printer = pprint.PrettyPrinter(indent=4) + return printer.pformat(data) + + def save_to_file(self): + if self.disabled: + return + with open(self.cache_file_path, "w") as f: + f.write(self.serialize()) + + def __contains__(self, key: Tuple[Optional[int], int]) -> bool: + if self.disabled: + return False + runtime_shape, graph_index = key + return runtime_shape in self.cache and graph_index in self.cache[ + runtime_shape] + + def __getitem__(self, key: Tuple[Optional[int], int]) -> str: + if self.disabled: + raise KeyError("cannot read from disabled cache") + runtime_shape, graph_index = key + return self.cache[runtime_shape][graph_index] + + def __setitem__(self, key: Tuple[Optional[int], int], value: str): + # setitem for disabled cache is fine, because we + # don't actually write to the disk + runtime_shape, graph_index = key + self.cache[runtime_shape][graph_index] = value + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -55,9 +175,93 @@ def wrap_inductor(graph, # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) + + cache_data = compilation_config.inductor_hash_cache + if (runtime_shape, graph_index) in cache_data: + # we compiled this graph before + # so we can directly lookup the compiled graph via hash + hash_str = cache_data[(runtime_shape, graph_index)] + if graph_index == 0: + # adds some info logging for the first graph + logger.info( + "Directly lookup the graph for shape %s from the cache", + str(runtime_shape)) # noqa + logger.debug( + "directly lookup the %s-th graph for shape %s via hash %s", + graph_index, str(runtime_shape), hash_str) + from torch._inductor.codecache import FxGraphCache + with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv()): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the graph we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + else: + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + from torch._inductor.codecache import compiled_fx_graph_hash + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + # store the hash in the cache + nonlocal cache_data + cache_data[(runtime_shape, graph_index)] = out[0] + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug("store the %s-th graph for shape %s via hash %s", + graph_index, str(runtime_shape), out[0]) + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env(): + return AlwaysHitShapeEnv() + + with patch(# for hijacking the hash of the compiled graph + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash), \ + patch(# for providing a dummy shape environment + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env), \ + patch(# for forcing the graph to be cached + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache): + compiled_graph = compile_fx(graph, + example_inputs, + config_patches=current_config) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: @@ -457,6 +661,9 @@ def __call__(self, *args) -> Any: # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: + + # save the hash of the inductor graph for the next run + self.compilation_config.inductor_hash_cache.save_to_file() end_monitoring_torch_compile(self.vllm_config) if not entry.use_cudagraph: diff --git a/vllm/config.py b/vllm/config.py index fce8011be4015..9cfd08024ea7b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,6 +3,7 @@ import enum import hashlib import json +import os import warnings from contextlib import contextmanager from dataclasses import dataclass, field, replace @@ -162,6 +163,30 @@ class ModelConfig: which allows no processors. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.quantization_param_path) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.trust_remote_code) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __init__(self, model: str, task: Union[TaskOption, Literal["draft"]], @@ -203,6 +228,8 @@ def __init__(self, self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta if hf_overrides is None: hf_overrides = {} @@ -832,6 +859,24 @@ class CacheConfig: cpu_offload_gb: Size of the CPU offload buffer in GiB. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __init__( self, block_size: int, @@ -928,6 +973,24 @@ class TokenizerPoolConfig: pool_type: Union[str, Type["BaseTokenizerGroup"]] extra_config: dict + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if self.pool_type not in ("ray", ) and not isinstance( self.pool_type, type): @@ -1010,6 +1073,24 @@ class LoadConfig: default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} if isinstance(model_loader_extra_config, str): @@ -1073,6 +1154,19 @@ class ParallelConfig: rank: int = 0 + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -1209,6 +1303,24 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self) -> None: if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -1286,6 +1398,25 @@ class DeviceConfig: device: Optional[torch.device] device_type: str + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection @@ -1313,6 +1444,24 @@ class SpeculativeConfig: decoding with top-1 proposals. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # spec decode does not use `torch.compile` yet. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1753,6 +1902,24 @@ class LoRAConfig: long_lora_scaling_factors: Optional[Tuple[float]] = None bias_enabled: bool = False + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # LoRA is not compatible with `torch.compile` . + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast # majority of applications. @@ -1802,6 +1969,24 @@ class PromptAdapterConfig: max_cpu_prompt_adapters: Optional[int] = None prompt_adapter_dtype: Optional[torch.dtype] = None + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if self.max_prompt_adapters < 1: @@ -1830,6 +2015,24 @@ class MultiModalConfig: for each :class:`~vllm.multimodal.MultiModalPlugin`. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + # TODO: Add configs to init vision tower or not. @@ -1869,6 +2072,24 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @staticmethod def from_json(json_str: str) -> "PoolerConfig": return PoolerConfig(**json.loads(json_str)) @@ -2103,6 +2324,24 @@ class DecodingConfig: # 'outlines' / 'lm-format-enforcer' / 'xgrammar' guided_decoding_backend: str = 'xgrammar' + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend @@ -2124,6 +2363,24 @@ class ObservabilityConfig: # If set, collects the model execute time for the request. collect_model_execute_time: bool = False + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if not is_otel_available() and self.otlp_traces_endpoint is not None: raise ValueError( @@ -2165,6 +2422,24 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @classmethod def from_cli(cls, cli_value: str) -> "KVTransferConfig": """Parse the CLI value for the kv cache transfer config.""" @@ -2234,6 +2509,9 @@ class CompilationConfig(BaseModel): - 2: dynamo once. - 3: piecewise compilation. - debug_dump_path: the path to dump the debug information. + - cache_dir: the directory to store the compiled graph, to + accelerate Inductor compilation. By default, it will use + model-related information to generate a cache directory. - backend: the backend for compilation. It needs to be a string. - "" (empty string): use the default backend. - "eager"/"openxla"/...: use the specified backend registered in PyTorch. @@ -2302,12 +2580,10 @@ class CompilationConfig(BaseModel): """ # noqa level: int = 0 debug_dump_path: str = "" + cache_dir: str = "" backend: str = "" custom_ops: List[str] = Field(default_factory=list) - splitting_ops: List[str] = Field(default_factory=lambda: [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ]) + splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True candidate_compile_sizes: Optional[List[int]] = Field(default=None) @@ -2371,12 +2647,37 @@ def model_post_init(self, __context: Any) -> None: enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr compilation_time: float = PrivateAttr + # should be InductorHashCache, but Pydantic does not support it + inductor_hash_cache: Any = PrivateAttr # Per-model forward context # Mainly used to store attention cls # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __repr__(self) -> str: exclude = { "static_forward_context", @@ -2405,6 +2706,27 @@ def model_post_init(self, __context: Any) -> None: count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + if self.splitting_ops is None: + if envs.VLLM_USE_V1: + # v1 must split the graph on attention ops + # for piecewise cudagraph + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + else: + # v0 can use full graph compilation without splitting, + # splitting is optional. + # right now we still need it. kv cache shape + # will be included in the graph if we don't split + # the graph. + # TODO: hide kv cache in static forward context + # so that inductor does not see it. + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -2444,6 +2766,30 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # TODO: pass user-specified backend to piecewise compilation # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE + + if not self.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + hash_key = vllm_config.compute_hash() + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, + f"rank_{vllm_config.parallel_config.rank}") + os.makedirs(cache_dir, exist_ok=True) + self.cache_dir = cache_dir + + disabled = envs.VLLM_DISABLE_COMPILE_CACHE + from vllm.compilation.backends import InductorHashCache + self.inductor_hash_cache: InductorHashCache = InductorHashCache( + self.cache_dir, disabled=disabled) + if disabled: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info( + "Using cache directory: %s for vLLM's torch.compile", + self.cache_dir) + from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) @@ -2520,6 +2866,67 @@ class VllmConfig: init=True) # type: ignore instance_id: str = "" + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + + # summarize vllm config + vllm_factors: List[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + if self.decoding_config: + vllm_factors.append(self.decoding_config.compute_hash()) + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + if self.prompt_adapter_config: + vllm_factors.append(self.prompt_adapter_config.compute_hash()) + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] + return hash_str + def pad_for_cudagraph(self, batch_size: int) -> int: # if batch_size > self.compilation_config.max_capture_size, # it should raise an IndexError. diff --git a/vllm/envs.py b/vllm/envs.py index bc19c6af798db..09c7216265515 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -71,6 +71,7 @@ VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 + VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_LOCAL_RANK_DEV_MAP: str = "{}" @@ -464,6 +465,8 @@ def get_default_config_root(): lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), + "VLLM_DISABLE_COMPILE_CACHE": + lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), "VLLM_LOCAL_RANK_DEV_MAP": lambda: os.getenv("VLLM_LOCAL_RANK_DEV_MAP", "{}") }