Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Nov 17, 2024
1 parent 5a87b99 commit 63b301a
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ def _prepare_prefill_inputs(
self,
num_scheduled_tokens: List[int],
) -> PrefillInputData:
# Prefills run separately, each with shape [1, prompt_len],
# due to lack of variable length flashattention, so we
# create a list that will be used in execute_model()
# Each prefill run separately with shape [1, padded_prompt_len].
# So we create lists that will be used in execute_model().

prefill_request_ids = []
prefill_prompt_lens = []
Expand All @@ -229,7 +228,7 @@ def _prepare_prefill_inputs(

# TOKEN_IDS.
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[
idx, :padded_prompt_len].reshape(-1, 1))
idx, :padded_prompt_len].reshape(1, -1))
prefill_token_ids.append(token_ids.to(self.device))

# POSITIONS.
Expand Down Expand Up @@ -258,6 +257,10 @@ def _prepare_prefill_inputs(
context_lens=None,
))

print(f"PREFILL {token_ids.shape=}")
print(f"PREFILL {positions.shape=}")
print(f"PREFILL {slot_mapping.shape=}")

return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
Expand Down Expand Up @@ -316,6 +319,12 @@ def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData:
# CONTEXT_LENS [batch_size]
context_lens = (positions.reshape(-1) + 1)

print(f"{token_ids.shape=}")
print(f"{positions.shape=}")
print(f"{slot_mapping.shape=}")
print(f"{block_table.shape=}")
print(f"{context_lens.shape=}")

# CPU<>TPU sync happens here.
return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids.to(self.device),
Expand Down Expand Up @@ -344,7 +353,7 @@ def _prepare_inputs(
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)

# Assert Decodes Are Decodes.
# NOTE: assert that all the decodes are "decodes".
if idx < num_decodes:
assert num_tokens == 1

Expand All @@ -368,6 +377,7 @@ def _prepare_sampling(
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata

@torch.no_grad()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
Expand All @@ -378,7 +388,7 @@ def execute_model(
sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32)

######################### DECODES #########################
# Decodes run as one single padded batch with shape [batch, 1]
# Decodes run as one single batch with [padded_batch, 1]
if decode_data.num_decodes > 0:

# FORWARD.
Expand Down Expand Up @@ -410,6 +420,8 @@ def execute_model(
req_state.output_token_ids.append(token_id)

######################### PREFILLS #########################
# Prefills run separately with shape [1, padded_prefill_len],
# due to lack of variable length attention kernel so far.
for idx, (req_id, prompt_len, token_ids, position_ids,
attn_metadata) in enumerate(prefill_data.zipped()):

Expand Down Expand Up @@ -440,14 +452,13 @@ def execute_model(
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id
req_state.output_token_ids.append(token_id)

model_runner_output = ModelRunnerOutput(
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
logprob_token_ids_cpu=None,
logprobs_cpu=None,
)
return model_runner_output

def load_model(self) -> None:

Expand Down

0 comments on commit 63b301a

Please sign in to comment.