diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 44de1d7ec5607..779d8609ec9fc 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -102,14 +102,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, - child_sample.logprobs) + child_sample.logprobs, + child_sample.output_classification_probs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + parent.append_token_id( + last_child_sample.output_token, last_child_sample.logprobs, + last_child_sample.output_classification_probs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a84f562909d50..c5f99e314371e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,10 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda") + self.classification_head.weight.data = torch.load( + "classification_head.pth", map_location="cuda").bfloat16() + def forward( self, logits: torch.Tensor, @@ -62,6 +66,9 @@ def forward( logits = _apply_min_tokens_penalty(logits, sampling_metadata) + classification_probs = torch.nn.functional.sigmoid( + self.classification_head(logits)).flatten().tolist() + # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -111,7 +118,8 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, + return _build_sampler_output(classification_probs, + sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, @@ -992,6 +1000,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( + classification_probs, sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], @@ -1016,11 +1025,12 @@ def _build_sampler_output( seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): + for parent_id, next_token_id, logprobs, sample_idx in zip( + parent_ids, next_token_ids, group_sample_logprobs, + seq_group.sample_indices): seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) + SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, + classification_probs[sample_idx])) sampler_output.append( CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) @@ -1031,7 +1041,6 @@ def _build_sampler_output( else: sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) - return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, diff --git a/vllm/outputs.py b/vllm/outputs.py index 49f526b5f9300..0f1b076498112 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -30,6 +30,7 @@ class CompletionOutput: text: str token_ids: List[int] cumulative_logprob: float + classification_probs: List[float] logprobs: Optional[SampleLogprobs] finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None @@ -43,6 +44,7 @@ def __repr__(self) -> str: f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " + f"classification_probs={self.classification_probs}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " f"stop_reason={self.stop_reason})") @@ -124,13 +126,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": include_logprobs = seq_group.sampling_params.logprobs is not None text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs + CompletionOutput( + seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + seq.get_output_classification_probs(), + seq.output_logprobs if include_logprobs else None, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason, + ) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sequence.py b/vllm/sequence.py index ee8c94bbf06f7..a8a3a87413107 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -116,6 +116,7 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], + output_classification_probs: Optional[List[float]] = None, output_token_ids: Optional[List[int]] = None, ) -> None: if output_token_ids is None: @@ -124,13 +125,16 @@ def __init__( self.prompt_token_ids = prompt_token_ids self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self.output_token_ids = output_token_ids + self.output_classification_probs = output_classification_probs or [] self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, token_id: int, logprob: float, + classification_prob: float) -> None: self.output_token_ids.append(token_id) + self.output_classification_probs.append(classification_prob) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -139,6 +143,9 @@ def get_len(self) -> int: def get_prompt_len(self) -> int: return len(self.prompt_token_ids) + def get_output_classification_probs(self) -> List[float]: + return self.output_classification_probs + def get_output_len(self) -> int: return len(self.output_token_ids) @@ -230,8 +237,8 @@ def __init__( self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(self.prompt_token_ids) + self.output_classification_probs: List[float] = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -313,11 +320,14 @@ def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], + classification_prob: float, ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.output_classification_probs.append(classification_prob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + classification_prob) def get_len(self) -> int: return self.data.get_len() @@ -328,6 +338,9 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() + def get_output_classification_probs(self) -> List[float]: + return self.data.get_output_classification_probs() + def get_token_ids(self) -> List[int]: return self.data.get_token_ids() @@ -708,20 +721,20 @@ class SequenceOutput: (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ - def __init__( - self, - parent_seq_id: int, - output_token: int, - logprobs: Dict[int, Logprob], - ) -> None: + def __init__(self, parent_seq_id: int, output_token: int, + logprobs: Dict[int, Logprob], + classification_probs: List[float]) -> None: self.parent_seq_id = parent_seq_id + self.output_classification_probs = classification_probs self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})") + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_classification_probs={self.output_classification_probs}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput):