From 63b301a88ec19bfb054c47ae7f98509a55960338 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 22:21:49 +0000 Subject: [PATCH] updated --- vllm/v1/worker/tpu_model_runner.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 04750d42646db..048782d5e7b43 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -204,9 +204,8 @@ def _prepare_prefill_inputs( self, num_scheduled_tokens: List[int], ) -> PrefillInputData: - # Prefills run separately, each with shape [1, prompt_len], - # due to lack of variable length flashattention, so we - # create a list that will be used in execute_model() + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). prefill_request_ids = [] prefill_prompt_lens = [] @@ -229,7 +228,7 @@ def _prepare_prefill_inputs( # TOKEN_IDS. token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(-1, 1)) + idx, :padded_prompt_len].reshape(1, -1)) prefill_token_ids.append(token_ids.to(self.device)) # POSITIONS. @@ -258,6 +257,10 @@ def _prepare_prefill_inputs( context_lens=None, )) + print(f"PREFILL {token_ids.shape=}") + print(f"PREFILL {positions.shape=}") + print(f"PREFILL {slot_mapping.shape=}") + return PrefillInputData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, @@ -316,6 +319,12 @@ def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: # CONTEXT_LENS [batch_size] context_lens = (positions.reshape(-1) + 1) + print(f"{token_ids.shape=}") + print(f"{positions.shape=}") + print(f"{slot_mapping.shape=}") + print(f"{block_table.shape=}") + print(f"{context_lens.shape=}") + # CPU<>TPU sync happens here. return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids.to(self.device), @@ -344,7 +353,7 @@ def _prepare_inputs( num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) - # Assert Decodes Are Decodes. + # NOTE: assert that all the decodes are "decodes". if idx < num_decodes: assert num_tokens == 1 @@ -368,6 +377,7 @@ def _prepare_sampling( sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) return sampling_metadata + @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -378,7 +388,7 @@ def execute_model( sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) ######################### DECODES ######################### - # Decodes run as one single padded batch with shape [batch, 1] + # Decodes run as one single batch with [padded_batch, 1] if decode_data.num_decodes > 0: # FORWARD. @@ -410,6 +420,8 @@ def execute_model( req_state.output_token_ids.append(token_id) ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): @@ -440,14 +452,13 @@ def execute_model( self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id req_state.output_token_ids.append(token_id) - model_runner_output = ModelRunnerOutput( + return ModelRunnerOutput( req_ids=self.input_batch.req_ids[:num_reqs], req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids_cpu=sampled_token_ids, logprob_token_ids_cpu=None, logprobs_cpu=None, ) - return model_runner_output def load_model(self) -> None: