diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fbaa427bb7270..baad54eaf6a91 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -313,14 +313,15 @@ steps: ##### models test ##### -- label: Basic Models Test # 10min +- label: Basic Models Test # 30min source_file_dependencies: - vllm/ - tests/models commands: - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s models/test_oot_registration.py # it needs a clean process - - pytest -v -s models/*.py --ignore=models/test_oot_registration.py + - pytest -v -s models/test_registry.py + - pytest -v -s models/test_initialization.py - label: Decoder-only Language Models Test (Standard) # 18min #mirror_hardwares: [amd] diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5d566f8308b70..c49ed9802cde8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -166,14 +166,14 @@ def iter_params(self, model_name: str): "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), - "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), + "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), + "adept/persimmon-8b-chat": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(), - "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501 "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 - "adept/persimmon-8b-chat": PPTestSettings.fast(), + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501 "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), diff --git a/tests/models/registry.py b/tests/models/registry.py new file mode 100644 index 0000000000000..ec9ff52d112df --- /dev/null +++ b/tests/models/registry.py @@ -0,0 +1,212 @@ +from dataclasses import dataclass, field +from typing import AbstractSet, Mapping, Optional + + +@dataclass(frozen=True) +class _HfExamplesInfo: + default: str + """The default model to use for testing this architecture.""" + + extras: Mapping[str, str] = field(default_factory=dict) + """Extra models to use for testing this architecture.""" + + tokenizer: Optional[str] = None + """Set the tokenizer to load for this architecture.""" + + tokenizer_mode: str = "auto" + """Set the tokenizer type for this architecture.""" + + speculative_model: Optional[str] = None + """ + The default model to use for testing this architecture, which is only used + for speculative decoding. + """ + + is_available_online: bool = True + """ + Set this to ``False`` if the name of this architecture no longer exists on + the HF repo. To maintain backwards compatibility, we have not removed them + from the main model registry, so without this flag the registry tests will + fail. + """ + + trust_remote_code: bool = False + """The ``trust_remote_code`` level required to load the model.""" + + +# yapf: disable +_TEXT_GENERATION_EXAMPLE_MODELS = { + # [Decoder-only] + "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", + trust_remote_code=True), + "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", + trust_remote_code=True), + "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", + trust_remote_code=True), + "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", + trust_remote_code=True), + "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", + trust_remote_code=True), + "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), + # ChatGLMModel supports multimodal + "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", + trust_remote_code=True), + "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), + "DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct", + trust_remote_code=True), + "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), + "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 + trust_remote_code=True), + "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 + "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), + "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), + "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), + "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), + "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), + "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), + "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), + "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), + "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", + trust_remote_code=True), + "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", + trust_remote_code=True), + "InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B", + trust_remote_code=True), + "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), + "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"), + "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"), + "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", + is_available_online=False), + "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), + "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 + "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", + trust_remote_code=True), + "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", + trust_remote_code=True), + "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), + "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 + "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 + "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), + "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), + "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), + "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), + "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), + "OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), + "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", + trust_remote_code=True), + "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), + "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), + "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", + trust_remote_code=True), + "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", + trust_remote_code=True), + # QWenLMHeadModel supports multimodal + "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"), + "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), + "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", + is_available_online=False), + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 + is_available_online=False), + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), + "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), + "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), + "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", + is_available_online=False, + trust_remote_code=True), + # [Encoder-decoder] + "BartModel": _HfExamplesInfo("facebook/bart-base"), + "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer + # Therefore, we borrow the BartTokenizer from the original Bart model + "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 +} + +_EMBEDDING_EXAMPLE_MODELS = { + # [Text-only] + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + # [Multimodal] + "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), + "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", + trust_remote_code=True), + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 +} + +_MULTIMODAL_EXAMPLE_MODELS = { + # [Decoder-only] + "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 + "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 + "ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b", + extras={"text_only": "THUDM/chatglm3-6b"}, + trust_remote_code=True), + "ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b", + is_available_online=False), + "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), + "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"), + "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", + trust_remote_code=True), + "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501 + "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", + extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501 + "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 + "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 + "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", + trust_remote_code=True), + "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", + trust_remote_code=True), + "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", + trust_remote_code=True), + "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501 + "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", + trust_remote_code=True), + "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 + tokenizer_mode="mistral"), + "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat", + extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501 + trust_remote_code=True), + "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), + # [Encoder-decoder] + "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 +} + +_SPECULATIVE_DECODING_EXAMPLE_MODELS = { + "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", + speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 + "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", + speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 + "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", + speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501 +} + +_EXAMPLE_MODELS = { + **_TEXT_GENERATION_EXAMPLE_MODELS, + **_EMBEDDING_EXAMPLE_MODELS, + **_MULTIMODAL_EXAMPLE_MODELS, + **_SPECULATIVE_DECODING_EXAMPLE_MODELS, +} + + +class HfExampleModels: + def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None: + super().__init__() + + self.hf_models = hf_models + + def get_supported_archs(self) -> AbstractSet[str]: + return self.hf_models.keys() + + def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: + return self.hf_models[model_arch] + + +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py new file mode 100644 index 0000000000000..b8312c2d9b7cc --- /dev/null +++ b/tests/models/test_initialization.py @@ -0,0 +1,55 @@ +from unittest.mock import patch + +import pytest +import transformers +from transformers import PretrainedConfig + +from vllm import LLM + +from .registry import HF_EXAMPLE_MODELS + + +@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) +def test_can_initialize(model_arch): + if (model_arch == "Idefics3ForConditionalGeneration" + and transformers.__version__ < "4.46.0"): + pytest.skip(reason="Model introduced in HF >= 4.46.0") + + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + if not model_info.is_available_online: + pytest.skip("Model is not available online") + + # Avoid OOM + def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: + if hasattr(hf_config, "text_config"): + text_config: PretrainedConfig = hf_config.text_config + else: + text_config = hf_config + + text_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "num_experts": 2, + "num_experts_per_tok": 2, + "num_local_experts": 2, + }) + + return hf_config + + # Avoid calling model.forward() + def _initialize_kv_caches(self) -> None: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + + with patch.object(LLM.get_engine_class(), "_initialize_kv_caches", + _initialize_kv_caches): + LLM( + model_info.default, + tokenizer=model_info.tokenizer, + tokenizer_mode=model_info.tokenizer_mode, + speculative_model=model_info.speculative_model, + num_speculative_tokens=1 if model_info.speculative_model else None, + trust_remote_code=model_info.trust_remote_code, + load_format="dummy", + hf_overrides=hf_overrides, + ) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index a2194fa15f90e..dbc415796ee55 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -14,6 +14,7 @@ from vllm.platforms import current_platform from ..utils import fork_new_process_for_each_test +from .registry import HF_EXAMPLE_MODELS @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) @@ -73,3 +74,12 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): "This model no longer initializes CUDA on import. " "Please test using a different one.", stacklevel=2) + + +def test_hf_registry_coverage(): + untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() - + set(ModelRegistry.get_supported_archs())) + + assert not untested_archs, ( + "Please add the following architectures to " + f"`tests/models/registry.py`: {untested_archs}") diff --git a/vllm/config.py b/vllm/config.py index 002adb4316969..83b1483eb99e0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,8 +3,8 @@ import json import warnings from dataclasses import dataclass, field, replace -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, - Mapping, Optional, Set, Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, + Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch from transformers import PretrainedConfig @@ -20,7 +20,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - print_warning_once) + identity, print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -44,6 +44,9 @@ # "draft" is only used internally for speculative decoding _Task = Literal["generate", "embedding", "draft"] +HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], + PretrainedConfig]] + class ModelConfig: """Configuration for the model. @@ -115,7 +118,9 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - hf_overrides: Arguments to be forwarded to the HuggingFace config. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. pooling_type: Used to configure the pooling method in the embedding @@ -164,7 +169,7 @@ def __init__( override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, chat_template_text_format: str = "string", - hf_overrides: Optional[Dict[str, Any]] = None, + hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, pooling_norm: Optional[bool] = None, @@ -182,15 +187,23 @@ def __init__( if hf_overrides is None: hf_overrides = {} + + if callable(hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = hf_overrides + else: + hf_overrides_kw = hf_overrides + hf_overrides_fn = identity + if rope_scaling is not None: hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} - hf_overrides.update(hf_override) + hf_overrides_kw.update(hf_override) msg = ("`--rope-scaling` will be removed in a future release. " f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) if rope_theta is not None: hf_override = {"rope_theta": rope_theta} - hf_overrides.update(hf_override) + hf_overrides_kw.update(hf_override) msg = ("`--rope-theta` will be removed in a future release. " f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) @@ -207,9 +220,12 @@ def __init__( self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, - **hf_overrides) + + hf_config = get_config(self.model, trust_remote_code, revision, + code_revision, config_format, **hf_overrides_kw) + hf_config = hf_overrides_fn(hf_config) + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(self.hf_config) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 31aa8c5908719..244aa09e12552 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,9 +9,9 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, + DeviceConfig, HfOverrides, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase @@ -128,7 +128,7 @@ class EngineArgs: code_revision: Optional[str] = None rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None - hf_overrides: Optional[Dict[str, Any]] = None + hf_overrides: Optional[HfOverrides] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: Optional[bool] = None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a15dbd1c45119..63c2bb6097079 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,7 +9,7 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.engine.arg_utils import EngineArgs, TaskOption +from vllm.engine.arg_utils import EngineArgs, HfOverrides, TaskOption from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, @@ -101,7 +101,9 @@ class LLM: disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` disable_async_output_proc: Disable async output processing. This may result in lower performance. - hf_overrides: Arguments to be forwarded to the HuggingFace config. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -156,7 +158,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, - hf_overrides: Optional[dict] = None, + hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 50701793b7b83..31fc098a8bb3f 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,7 +41,8 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -245,7 +246,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): gather_output=True, ) self.language_model = PersimmonForCausalLM( - vllm_config.with_hf_config(config.text_config)) + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 34889d691a934..f1b7c896cadfe 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -161,11 +161,5 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.model = InternLM2VEModel(config, - cache_config, - quant_config, + self.model = InternLM2VEModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 999739ccd98bf..fd8eda997f76f 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -382,11 +382,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config quant_config = vllm_config.quant_config @@ -699,12 +695,8 @@ def is_default_weight_loading(self, name: str) -> bool: class MiniCPMV2_0(MiniCPMVBaseModel): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__(vllm_config) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (2, 0) def init_llm( @@ -857,12 +849,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__(vllm_config) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (2, 5) def init_llm( @@ -999,12 +987,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__(vllm_config) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (2, 6) def init_llm( @@ -1117,7 +1101,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): embedding_modules = {} embedding_padding_modules = [] - def __new__(cls, vllm_config: VllmConfig, prefix: str = ""): + def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if not hasattr(config, "version"): if config.hidden_size == 2304 and config.query_num == 64: diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 6aa43f22f4c93..4d7e82880041d 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -65,7 +65,7 @@ class MLPSpeculator(nn.Module): https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite """ - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config self.n_predict = config.n_predict diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f22d1b04ebf09..c0d503a1c5ba2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -1,3 +1,7 @@ +""" +Whenever you add an architecture to this page, please also update +`tests/models/registry.py` with example HuggingFace models for it. +""" import importlib import os import pickle @@ -58,14 +62,14 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), - "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),