Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Feldman <[email protected]>
  • Loading branch information
abf149 committed Nov 26, 2024
1 parent 3460c18 commit 9ca0ce0
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,18 +329,15 @@ def _prepare_inputs(
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
maybe_sample_logits_indices = query_start_loc[1:] - 1
num_query_tokens = torch.diff(query_start_loc)

# One or more requests require prompt logprobs
complete_req_mask = torch.tensor(
[not x for x in scheduler_output.partial_running_reqs])

if do_prompt_logprobs:
prompt_logits_mask = torch.ones(num_input_tokens, dtype=torch.bool)
prompt_logits_mask[
maybe_sample_logits_indices[complete_req_mask]] = False
# Sequence offsets where a token is being decoded are *not* prompt
# tokens, unless the request in question is partial
prompt_logits_mask[maybe_sample_logits_indices[
~torch.tensor(scheduler_output.partial_running_reqs)]] = False

return (input_ids, attn_metadata, num_query_tokens,
maybe_sample_logits_indices, prompt_logits_mask)
Expand Down Expand Up @@ -448,6 +445,9 @@ def execute_model(

sampling_metadata = self._prepare_sampling(scheduler_output)

do_logprobs = sampling_metadata.max_num_logprobs > 0
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0

num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
Expand All @@ -459,9 +459,6 @@ def execute_model(
# Eager mode.
num_input_tokens = num_scheduled_tokens

do_logprobs = sampling_metadata.max_num_logprobs > 0
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0

# Prepare the decoder inputs.
(
input_ids,
Expand Down

0 comments on commit 9ca0ce0

Please sign in to comment.