diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8d9e177db1333..ae5583ffe0617 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,6 +1,8 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse import time +from pathlib import Path +from typing import Optional import numpy as np import torch @@ -34,12 +36,15 @@ def main(args: argparse.Namespace): print(sampling_params) dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size - def run_to_completion(profile: bool = False): - if profile: - with torch.profiler.profile(activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ]) as p: + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir))) as p: llm.generate(prompt_token_ids=dummy_prompt_token_ids, sampling_params=sampling_params, use_tqdm=False) @@ -54,11 +59,14 @@ def run_to_completion(profile: bool = False): return latency print("Warming up...") - run_to_completion(profile=False) + run_to_completion(profile_dir=None) if args.profile: - print("Profiling...") - run_to_completion(profile=True) + profile_dir = args.profile_result_dir + if not profile_dir: + profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=args.profile_result_dir) return # Benchmark. @@ -107,5 +115,13 @@ def run_to_completion(profile: bool = False): '--profile', action='store_true', help='profile the generation process of a single batch') + parser.add_argument( + '--profile-result-dir', + type=str, + default=None, + help=( + 'path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.' + )) args = parser.parse_args() main(args)