diff --git a/vllm/forward_context.py b/vllm/forward_context.py index cd136f43c0c57..7f56575279e9b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,9 +1,11 @@ import time -from collections import Counter +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Optional +import torch + import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger @@ -11,9 +13,10 @@ logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 -batchsize_counter: Counter = Counter() last_logging_time: float = 0 +forward_start_time: float = 0 batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL +batchsize_forward_time: defaultdict = defaultdict(list) @dataclass @@ -40,23 +43,10 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ - global track_batchsize, batchsize_counter - global last_logging_time, batchsize_logging_interval - if track_batchsize and context is not None: - if hasattr(context, "num_prefill_tokens"): - # for v0 attention backends - batchsize = context.num_prefill_tokens + context.num_decode_tokens - else: - # for v1 attention backends - batchsize = context.num_input_tokens - batchsize_counter[batchsize] += 1 - if time.monotonic() - last_logging_time > batchsize_logging_interval: - last_logging_time = time.monotonic() - sorted_data = sorted(batchsize_counter.items(), - key=lambda x: x[1], - reverse=True) - logger.info("Batchsize distribution (batchsize, count): %s", - sorted_data) + global forward_start_time + need_to_track_batchsize = track_batchsize and context is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -66,4 +56,37 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): try: yield finally: + global batchsize_counter + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(context, "num_prefill_tokens"): + # for v0 attention backends + batchsize = context.num_prefill_tokens + \ + context.num_decode_tokens + else: + # for v1 attention backends + batchsize = context.num_input_tokens + # we use synchronous scheduling right now, + # adding a sync point here should not affect + # scheduling of the next batch + torch.cuda.synchronize() + now = time.perf_counter() + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append( + (now - forward_start_time) * 1000) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + forward_stats = [] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = round(medium, 2) + forward_stats.append((bs, len(times), medium)) + forward_stats.sort(key=lambda x: x[1], reverse=True) + if forward_stats: + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s"), + forward_stats) _forward_context = prev_context