Skip to content

Commit

Permalink
add jamba classification model
Browse files Browse the repository at this point in the history
Signed-off-by: Yehoshua Cohen <[email protected]>
  • Loading branch information
Yehoshua Cohen committed Dec 3, 2024
1 parent f6084f6 commit d4de8fd
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
Expand Down Expand Up @@ -548,3 +550,35 @@ def _is_moe_layer(name: str):
"experts",
"router",
]])


class JambaForSequenceClassification(JambaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)

pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float()
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()

0 comments on commit d4de8fd

Please sign in to comment.