diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5d5e8ae1ee532..aab523705a809 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -16,6 +16,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -23,8 +24,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import HasInnerState, SupportsLoRA from .utils import maybe_prefix @@ -548,3 +550,35 @@ def _is_moe_layer(name: str): "experts", "router", ]]) + + +class JambaForSequenceClassification(JambaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + num_labels: int = config.num_labels + score_bias: bool = getattr(config, 'score_bias', False) + self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=False) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + hidden_states = hidden_states.float() + logits = self.score(hidden_states) + return self._pooler(logits, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # TODO: The reward weights themselves have float32 accuracy data, we + # would like to load them in fp32 to get that extra precision. + super().load_weights(weights) + self.score = self.score.float() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c66fbce018a62..be5211e6ae5f6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -108,6 +108,7 @@ "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),