Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return classification probabilities from single step #1

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_sample.logprobs,
child_sample.output_classification_probs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
parent.append_token_id(
last_child_sample.output_token, last_child_sample.logprobs,
last_child_sample.output_classification_probs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
21 changes: 15 additions & 6 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False

self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda")
self.classification_head.weight.data = torch.load(
"classification_head.pth", map_location="cuda").bfloat16()

def forward(
self,
logits: torch.Tensor,
Expand All @@ -62,6 +66,9 @@ def forward(

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

classification_probs = torch.nn.functional.sigmoid(
self.classification_head(logits)).flatten().tolist()

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
Expand Down Expand Up @@ -111,7 +118,8 @@ def forward(
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results,
return _build_sampler_output(classification_probs,
sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
Expand Down Expand Up @@ -992,6 +1000,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,


def _build_sampler_output(
classification_probs,
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
Expand All @@ -1016,11 +1025,12 @@ def _build_sampler_output(
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
for parent_id, next_token_id, logprobs, sample_idx in zip(
parent_ids, next_token_ids, group_sample_logprobs,
seq_group.sample_indices):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
classification_probs[sample_idx]))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))

Expand All @@ -1031,7 +1041,6 @@ def _build_sampler_output(
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)

return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
Expand Down
19 changes: 12 additions & 7 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CompletionOutput:
text: str
token_ids: List[int]
cumulative_logprob: float
classification_probs: List[float]
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
Expand All @@ -43,6 +44,7 @@ def __repr__(self) -> str:
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"classification_probs={self.classification_probs}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
Expand Down Expand Up @@ -124,13 +126,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.get_output_classification_probs(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason,
) for seq in top_n_seqs
]

# Every sequence in the sequence group should have the same prompt.
Expand Down
37 changes: 25 additions & 12 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
output_classification_probs: Optional[List[float]] = None,
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
Expand All @@ -124,13 +125,16 @@ def __init__(
self.prompt_token_ids = prompt_token_ids
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self.output_token_ids = output_token_ids
self.output_classification_probs = output_classification_probs or []
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL

def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: float,
classification_prob: float) -> None:
self.output_token_ids.append(token_id)
self.output_classification_probs.append(classification_prob)
self.cumulative_logprob += logprob

def get_len(self) -> int:
Expand All @@ -139,6 +143,9 @@ def get_len(self) -> int:
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)

def get_output_classification_probs(self) -> List[float]:
return self.output_classification_probs

def get_output_len(self) -> int:
return len(self.output_token_ids)

Expand Down Expand Up @@ -230,8 +237,8 @@ def __init__(
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request

self.data = SequenceData(self.prompt_token_ids)
self.output_classification_probs: List[float] = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down Expand Up @@ -313,11 +320,14 @@ def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
classification_prob: float,
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.output_classification_probs.append(classification_prob)
self.data.append_token_id(token_id, logprobs[token_id].logprob,
classification_prob)

def get_len(self) -> int:
return self.data.get_len()
Expand All @@ -328,6 +338,9 @@ def get_prompt_len(self) -> int:
def get_output_len(self) -> int:
return self.data.get_output_len()

def get_output_classification_probs(self) -> List[float]:
return self.data.get_output_classification_probs()

def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()

Expand Down Expand Up @@ -708,20 +721,20 @@ class SequenceOutput:
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""

def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
) -> None:
def __init__(self, parent_seq_id: int, output_token: int,
logprobs: Dict[int, Logprob],
classification_probs: List[float]) -> None:
self.parent_seq_id = parent_seq_id
self.output_classification_probs = classification_probs
self.output_token = output_token
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")
return (
f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_classification_probs={self.output_classification_probs}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutput):
Expand Down
Loading