From 954dd061cb4402a32dffd492df37a1412dd50fb0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 08:55:40 +0000 Subject: [PATCH] Replace embedding models with generic adapter Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 6 +- .../embedding/language/test_embedding.py | 5 + vllm/inputs/registry.py | 16 ++-- vllm/model_executor/model_loader/loader.py | 5 +- vllm/model_executor/model_loader/utils.py | 7 +- vllm/model_executor/models/adapters.py | 94 +++++++++++++++++++ vllm/model_executor/models/gemma2.py | 58 +----------- vllm/model_executor/models/llama.py | 1 + vllm/model_executor/models/llava_next.py | 19 +--- vllm/model_executor/models/phi3v.py | 19 +--- vllm/model_executor/models/qwen2.py | 28 +++--- vllm/model_executor/models/qwen2_vl.py | 18 +--- vllm/model_executor/models/registry.py | 14 ++- vllm/multimodal/base.py | 6 +- vllm/multimodal/registry.py | 5 +- vllm/utils.py | 19 +++- 16 files changed, 175 insertions(+), 145 deletions(-) create mode 100644 vllm/model_executor/models/adapters.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 7b7a83f20871b..f4cab81b3d20b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -357,7 +357,7 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` @@ -378,6 +378,10 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. + .. note:: Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 36b1e5887981c..5ef8540265d14 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.config import PoolerConfig + from ..utils import check_embeddings_close @@ -33,6 +35,9 @@ def test_models( dtype: str, ) -> None: vllm_extra_kwargs = {} + if model == "ssmits/Qwen2-7B-Instruct-embed-base": + vllm_extra_kwargs["override_pooler_config"] = \ + PoolerConfig(pooling_type="MEAN") if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 68b4756331e6d..874290bc94ed1 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,8 +11,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, - resolve_mm_processor_kwargs) +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + print_warning_once, resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs @@ -136,12 +136,12 @@ class InputRegistry: """ def __init__(self) -> None: - self._dummy_factories_by_model_type: Dict[Type[nn.Module], - DummyDataFactory] = {} - self._dummy_encoder_factories_by_model_type: Dict[ - Type[nn.Module], DummyDataFactory] = {} - self._input_processors_by_model_type: Dict[Type[nn.Module], - InputProcessor] = {} + self._dummy_factories_by_model_type = \ + ClassRegistry[nn.Module,DummyDataFactory]() + self._dummy_encoder_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._input_processors_by_model_type = \ + ClassRegistry[nn.Module, InputProcessor]() def _default_dummy_data_factory( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 37c2d789030b6..6c0b6a3c7ccab 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,6 +9,7 @@ import json import math import os +import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast @@ -107,12 +108,14 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: # new-style model class with set_current_vllm_config(vllm_config): return model_class(vllm_config=vllm_config, prefix=prefix) + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "for the design and update the model class accordingly.") - logger.warning(msg) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + logger.warning( "Trying to guess the arguments for old-style model class %s", model_class, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b95c0b7cd0612..1975f1e53e506 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,6 +7,7 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import for_embedding @contextlib.contextmanager @@ -32,7 +33,11 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embedding": + model_cls = for_embedding(model_cls) + + return model_cls, arch def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000000000..b529dcb5dd3b8 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,94 @@ +from collections.abc import Iterable +from typing import Any, TypeVar + +import torch +import torch.nn as nn + +from .interfaces_base import VllmModelForEmbedding, is_embedding_model + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def for_embedding(cls: _T) -> _T: + """Subclass an existing vLLM model to support embeddings.""" + # Avoid modifying existing embedding models + if is_embedding_model(cls): + return cls + + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, + PoolingType) + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForEmbedding(cls, VllmModelForEmbedding): + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in embedding models + if hasattr(self, "lm_head"): + del self.lm_head + if hasattr(self, "logits_processor"): + del self.logits_processor + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + assert pooler is not None + self._pooler = pooler + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if (hasattr(self, "model") and hasattr(self.model, "load_weights") + and all( + name == "model" or all(False for _ in child.parameters()) + for name, child in self.named_children() + )): + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + self.model.load_weights(weights) + # For most other models + elif hasattr(cls, "load_weights"): + cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + loader.load_weights(weights) + + ModelForEmbedding.__name__ = cls.__name__ \ + .removesuffix("ForCausalLM") \ + .removesuffix("ForConditionalGeneration") + "ForEmbedding" + + return ModelForEmbedding # type: ignore + \ No newline at end of file diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d35fcb012e166..4664aa53ea092 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -30,19 +30,17 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, +from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -455,55 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Gemma2EmbeddingModel(nn.Module, SupportsPP): - """ - A model that uses Gemma2 with additional embedding functionalities. - - This class encapsulates the Gemma2Model and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of Gemma2Model used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fe94bb352961b..4daaf5ff3d37e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -627,6 +627,7 @@ def permute(w: torch.Tensor, n_heads: int): return name, loaded_weight +# TODO: Remove this once reward modeling is separated from LlamaForCausalLM class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): """ A model that uses Llama with additional embedding functionalities. diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e113f5862830d..42c190811eba4 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -14,13 +14,11 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -286,7 +284,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config vision_feature_layer = config.vision_feature_layer @@ -325,13 +322,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")) - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -678,13 +668,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4cb874a13e0c1..a725590914533 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,19 +29,17 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip @@ -536,7 +534,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -561,13 +558,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model = LlamaForCausalLM(vllm_config=vllm_config, prefix="") - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -739,13 +729,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 87943e53d861c..7d4cc4b69e614 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -31,6 +31,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class Qwen2MLP(nn.Module): @@ -433,7 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -454,14 +456,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -499,13 +493,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( @@ -553,6 +540,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), + # after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7956a98b21569..27175dbae7483 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -59,14 +58,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.platforms import _Backend -from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor @@ -1070,7 +1068,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" @@ -1102,11 +1099,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -1361,13 +1354,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c400c7d59828c..8b606f0d2844e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,6 +20,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .adapters import for_embedding from .interfaces import (has_inner_state, is_attention_free, supports_cross_encoding, supports_multimodal, supports_pp) @@ -107,7 +108,7 @@ "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), + "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "LlamaModel": ("llama", "LlamaEmbeddingModel"), **{ @@ -218,9 +219,18 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + is_embedding_model_ = is_embedding_model(model) + if not is_embedding_model_: + try: + for_embedding(model) + except Exception: + pass + else: + is_embedding_model_ = True + return _ModelInfo( is_text_generation_model=is_text_generation_model(model), - is_embedding_model=is_embedding_model(model), + is_embedding_model=is_embedding_model_, supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6eec660e42ac4..bbb8fb4bc1cd1 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -7,7 +7,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (get_allowed_kwarg_only_overrides, +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) if TYPE_CHECKING: @@ -54,8 +54,8 @@ class MultiModalPlugin(ABC): """ def __init__(self) -> None: - self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} - self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() @abstractmethod def get_data_key(self) -> str: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b992442d3b314..b73daee98bd80 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,6 +9,7 @@ from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import ClassRegistry from .audio import AudioPlugin from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc @@ -62,8 +63,8 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} - self._processor_factories: Dict[Type[nn.Module], - MultiModalProcessorFactory] = {} + self._processor_factories = ClassRegistry[nn.Module, + MultiModalProcessorFactory]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} diff --git a/vllm/utils.py b/vllm/utils.py index 6f7a6f8c54e47..83fbefd755870 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections import defaultdict +from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname @@ -1517,13 +1517,13 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping, Generic[T]): +class LazyDict(Mapping[str, T], Generic[T]): def __init__(self, factory: Dict[str, Callable[[], T]]): self._factory = factory self._dict: Dict[str, T] = {} - def __getitem__(self, key) -> T: + def __getitem__(self, key: str) -> T: if key not in self._dict: if key not in self._factory: raise KeyError(key) @@ -1540,6 +1540,19 @@ def __len__(self): return len(self._factory) +class ClassRegistry(UserDict[type[T], _V]): + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: type[T]) -> bool: + return any(cls in self.data for cls in key.mro()) + + def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor.