From 05d1f8c9c64b4458ae7cee2650eb97498146ee50 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 01:27:30 -0800 Subject: [PATCH] [misc] move functions to config.py (#10624) Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 4 +- tests/compile/piecewise/test_toy_llama.py | 4 +- tests/kernels/test_encoder_decoder_attn.py | 3 +- .../model_executor/test_enabled_custom_ops.py | 3 +- vllm/attention/layer.py | 3 +- vllm/compilation/wrapper.py | 3 +- vllm/config.py | 51 +++++++++++++++++ vllm/model_executor/custom_op.py | 2 +- vllm/model_executor/model_loader/loader.py | 3 +- .../model_executor/model_loader/tensorizer.py | 3 +- vllm/plugins/__init__.py | 56 ------------------- 11 files changed, 62 insertions(+), 73 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 0db12d6b6a43c..7ef502abee345 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -10,8 +10,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op global_counter = 0 diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index cfe661b8871e0..dbd5a3bbffeab 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -16,8 +16,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op # create a library to hold the custom op diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c4b72ba6bf4ee..d943b048b7934 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,10 +18,9 @@ from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.platforms import current_platform -from vllm.plugins import set_current_vllm_config # List of support backends for encoder/decoder models LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index c54e30995da49..0a3aba255fd76 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -2,13 +2,12 @@ import pytest -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.plugins import set_current_vllm_config # Registered subclass for test diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1bb335909484b..17157617248f7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,13 +7,12 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend -from vllm.config import CacheConfig +from vllm.config import CacheConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform -from vllm.plugins import get_current_vllm_config from vllm.utils import direct_register_custom_op diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 0143d0301ca1a..bc4d292fef402 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,7 +8,7 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel +from vllm.config import CompilationLevel, get_current_vllm_config class TorchCompileWrapperWithCustomDispatcher: @@ -32,7 +32,6 @@ def __init__(self, # default compilation settings # compiling the forward method - from vllm.plugins import get_current_vllm_config backend = get_current_vllm_config( ).compilation_config.init_backend() diff --git a/vllm/config.py b/vllm/config.py index 68720f3a3034d..0a390c4311ba6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,6 +3,7 @@ import hashlib import json import warnings +from contextlib import contextmanager from dataclasses import dataclass, field, replace from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, @@ -2450,3 +2451,53 @@ def __str__(self): self.cache_config.enable_prefix_caching, self.model_config.use_async_output_proc, self.model_config.mm_processor_kwargs) + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + yield + finally: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + "if you want it to be supported.", + vllm_config.model_config.model) + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index b07966f2ab7d0..fddc8bad09ef5 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -2,9 +2,9 @@ import torch.nn as nn +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 34e0860162260..441dd409b4f9d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, - VllmConfig) + VllmConfig, set_current_vllm_config) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE @@ -47,7 +47,6 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3fd668765a1b1..87f3fcb5cae00 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -13,13 +13,12 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import ModelConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.plugins import set_current_vllm_config from vllm.utils import FlexibleArgumentParser tensorizer_error_msg = None diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 8b43167693598..3c64726ca3344 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,15 +1,10 @@ import logging import os -from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional import torch import vllm.envs as envs -if TYPE_CHECKING: - from vllm.config import VllmConfig - logger = logging.getLogger(__name__) # make sure one process only loads plugins once @@ -64,54 +59,3 @@ def load_general_plugins(): logger.info("plugin %s loaded.", plugin.name) except Exception: logger.exception("Failed to load plugin %s", plugin.name) - - -_current_vllm_config: Optional["VllmConfig"] = None - - -@contextmanager -def set_current_vllm_config(vllm_config: "VllmConfig"): - """ - Temporarily set the current VLLM config. - Used during model initialization. - We save the current VLLM config in a global variable, - so that all modules can access it, e.g. custom ops - can access the VLLM config to determine how to dispatch. - """ - global _current_vllm_config - old_vllm_config = _current_vllm_config - from vllm.compilation.counter import compilation_counter - from vllm.config import CompilationLevel - num_models_seen = compilation_counter.num_models_seen - try: - _current_vllm_config = vllm_config - yield - finally: - logger.debug("enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops) - logger.debug("disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops) - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ - and compilation_counter.num_models_seen == num_models_seen: - # If the model supports compilation, - # compilation_counter.num_models_seen should be increased - # by at least 1. - # If it is not increased, it means the model does not support - # compilation (does not have @support_torch_compile decorator). - logger.warning( - "`torch.compile` is turned on, but the model %s" - " does not support it. Please open an issue on GitHub" - "if you want it to be supported.", - vllm_config.model_config.model) - _current_vllm_config = old_vllm_config - - -def get_current_vllm_config() -> "VllmConfig": - if _current_vllm_config is None: - # in ci, usually when we test custom ops/modules directly, - # we don't set the vllm config. In that case, we set a default - # config. - logger.warning("Current VLLM config is not set.") - from vllm.config import VllmConfig - return VllmConfig() - return _current_vllm_config