From 6c7f8815416f1968a8c1578f52a7e5b63f9310ed Mon Sep 17 00:00:00 2001 From: Yehoshua Cohen <61619195+yecohn@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:48:06 +0200 Subject: [PATCH] [Model] Add JambaForSequenceClassification model (#10860) Signed-off-by: Yehoshua Cohen Signed-off-by: DarkLight1337 Co-authored-by: Yehoshua Cohen Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.rst | 5 ++++ tests/models/registry.py | 1 + vllm/model_executor/models/jamba.py | 36 ++++++++++++++++++++++++- vllm/model_executor/models/registry.py | 1 + vllm/worker/pooling_model_runner.py | 7 ++++- 5 files changed, 48 insertions(+), 2 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 8d39e6f14a59c..488fcc7709c77 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -476,6 +476,11 @@ Classification (``--task classify``) - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`JambaForSequenceClassification` + - Jamba + - :code:`ai21labs/Jamba-tiny-reward-dev`, etc. + - ✅︎ + - ✅︎ * - :code:`Qwen2ForSequenceClassification` - Qwen2-based - :code:`jason9693/Qwen2.5-1.5B-apeach`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index fac8c4b2e9b19..819ef957a07f3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -138,6 +138,7 @@ class _HfExamplesInfo: "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), + "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 831db2ae52d74..91786db5ddc96 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -17,6 +17,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -24,8 +25,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import LayerBlockType from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP @@ -593,3 +595,35 @@ def _is_moe_layer(name: str): "experts", "router", ]]) + + +class JambaForSequenceClassification(JambaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + num_labels: int = config.num_labels + score_bias: bool = getattr(config, 'score_bias', False) + self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=False) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + hidden_states = hidden_states.float() + logits = self.score(hidden_states) + return self._pooler(logits, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # TODO: The reward weights themselves have float32 accuracy data, we + # would like to load them in fp32 to get that extra precision. + super().load_weights(weights) + self.score = self.score.float() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 68a2467a813a1..04d806c3c7eae 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -113,6 +113,7 @@ "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 1beae1e3884c5..f79b3773bcbd2 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -91,6 +91,10 @@ def execute_model( ] multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) @@ -110,7 +114,8 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **cross_enc_kwargs) + **cross_enc_kwargs, + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time):