Skip to content

Commit

Permalink
Replace embedding models with generic adapter
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 29, 2024
1 parent fa6ecb9 commit 954dd06
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 145 deletions.
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
Expand All @@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.

.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
Expand Down
5 changes: 5 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
import pytest

from vllm.config import PoolerConfig

from ..utils import check_embeddings_close


Expand Down Expand Up @@ -33,6 +35,9 @@ def test_models(
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}

Expand Down
16 changes: 8 additions & 8 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)

from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
Expand Down Expand Up @@ -136,12 +136,12 @@ class InputRegistry:
"""

def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module,DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()

def _default_dummy_data_factory(
self,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import math
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
Expand Down Expand Up @@ -107,12 +108,14 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
# new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)

msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
warnings.warn(msg, DeprecationWarning, stacklevel=2)

logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import for_embedding


@contextlib.contextmanager
Expand All @@ -32,7 +33,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

return ModelRegistry.resolve_model_cls(architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = for_embedding(model_cls)

return model_cls, arch


def get_architecture_class_name(model_config: ModelConfig) -> str:
Expand Down
94 changes: 94 additions & 0 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from collections.abc import Iterable
from typing import Any, TypeVar

import torch
import torch.nn as nn

from .interfaces_base import VllmModelForEmbedding, is_embedding_model

_T = TypeVar("_T", bound=type[nn.Module])


def for_embedding(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
return cls

# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata

from .utils import AutoWeightsLoader, WeightsMapper

class ModelForEmbedding(cls, VllmModelForEmbedding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

# These are not used in embedding models
if hasattr(self, "lm_head"):
del self.lm_head
if hasattr(self, "logits_processor"):
del self.logits_processor

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
assert pooler is not None
self._pooler = pooler

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))


# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if (hasattr(self, "model") and hasattr(self.model, "load_weights")
and all(
name == "model" or all(False for _ in child.parameters())
for name, child in self.named_children()
)):
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)

self.model.load_weights(weights)
# For most other models
elif hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") + "ForEmbedding"

return ModelForEmbedding # type: ignore

58 changes: 2 additions & 56 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -455,55 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str,
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def permute(w: torch.Tensor, n_heads: int):
return name, loaded_weight


# TODO: Remove this once reward modeling is separated from LlamaForCausalLM
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
Expand Down
19 changes: 1 addition & 18 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from .clip import (CLIPVisionModel, dummy_image_for_clip,
Expand Down Expand Up @@ -286,7 +284,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config

vision_feature_layer = config.vision_feature_layer
Expand Down Expand Up @@ -325,13 +322,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))

# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

Expand Down Expand Up @@ -678,13 +668,6 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
Expand Down
Loading

0 comments on commit 954dd06

Please sign in to comment.