Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix unable to load some models #10312

Merged
merged 18 commits into from
Nov 15, 2024
Merged
10 changes: 9 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,15 @@ steps:
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
- pytest -v -s models/test_registry.py

- label: Model Initialization Test # 1h
optional: true
source_file_dependencies:
- vllm/
- tests/models
commands:
- pytest -v -s models/test_initialization.py

- label: Decoder-only Language Models Test (Standard) # 18min
#mirror_hardwares: [amd]
Expand Down
6 changes: 3 additions & 3 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ def iter_params(self, model_name: str):
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
Expand Down
189 changes: 189 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Mapping


@dataclass(frozen=True)
class _HfExamplesInfo:
default: str
"""The default model to use for testing this architecture."""

extras: Mapping[str, str] = field(default_factory=dict)
"""Extra models to use for testing this architecture."""

is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
the HF repo. To maintain backwards compatibility, we have not removed them
from the main model registry, so without this flag the registry tests will
fail.
"""

trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""


# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
trust_remote_code=True),
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b"),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
trust_remote_code=True),
"InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B"),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"),
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"),
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"),
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
trust_remote_code=True),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"Florence2ForConditionalGeneration": _HfExamplesInfo("Xenova/tiny-random-Florence2ForConditionalGeneration"), # noqa: E501
}

_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("google-bert/bert-base"),
"Gemma2Model": _HfExamplesInfo("gBAAI/bge-multilingual-gemma2"),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1") # noqa: E501,
}

_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
extras={"text_only": "THUDM/chatglm3-6b"},
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409"), # noqa: E501
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat",
extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
}

_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EAGLEModel": _HfExamplesInfo("abhigoyal/vllm-eagle-llama-68m-random"),
"MedusaModel": _HfExamplesInfo("abhigoyal/vllm-medusa-llama-68m-random"),
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("ibm-fms/llama-160m-accelerator"), # noqa: E501
}

_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
}


class HfExampleModels:
def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None:
super().__init__()

self.hf_models = hf_models

def get_supported_archs(self) -> AbstractSet[str]:
return self.hf_models.keys()

def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
20 changes: 20 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from vllm import LLM

from .registry import HF_EXAMPLE_MODELS


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)

if not model_info.is_available_online:
pytest.skip("Model is not available online")

try:
LLM(model_info.default,
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
trust_remote_code=model_info.trust_remote_code,
load_format="dummy")
except Exception:
raise
9 changes: 9 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.platforms import current_platform

from ..utils import fork_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS


@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
Expand Down Expand Up @@ -73,3 +74,11 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)


def test_hf_registry_coverage():
untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() -
set(ModelRegistry.get_supported_archs()))

assert not untested_archs, ("Please add the following architectures to "
f"`tests/models/registry.py: {untested_archs}")
7 changes: 5 additions & 2 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
Expand Down Expand Up @@ -245,7 +246,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
gather_output=True,
)
self.language_model = PersimmonForCausalLM(
vllm_config.with_hf_config(config.text_config))
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/models/internlm2_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,5 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.model = InternLM2VEModel(config,
cache_config,
quant_config,
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
32 changes: 8 additions & 24 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
instantiated.
"""

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
quant_config = vllm_config.quant_config
Expand Down Expand Up @@ -699,12 +695,8 @@ def is_default_weight_loading(self, name: str) -> bool:

class MiniCPMV2_0(MiniCPMVBaseModel):

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 0)

def init_llm(
Expand Down Expand Up @@ -857,12 +849,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 5)

def init_llm(
Expand Down Expand Up @@ -999,12 +987,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 6)

def init_llm(
Expand Down Expand Up @@ -1117,7 +1101,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []

def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Whenever you add an architecture to this page, please also update
`vllm/model_executor/models/registry.py` with example HuggingFace models
for that architecture.
"""
import importlib
import os
import pickle
Expand Down Expand Up @@ -58,14 +63,14 @@
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
Expand Down