From 25d806e95391a8556deb69bdb214714425f776c9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 24 Nov 2024 23:40:08 -0800 Subject: [PATCH] [misc] add torch.compile compatibility check (#10618) Signed-off-by: youkaichao --- tests/v1/engine/test_engine_core_client.py | 2 +- vllm/config.py | 14 ++++++++++++++ vllm/engine/arg_utils.py | 7 +++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 7b241bf836a0e..e248e35ae4069 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -81,7 +81,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME) + engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) vllm_config = engine_args.create_engine_config() executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( diff --git a/vllm/config.py b/vllm/config.py index dcdaf58b5ccdb..68720f3a3034d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2394,6 +2394,20 @@ def __post_init__(self): self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION: + logger.warning( + "CPU offload is not supported with `torch.compile` yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.lora_config is not None and self.compilation_config.level !=\ + CompilationLevel.NO_COMPILATION: + logger.warning("LoRA is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + current_platform.check_and_update_config(self) def __str__(self): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 82f1ef51255e9..a43e133f21ac2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,6 +197,13 @@ def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int, dict)): + self.compilation_config = CompilationConfig.from_cli( + json.dumps(self.compilation_config)) + # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins()