Skip to content

Commit

Permalink
[Core] Add span metrics for model_forward, scheduler and sampler time (
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeralapura authored Aug 9, 2024
1 parent 70d268a commit 933790c
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 21 deletions.
2 changes: 2 additions & 0 deletions tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,5 @@ def test_traces(trace_service):
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER
) == metrics.scheduler_time
1 change: 1 addition & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config,
is_driver_worker=True,
)
return model_runner
Expand Down
15 changes: 15 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,11 +1656,26 @@ class ObservabilityConfig:
"""Configuration for observability."""
otlp_traces_endpoint: Optional[str] = None

# Collecting detailed timing information for each request can be expensive.

# If set, collects the model forward time for the request.
collect_model_forward_time: bool = False

# If set, collects the model execute time for the request.
collect_model_execute_time: bool = False

def __post_init__(self):
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
raise ValueError("OpenTelemetry packages must be installed before "
"configuring 'otlp_traces_endpoint'")

if ((self.collect_model_forward_time
or self.collect_model_execute_time)
and self.otlp_traces_endpoint is None):
raise ValueError(
"collect_model_forward_time or collect_model_execute_time "
"requires --otlp-traces-endpoint to be set.")


@dataclass(frozen=True)
class EngineConfig:
Expand Down
12 changes: 12 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs = self._schedule()
now = time.time()
scheduler_start_time = time.perf_counter()

if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
Expand Down Expand Up @@ -1127,6 +1128,17 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

self._seq_group_metadata_cache.reset()

scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
# component in the e2e latency.
for seq_group in self.running:
if seq_group is not None and seq_group.metrics is not None:
if seq_group.metrics.scheduler_time is not None:
seq_group.metrics.scheduler_time += scheduler_time
else:
seq_group.metrics.scheduler_time = scheduler_time

return seq_group_metadata_list, scheduler_outputs

def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Expand Down
33 changes: 32 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

logger = init_logger(__name__)

ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]


def nullable_str(val: str):
if not val or val == "None":
Expand Down Expand Up @@ -117,6 +119,7 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None

otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -660,6 +663,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=str,
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')
parser.add_argument(
'--collect-detailed-traces',
type=str,
default=None,
help="Valid choices are " +
",".join(ALLOWED_DETAILED_TRACE_MODULES) +
". It makes sense to set this only if --otlp-traces-endpoint is"
" set. If set, it will collect detailed traces for the specified "
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")

return parser

Expand Down Expand Up @@ -852,8 +865,26 @@ def create_engine_config(self, ) -> EngineConfig:
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)

detailed_trace_modules = []
if self.collect_detailed_traces is not None:
detailed_trace_modules = self.collect_detailed_traces.split(",")
for m in detailed_trace_modules:
if m not in ALLOWED_DETAILED_TRACE_MODULES:
raise ValueError(
f"Invalid module {m} in collect_detailed_traces. "
f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
if (m == "model"
or m == "all") and self.pipeline_parallel_size > 1:
raise ValueError(
"Collection of detailed traces for the 'model' module is "
"not yet supported with pipeline parallelism.")
observability_config = ObservabilityConfig(
otlp_traces_endpoint=self.otlp_traces_endpoint)
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_model_forward_time="model" in detailed_trace_modules
or "all" in detailed_trace_modules,
collect_model_execute_time="worker" in detailed_trace_modules
or "all" in detailed_trace_modules,
)

if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled
Expand Down
29 changes: 29 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)

if not self.model_config.embedding_mode:
Expand Down Expand Up @@ -1183,6 +1184,22 @@ def _process_model_outputs(
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
for o in output:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
Expand Down Expand Up @@ -1575,6 +1592,18 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
if metrics.scheduler_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
metrics.scheduler_time)
if metrics.model_forward_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
metrics.model_forward_time / 1000.0)
if metrics.model_execute_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)

def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
Expand Down
7 changes: 4 additions & 3 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import List, Optional, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -32,6 +32,7 @@ def __init__(
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
Expand All @@ -43,7 +44,7 @@ def __init__(
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config

self.observability_config = observability_config
self._init_executor()

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _get_worker_kwargs(
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
)

def _get_create_worker_kwargs(
Expand Down
17 changes: 17 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,23 @@ class RequestMetrics:
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
scheduler_time: The time spent in the scheduler when this request was
being considered by the scheduler.
model_forward_time: The time spent in the model forward pass when this
request was in the batch.
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
"""
arrival_time: float
last_token_time: float
first_scheduled_time: Optional[float]
first_token_time: Optional[float]
time_in_queue: Optional[float]
finished_time: Optional[float] = None
scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None


class SequenceData:
Expand Down Expand Up @@ -968,6 +978,13 @@ class SamplerOutput:
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None

# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None

# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None

def __getitem__(self, idx: int):
return self.outputs[idx]

Expand Down
6 changes: 4 additions & 2 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(
multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
if return_hidden_states:
raise ValueError(
Expand All @@ -88,6 +89,7 @@ def __init__(
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)

self.flashinfer_decode_workspace_buffer = None
Expand Down
8 changes: 5 additions & 3 deletions vllm/spec_decode/target_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
Expand Down Expand Up @@ -32,7 +32,8 @@ def __init__(self,
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
Expand All @@ -49,6 +50,7 @@ def __init__(self,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)

def prepare_model_input(
Expand Down
6 changes: 6 additions & 0 deletions vllm/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class SpanAttributes(BaseSpanAttributes):
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
LLM_LATENCY_E2E = "gen_ai.latency.e2e"
LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Time taken in the forward pass for this across all workers
LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
# Time taken in the model execute function. This will include model
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"


def contains_trace_headers(headers: Mapping[str, str]) -> bool:
Expand Down
8 changes: 5 additions & 3 deletions vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
Expand Down Expand Up @@ -45,6 +45,7 @@ def __init__(
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
super().__init__(model_config,
parallel_config,
Expand All @@ -56,7 +57,8 @@ def __init__(
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config)
multimodal_config=multimodal_config,
observability_config=observability_config)

@torch.inference_mode()
def execute_model(
Expand Down
5 changes: 3 additions & 2 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
'''
EncoderDecoderModelRunner constructor.
Expand Down
Loading

0 comments on commit 933790c

Please sign in to comment.