Skip to content

Commit

Permalink
Call classification head inside sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj committed May 29, 2024
1 parent 069808b commit 6883da5
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 24 deletions.
4 changes: 2 additions & 2 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs,
child_sample.output_logits)
child_sample.output_classification_probs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs,
last_child_sample.output_logits)
last_child_sample.output_classification_probs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False

self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda")
self.classification_head.weight.data = torch.load("classification_head.pth", map_location="cuda").bfloat16()

def forward(
self,
logits: torch.Tensor,
Expand All @@ -61,6 +64,10 @@ def forward(

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

classification_probs = torch.nn.functional.sigmoid(
self.classification_head(logits)
).flatten().tolist()

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
Expand Down Expand Up @@ -110,7 +117,7 @@ def forward(
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(logits,
return _build_sampler_output(classification_probs,
sample_results,
sampling_metadata,
prompt_logprobs,
Expand Down Expand Up @@ -971,7 +978,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,


def _build_sampler_output(
logits,
classification_probs,
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
Expand Down Expand Up @@ -1000,7 +1007,7 @@ def _build_sampler_output(
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices):
seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, logits[sample_idx]))
seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, classification_probs[sample_idx]))
sampler_output.append(SequenceGroupOutput(seq_outputs, group_prompt_logprobs))

return SamplerOutput(
Expand Down
8 changes: 4 additions & 4 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
text: str,
token_ids: List[int],
cumulative_logprob: float,
logits: List[float],
classification_probs: List[float],
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
Expand All @@ -44,7 +44,7 @@ def __init__(
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
self.logits = logits
self.classification_probs = classification_probs

def finished(self) -> bool:
return self.finish_reason is not None
Expand All @@ -54,7 +54,7 @@ def __repr__(self) -> str:
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logits={self.logits[0][:3]} ... {self.logits[0][-3:]}, "
f"classification_probs={self.classification_probs}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
Expand Down Expand Up @@ -122,7 +122,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.get_output_logits(),
seq.get_output_classification_probs(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason,
Expand Down
30 changes: 15 additions & 15 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,23 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
output_logits: Optional[List[float]] = None,
output_classification_probs: Optional[List[float]] = None,
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
self.output_token_ids = output_token_ids
self.output_logits = output_logits or []
self.output_classification_probs = output_classification_probs or []
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL

def append_token_id(self, token_id: int, logprob: float, logits: List[float]) -> None:
def append_token_id(self, token_id: int, logprob: float, classification_probs: List[float]) -> None:
self.output_token_ids.append(token_id)
self.output_logits.append(logits)
self.output_classification_probs.append(classification_probs)
self.cumulative_logprob += logprob

def get_len(self) -> int:
Expand All @@ -138,8 +138,8 @@ def get_len(self) -> int:
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)

def get_output_logits(self) -> List[float]:
return self.output_logits
def get_output_classification_probs(self) -> List[float]:
return self.output_classification_probs

def get_output_len(self) -> int:
return len(self.output_token_ids)
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(
self.lora_request = lora_request

self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logits = []
self.output_classification_probs = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down Expand Up @@ -295,13 +295,13 @@ def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
logits: List[float],
classification_probs: List[float],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.output_logits.append(logits)
self.data.append_token_id(token_id, logprobs[token_id].logprob, logits)
self.output_classification_probs.append(classification_probs)
self.data.append_token_id(token_id, logprobs[token_id].logprob, classification_probs)

def get_len(self) -> int:
return self.data.get_len()
Expand All @@ -312,8 +312,8 @@ def get_prompt_len(self) -> int:
def get_output_len(self) -> int:
return self.data.get_output_len()

def get_output_logits(self) -> List[float]:
return self.data.get_output_logits()
def get_output_classification_probs(self) -> List[float]:
return self.data.get_output_classification_probs()

def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
Expand Down Expand Up @@ -656,16 +656,16 @@ def __init__(
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
logits
classification_probs: List[float]
) -> None:
self.parent_seq_id = parent_seq_id
self.output_logits = logits.tolist()
self.output_classification_probs = classification_probs
self.output_token = output_token
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_logits_len={len(self.output_logits)}, "
f"output_classification_probs={self.output_classification_probs}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")

Expand Down

0 comments on commit 6883da5

Please sign in to comment.