Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Automatic conversion of classification and reward models #11469

Merged
merged 7 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)

# For pooling models (task={embed,classify,reward}) only
# For pooling models (task={embed,classify,reward,score}) only
llm = LLM(model=..., task="embed") # Name or path of your model
output = llm.encode("Hello, my name is")
print(output)
Expand Down Expand Up @@ -59,7 +59,7 @@ llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
output = llm.generate("Hello, my name is")
print(output)

# For pooling models (task={embed,classify,reward}) only
# For pooling models (task={embed,classify,reward,score}) only
output = llm.encode("Hello, my name is")
print(output)
```
Expand Down Expand Up @@ -369,14 +369,6 @@ you should explicitly specify the task type to ensure that the model is used in

#### Text Embedding (`--task embed`)

Any text generation model can be converted into an embedding model by passing {code}`--task embed`.

```{note}
To get the best results, you should use pooling models that are specifically trained as such.
```

The following table lists those that are tested in vLLM.

```{eval-rst}
.. list-table::
:widths: 25 25 50 5 5
Expand Down Expand Up @@ -437,6 +429,10 @@ On the other hand, its 1.5B variant ({code}`Alibaba-NLP/gte-Qwen2-1.5B-instruct`
despite being described otherwise on its model card.
```

If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.

#### Reward Modeling (`--task reward`)

```{eval-rst}
Expand All @@ -461,6 +457,9 @@ despite being described otherwise on its model card.
- ✅︎
```

If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly.

```{important}
For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
Expand Down Expand Up @@ -490,6 +489,9 @@ e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 1
- ✅︎
```

If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.

#### Sentence Pair Scoring (`--task score`)

```{eval-rst}
Expand Down
5 changes: 1 addition & 4 deletions tests/models/embedding/language/test_cls_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.

This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
"""Compare the classification outputs of HF and vLLM models.

Run `pytest tests/models/test_cls_models.py`.
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/models/embedding/language/test_scoring.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Compare the embedding outputs of HF and vLLM models.
"""Compare the scoring outputs of HF and vLLM models.

Run `pytest tests/models/embedding/language/test_embedding.py`.
Run `pytest tests/models/embedding/language/test_scoring.py`.
"""
import math

Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
Expand All @@ -29,9 +31,10 @@ def test_registry_imports(model_arch):
or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls)

# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_pooling_model(embed_model)
# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))

if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)


@contextlib.contextmanager
Expand Down Expand Up @@ -35,8 +37,12 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"]

model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.runner_type == "pooling":
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)

return model_cls, arch

Expand Down
190 changes: 170 additions & 20 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
from collections.abc import Iterable
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar

import torch
import torch.nn as nn

from .interfaces_base import VllmModelForPooling, is_pooling_model

if TYPE_CHECKING:
from vllm.model_executor.layers.pooler import PoolingType

_T = TypeVar("_T", bound=type[nn.Module])

_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]

def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls

def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name

for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)

return model_name + pooling_suffix


def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata

from .utils import AutoWeightsLoader, WeightsMapper

class ModelForEmbedding(cls, VllmModelForPooling):
class ModelForPooling(orig_cls, VllmModelForPooling):

def __init__(
self,
Expand All @@ -34,7 +53,7 @@ def __init__(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

# These are not used in embedding models
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
Expand All @@ -46,9 +65,9 @@ def __init__(
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)

def pooler(
Expand Down Expand Up @@ -82,17 +101,148 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return

# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
if hasattr(orig_cls, "load_weights"):
orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForPooling # type: ignore


def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.

By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.

Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls

# Lazy import
from vllm.model_executor.layers.pooler import PoolingType

ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")

return ModelForEmbedding # type: ignore


def as_classification_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classification.

By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.

Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls

# Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
from vllm.sequence import IntermediateTensors

from .utils import maybe_prefix

ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)

class ModelForClassification(ModelForPooling):

def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits


ModelForClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForClassification")

return ModelForClassification # type: ignore


def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.

By default, we return the hidden states of each token directly.

Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls

# Lazy import
from vllm.model_executor.layers.pooler import PoolingType

ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)

ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")

return ModelForReward # type: ignore
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
# TODO: Replace this model class with as_embedding_model(
# Qwen2ForCausalLM) after changing the default pooling method
if pooler_config.pooling_type is None:
logger.warning(
"This embedding model will default to last-token pooling in "
Expand Down
Loading
Loading