From 58e85eba90f9c438f3db484bf6fc011e20c2aea1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 19:26:00 +0000 Subject: [PATCH] a bit faster --- vllm/v1/worker/tpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4c7e6ec543337..868bb41d17365 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -376,6 +376,7 @@ def execute_model( ) # NOTE: TPU<>CPU sync happens here. + # It is important to call .cpu() first to avoid compilation on hotpath. token_ids = selected_token_ids.cpu()[:num_decodes] sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids @@ -407,6 +408,7 @@ def execute_model( is_prompt=True ) # NOTE: TPU<>CPU sync happens here. + # It is important to call .cpu() first to avoid compilation on hotpath. token_id = selected_token_ids.cpu()[prompt_len - 1].item() sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id]