Skip to content

Commit

Permalink
attempted sample_metadata fix; sample logprobs work, prompt logprobs …
Browse files Browse the repository at this point in the history
…broken

Signed-off-by: Andrew Feldman <[email protected]>
  • Loading branch information
abf149 committed Nov 26, 2024
1 parent 9ca0ce0 commit d277d37
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 65 deletions.
31 changes: 17 additions & 14 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def schedule(self) -> "SchedulerOutput":
# V1 model runner.
# TODO(woosuk): Remove this constraint after refactoring model runner.
has_partial_request = False
partial_req_index = -1
req_index = 0
while req_index < len(self.running):
# Only the last request in the RUNNING queue can be "partial".
Expand Down Expand Up @@ -158,9 +159,11 @@ def schedule(self) -> "SchedulerOutput":
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
if (request.num_computed_tokens + num_new_tokens <
request.num_tokens):
has_partial_request = True
partial_req_index = req_index
req_index += 1
has_partial_request = (request.num_computed_tokens + num_new_tokens
< request.num_tokens)

# Encoder-related.
if encoder_inputs_to_schedule:
Expand Down Expand Up @@ -236,8 +239,10 @@ def schedule(self) -> "SchedulerOutput":
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
has_partial_request = (num_computed_tokens + num_new_tokens <
request.num_tokens)
if (request.num_computed_tokens + num_new_tokens <
request.num_tokens):
has_partial_request = True
partial_req_index = req_index

# Encoder-related.
if encoder_inputs_to_schedule:
Expand All @@ -248,13 +253,6 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Now that requests are scheduled, generate a mask indicating which
# request is partial
partial_running_reqs = [
(req.num_computed_tokens + num_scheduled_tokens[req.request_id] <
req.num_tokens) for req in self.running
]

# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
Expand Down Expand Up @@ -285,7 +283,7 @@ def schedule(self) -> "SchedulerOutput":
scheduled_new_reqs=new_reqs_data,
scheduled_resumed_reqs=resumed_reqs_data,
scheduled_running_reqs=running_reqs_data,
partial_running_reqs=partial_running_reqs,
partial_req_index=partial_req_index,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
Expand Down Expand Up @@ -470,9 +468,14 @@ def update_from_output(

if do_prompt_logprobs:
max_prompt_logprobs = request.max_prompt_logprobs
# Number of new prompt tokens is the number of scheduled
# tokens *if* the request is partial (because the sampled
# token is discarded and all sequence offsets are prompt
# offsets), otherwise it is the number of scheduled
# tokens minus one (for the sampled token)
num_new_prompt_tokens = (
num_scheduled_tokens[request.request_id] -
int(not scheduler_output.partial_running_reqs[req_index]))
int(scheduler_output.partial_req_index != req_index))

request_do_prompt_logprobs = (max_prompt_logprobs is not None
and max_prompt_logprobs > 0
Expand Down Expand Up @@ -774,7 +777,7 @@ class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
scheduled_resumed_reqs: List[ResumedRequestData]
scheduled_running_reqs: List[RunningRequestData]
partial_running_reqs: List[bool] # True if running req is partial
partial_req_index: int # >0 if running req is partial, -1 o/w

num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class SamplingMetadata:
max_num_logprobs: int
max_num_prompt_logprobs: int

num_query_tokens: Optional[torch.Tensor] = None
maybe_sample_logits_indices: Optional[torch.Tensor] = None
prompt_logits_mask: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor]
num_query_tokens: Optional[torch.Tensor]
#maybe_sample_logits_indices: Optional[torch.Tensor] = None
#prompt_logits_mask: Optional[torch.Tensor] = None

num_input_tokens: int
partial_req_index: int # >0 if there is a partial request, -1 o/w
15 changes: 12 additions & 3 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,18 @@ def forward(
do_any_logprobs = do_logprobs or do_prompt_logprobs

num_query_tokens = sampling_metadata.num_query_tokens
maybe_sample_logits_indices = (
sampling_metadata.maybe_sample_logits_indices)
prompt_logits_mask = sampling_metadata.prompt_logits_mask
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
maybe_sample_logits_indices = sampling_metadata.query_start_loc[1:] - 1
prompt_logits_mask = torch.ones(sampling_metadata.num_input_tokens,
dtype=torch.bool)
# Sequence offsets where a token is being decoded are *not* prompt
# tokens...
prompt_logits_mask[maybe_sample_logits_indices] = False
# ...unless the request in question is partial.
prompt_logits_mask[sampling_metadata.partial_req_index] = True

# Apply temperature, top-k and top-p to logits at sequence offsets
# where a new token is being decoded.
Expand Down
99 changes: 54 additions & 45 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
sampling_metadata: SamplingMetadata,
num_input_tokens: int,
) -> Tuple[torch.Tensor, FlashAttentionMetadata, torch.Tensor,
torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0

total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
Expand Down Expand Up @@ -291,14 +289,7 @@ def _prepare_inputs(
out=slot_mapping)

# Prepare the attention metadata.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])

query_start_loc = sampling_metadata.query_start_loc
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
Expand All @@ -313,7 +304,6 @@ def _prepare_inputs(
input_ids = input_ids.to(self.device, non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
Expand All @@ -329,26 +319,12 @@ def _prepare_inputs(
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
maybe_sample_logits_indices = query_start_loc[1:] - 1
num_query_tokens = torch.diff(query_start_loc)

if do_prompt_logprobs:
prompt_logits_mask = torch.ones(num_input_tokens, dtype=torch.bool)
# Sequence offsets where a token is being decoded are *not* prompt
# tokens, unless the request in question is partial
prompt_logits_mask[maybe_sample_logits_indices[
~torch.tensor(scheduler_output.partial_running_reqs)]] = False

return (input_ids, attn_metadata, num_query_tokens,
maybe_sample_logits_indices, prompt_logits_mask)
else:
# No requests require prompt logprobs
return (input_ids, attn_metadata, num_query_tokens,
maybe_sample_logits_indices, None)
return (input_ids, attn_metadata)

def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
num_input_tokens: int,
) -> SamplingMetadata:
skip_copy = True
if (scheduler_output.finished_req_ids
Expand All @@ -358,7 +334,11 @@ def _prepare_sampling(
or scheduler_output.scheduled_resumed_reqs):
skip_copy = False
# Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
sampling_metadata = self.input_batch.make_sampling_metadata(
scheduler_output,
num_input_tokens,
skip_copy,
)
return sampling_metadata

def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
Expand Down Expand Up @@ -443,11 +423,6 @@ def execute_model(
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)

sampling_metadata = self._prepare_sampling(scheduler_output)

do_logprobs = sampling_metadata.max_num_logprobs > 0
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0

num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
Expand All @@ -459,16 +434,17 @@ def execute_model(
# Eager mode.
num_input_tokens = num_scheduled_tokens

sampling_metadata = self._prepare_sampling(scheduler_output,
num_input_tokens)
do_logprobs = sampling_metadata.max_num_logprobs > 0
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0

# Prepare the decoder inputs.
(
input_ids,
attn_metadata,
num_query_tokens,
maybe_sample_logits_indices,
prompt_logits_mask,
) = self._prepare_inputs(scheduler_output=scheduler_output,
sampling_metadata=sampling_metadata,
num_input_tokens=num_input_tokens)
sampling_metadata=sampling_metadata)

# Get the inputs embeds.
if encoder_outputs:
Expand All @@ -494,11 +470,6 @@ def execute_model(

hidden_states = hidden_states[:num_scheduled_tokens]

sampling_metadata.num_query_tokens = num_query_tokens
sampling_metadata.maybe_sample_logits_indices = (
maybe_sample_logits_indices)
sampling_metadata.prompt_logits_mask = prompt_logits_mask

# Sample the next token and get logprobs if needed.
sampler_output = self.model.sample(
logits=self.model.compute_logits(hidden_states, None),
Expand Down Expand Up @@ -855,6 +826,8 @@ def condense(self, empty_req_indices: List[int]) -> None:

def make_sampling_metadata(
self,
scheduler_output: "SchedulerOutput",
num_input_tokens: int,
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
Expand All @@ -864,8 +837,36 @@ def make_sampling_metadata(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)

num_reqs = self.num_reqs

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for req_id in self.req_ids[:num_reqs]:
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
assert max_num_scheduled_tokens > 0

# Compute query start offsets. It makes sense to compute this here
# rather than in model runner _prepare_inputs() because query start
# offsets are required for computing num_query_tokens in the scenario
# where prompt logprobs are required by the batch.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
query_start_loc = query_start_loc.to(self.device, non_blocking=True)

return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
temperature=self.temperature[:num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
Expand All @@ -874,7 +875,15 @@ def make_sampling_metadata(
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
max_num_prompt_logprobs=self.max_num_prompt_logprobs)
max_num_prompt_logprobs=self.max_num_prompt_logprobs,
query_start_loc=query_start_loc,
num_input_tokens=num_input_tokens,
partial_req_index=scheduler_output.partial_req_index,
# Required for prompt logprobs temperature computation.
# If prompt logprobs is not required for this batch, then
# avoid storing num_query_tokens
num_query_tokens=(torch.diff(query_start_loc)
if self.max_num_prompt_logprobs > 0 else None))

@property
def num_reqs(self) -> int:
Expand Down

0 comments on commit d277d37

Please sign in to comment.