diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 078c6bf9ea1df..7c92d165d05f7 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ @@ -85,7 +85,7 @@ def check_full_graph_support(model, enforce_eager=True, tensor_parallel_size=tp_size, disable_custom_all_reduce=True, - compilation_config=CompilationConfig(level=optimization_level), + compilation_config=optimization_level, **model_kwargs) outputs = llm.generate(prompts, sampling_params)