Skip to content

Commit

Permalink
Support Roberta embedding models (#9387)
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Flavia Beo <[email protected]>
Co-authored-by: Flavia Beo <[email protected]>
  • Loading branch information
maxdebayser and flaviabeo authored Nov 14, 2024
1 parent 1dbae03 commit 4a18fd1
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 14 deletions.
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down
44 changes: 44 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")

MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
"intfloat/multilingual-e5-large")
REVISION_ROBERTA = os.environ.get("REVISION", "main")


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
Expand Down Expand Up @@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner):
assert model._pooler.normalize
# assert output
assert output


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert not model_config.encoder_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize

# assert output
assert output
2 changes: 2 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"intfloat/multilingual-e5-large",
]

ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5",
"intfloat/multilingual-e5-large",
]


Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
35 changes: 23 additions & 12 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -305,14 +305,16 @@ def forward(self, hidden_states: torch.Tensor,

class BertModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
Expand Down Expand Up @@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)

def forward(
self,
Expand All @@ -415,3 +413,16 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
Expand Down
117 changes: 117 additions & 0 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import List, Optional

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors


class RobertaEmbedding(nn.Module):

def __init__(self, config: RobertaConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)

self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )

self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")

def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()

# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)

# TODO: figure out if there is a better way
# to make to make position ids start at padding_idx + 1
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
position_ids += self.padding_idx + 1

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)

# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))

embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings


class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert hasattr(attn_metadata, "seq_lens_tensor")
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0

return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)

0 comments on commit 4a18fd1

Please sign in to comment.