diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 13e2f508c754a..42dd6119e76f1 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -5,7 +5,7 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -384,13 +384,9 @@ class BertEmbeddingModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() pooler_config = vllm_config.model_config.pooler_config - self.model = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) + self.model = self._build_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self._pooler = self._build_pooler(pooler_config) def forward( self, @@ -418,6 +414,15 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.model.load_weights(weights) - def _build_model(self, vllm_config: VllmConfig): + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: return BertModel(vllm_config=vllm_config, + prefix=prefix, embedding_class=BertEmbedding) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + return Pooler.from_config_with_defaults(pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index b0bd58548bad7..5b36c91b584ac 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -3,22 +3,10 @@ from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.bert import (BertEmbedding, BertEmbeddingModel, - BertEncoder, BertModel) - - -class RobertaModel(BertModel): - - def __init__(self, vllm_config: VllmConfig): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - self.embeddings = RobertaEmbedding(config) - self.encoder = BertEncoder(config, cache_config, quant_config) + BertModel) class RobertaEmbedding(BertEmbedding): @@ -50,24 +38,17 @@ def __init__(self, config: RobertaConfig): class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. - This class encapsulates the RobertaModel and provides an interface for + This class encapsulates the BertModel and provides an interface for embedding operations and customized pooling functions. Attributes: - model: An instance of RobertaModel used for forward operations. + model: An instance of BertModel used for forward operations. _pooler: An instance of Pooler used for pooling operations. """ - def __init__(self, *, vllm_config: VllmConfig) -> None: - nn.Module.__init__(self) - pooler_config = vllm_config.model_config.pooler_config - self.model = RobertaModel(vllm_config=vllm_config) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) - - def _build_model(self, vllm_config: VllmConfig): + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: return BertModel(vllm_config=vllm_config, + prefix=prefix, embedding_class=RobertaEmbedding)