Skip to content

Commit

Permalink
update to call .cpu() before slicing to avoid recompilation
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Nov 17, 2024
1 parent d89200d commit 75c44b4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
7 changes: 7 additions & 0 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional

import torch
# import torch_xla.debug.metrics as met
import uvloop
from PIL import Image
from tqdm import tqdm
Expand Down Expand Up @@ -149,6 +150,8 @@ def run_vllm(

use_beam_search = False

# met.clear_all()

if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
Expand All @@ -168,6 +171,10 @@ def run_vllm(
ignore_eos=True,
))
end = time.perf_counter()

# print(met.metrics_report())
# print(met.short_metrics_report())

return end - start


Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/executor/tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

logger = init_logger(__name__)

# import torch_xla.debug.profiler as xp

class TPUExecutor:

Expand All @@ -28,6 +29,8 @@ def __init__(self, vllm_config: VllmConfig) -> None:
self.worker.initialize()
self.worker.load_model()

# self.server = xp.start_server(9012)

def _create_worker(
self,
local_rank: int = 0,
Expand Down Expand Up @@ -67,6 +70,7 @@ def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
# xp.trace_detached('localhost:9012', "./profiles")
output = self.worker.execute_model(scheduler_output)
return output

Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def execute_model(
is_prompt=False
)

token_ids = selected_token_ids[:num_decodes].cpu()
# NOTE: TPU<>CPU sync happens here.
token_ids = selected_token_ids.cpu()[:num_decodes]
sampled_token_ids_list = token_ids.tolist()
sampled_token_ids[:num_decodes] = token_ids

Expand Down Expand Up @@ -405,8 +406,8 @@ def execute_model(
self.kv_caches,
is_prompt=True
)
# TODO: move this into the model.
token_id = selected_token_ids[prompt_len - 1].cpu().item()
# NOTE: TPU<>CPU sync happens here.
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids[num_decodes + idx] = token_id
req_state = self.requests[req_id]

Expand Down

0 comments on commit 75c44b4

Please sign in to comment.