Skip to content

Commit

Permalink
[Model] Add JambaForSequenceClassification model (vllm-project#10860)
Browse files Browse the repository at this point in the history
Signed-off-by: Yehoshua Cohen <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Yehoshua Cohen <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent a0f7d53 commit 6c7f881
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ Classification (``--task classify``)
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`JambaForSequenceClassification`
- Jamba
- :code:`ai21labs/Jamba-tiny-reward-dev`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForSequenceClassification`
- Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class _HfExamplesInfo:
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
Expand Down
36 changes: 35 additions & 1 deletion vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import LayerBlockType

from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -593,3 +595,35 @@ def _is_moe_layer(name: str):
"experts",
"router",
]])


class JambaForSequenceClassification(JambaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)

pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float()
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
# Multiple models share the same architecture, so we include them all
Expand Down
7 changes: 6 additions & 1 deletion vllm/worker/pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def execute_model(
]

multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
Expand All @@ -110,7 +114,8 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**cross_enc_kwargs)
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
Expand Down

0 comments on commit 6c7f881

Please sign in to comment.