Skip to content

Commit

Permalink
finish generalizing the Bert classes
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser committed Nov 13, 2024
1 parent aae474e commit 07c931c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 35 deletions.
23 changes: 14 additions & 9 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -384,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)

def forward(
self,
Expand Down Expand Up @@ -418,6 +414,15 @@ def pooler(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def _build_model(self, vllm_config: VllmConfig):
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
33 changes: 7 additions & 26 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,10 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import (BertEmbedding, BertEmbeddingModel,
BertEncoder, BertModel)


class RobertaModel(BertModel):

def __init__(self, vllm_config: VllmConfig):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = RobertaEmbedding(config)
self.encoder = BertEncoder(config, cache_config, quant_config)
BertModel)


class RobertaEmbedding(BertEmbedding):
Expand Down Expand Up @@ -50,24 +38,17 @@ def __init__(self, config: RobertaConfig):
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the RobertaModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of RobertaModel used for forward operations.
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig) -> None:
nn.Module.__init__(self)
pooler_config = vllm_config.model_config.pooler_config
self.model = RobertaModel(vllm_config=vllm_config)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

def _build_model(self, vllm_config: VllmConfig):
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)

0 comments on commit 07c931c

Please sign in to comment.