From f69bdea94c124d4625a79d5ef903df9b6b448dc2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 00:16:04 +0000 Subject: [PATCH 01/33] prototype tpu on v1 --- vllm/attention/selector.py | 10 +- vllm/config.py | 2 + vllm/utils.py | 3 + vllm/v1/attention/backends/flash_attn.py | 3 +- vllm/v1/attention/backends/pallas.py | 304 ++++++++ vllm/v1/core/scheduler.py | 9 +- vllm/v1/engine/llm_engine.py | 8 +- vllm/v1/executor/tpu_executor.py | 76 ++ vllm/v1/worker/tpu_model_runner.py | 898 +++++++++++++++++++++++ vllm/v1/worker/tpu_worker.py | 188 +++++ 10 files changed, 1495 insertions(+), 6 deletions(-) create mode 100644 vllm/v1/attention/backends/pallas.py create mode 100644 vllm/v1/executor/tpu_executor.py create mode 100644 vllm/v1/worker/tpu_model_runner.py create mode 100644 vllm/v1/worker/tpu_worker.py diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65d..5f3745522aec5 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -25,6 +25,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() @@ -140,6 +141,10 @@ def _cached_get_attn_backend( from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend as FlashAttentionBackendV1) return FlashAttentionBackendV1 + if backend == _Backend.PALLAS_VLLM_V1: + from vllm.v1.attention.backends.pallas import ( # noqa: F401 + PallasAttentionBackend as PallasAttentionBackendV1) + return PallasAttentionBackendV1 if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -232,6 +237,8 @@ def which_attn_to_use(head_size: int, return _Backend.IPEX if current_platform.is_tpu(): + if selected_backend == _Backend.PALLAS_VLLM_V1: + return _Backend.PALLAS_VLLM_V1 if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) return _Backend.PALLAS @@ -252,7 +259,8 @@ def which_attn_to_use(head_size: int, return _Backend.HPU_ATTN if use_v1: - return _Backend.FLASH_ATTN_VLLM_V1 + # return _Backend.FLASH_ATTN_VLLM_V1 + return _Backend.PALLAS_VLLM_V1 # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: diff --git a/vllm/config.py b/vllm/config.py index f9b230e1bc688..d8101ad9393a8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1222,6 +1222,8 @@ def __init__(self, device: str = "auto") -> None: # Some device types require processing inputs on CPU if self.device_type in ["neuron", "openvino"]: self.device = torch.device("cpu") + # Device initialization should happen after initializing the + # distributed runtime. elif self.device_type in ["tpu"]: self.device = None else: diff --git a/vllm/utils.py b/vllm/utils.py index 1b02cbff79f78..4a548496a4785 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -728,6 +728,9 @@ def is_pin_memory_available() -> bool: elif current_platform.is_hpu(): print_warning_once("Pin memory is not supported on HPU.") return False + elif current_platform.is_tpu(): + print_warning_once("Pin memory is not supported on TPU.") + return False elif current_platform.is_cpu() or current_platform.is_openvino(): return False return True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e73a1e60b2730..057f281c220a6 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,7 +8,6 @@ AttentionMetadata, AttentionType) from vllm.forward_context import get_forward_context from vllm.utils import direct_register_custom_op -from vllm.vllm_flash_attn import flash_attn_varlen_func class FlashAttentionBackend(AttentionBackend): @@ -202,6 +201,8 @@ def unified_v1_flash_attention( v_scale, ) + from vllm.vllm_flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 0000000000000..a6cf7cbaa802b --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,304 @@ +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [128] + + @staticmethod + def get_name() -> str: + return "pallas-vllm-v1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionImpl"]: + return PallasAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasAttentionMetadata"]: + return PallasAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + +@dataclass +class PallasAttentionMetadata: + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_start_loc: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + +class PallasAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = PallasAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PallasAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: PallasAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in PallasAttentionImpl.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionImpl") + + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + # Write to KV cache. + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention requires [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a60f8b8138ecf..3fe486e7a1cf5 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -146,7 +146,14 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= 1 num_new_tokens = 1 computed_blocks.pop() - num_new_tokens = min(num_new_tokens, token_budget) + + # Disabled Chunking. + if not self.scheduler_config.chunked_prefill_enabled: + if num_new_tokens > token_budget: + break + else: + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, computed_blocks) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 38d95ab44bb90..8edff5cb446a4 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,7 +21,8 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.core.scheduler import Scheduler -from vllm.v1.executor.gpu_executor import GPUExecutor +# from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs from vllm.version import __version__ as VLLM_VERSION @@ -34,7 +35,7 @@ class LLMEngine: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[TPUExecutor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -489,7 +490,8 @@ def get_lora_config(self) -> LoRAConfig: @classmethod def _get_executor_cls(cls, engine_config: VllmConfig): - return GPUExecutor + # return GPUExecutor + return TPUExecutor def is_tracing_enabled(self) -> bool: return False diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py new file mode 100644 index 0000000000000..66fa2cfed91e8 --- /dev/null +++ b/vllm/v1/executor/tpu_executor.py @@ -0,0 +1,76 @@ +from typing import Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_worker import TPUWorker + +logger = init_logger(__name__) + + +class TPUExecutor: + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.worker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None + ) -> TPUWorker: + """Return worker init args for a given rank.""" + + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + return TPUWorker( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d", num_gpu_blocks) + self.worker.initialize_cache(num_gpu_blocks) + self.worker.compile_or_warm_up_model() + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + output = self.worker.execute_model(scheduler_output) + return output + + def check_health(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 0000000000000..54bc579b561cc --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,898 @@ +import os +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm import envs +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.config import CompilationConfig +from vllm.compilation.levels import CompilationLevel +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MultiModalDataDict +from vllm.plugins import set_compilation_config +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, + is_pin_memory_available) +from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionMetadata) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + ): + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[torch.Tensor] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=sampling_params, + generator=generator, + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table[:num_reqs].copy_( + self.input_batch.block_table_cpu_tensor[:num_reqs], + non_blocking=True) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + assert max_num_scheduled_tokens > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + indices = np.arange(num_reqs) + req_indices = np.repeat(indices, num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + (num_reqs, 1)) + mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + arange = arange_matrix[mask] + + # Get positions. + positions = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + positions_np = positions.numpy() + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.max_model_len + token_indices = torch.from_numpy(token_indices) + input_ids = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.index_select(torch.from_numpy( + self.input_batch.token_ids_cpu).flatten(), + 0, + token_indices, + out=input_ids) + + # Calculate the slot mapping. + block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ + token_indices // self.block_size] + block_offsets = token_indices % self.block_size + slot_mapping = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.add(block_numbers * self.block_size, + block_offsets, + out=slot_mapping) + + # Prepare the attention metadata. + query_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + + seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = seq_lens.max() + seq_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + seq_start_loc_np = seq_start_loc.numpy() + seq_start_loc_np[0] = 0 + np.cumsum(seq_lens, out=seq_start_loc_np[1:]) + + self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, + non_blocking=True) + self.positions[:total_num_scheduled_tokens].copy_(positions, + non_blocking=True) + + query_start_loc = query_start_loc.to(self.device, non_blocking=True) + seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) + slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_start_loc=seq_start_loc, + block_table=self.input_batch.block_table[:num_reqs], + slot_mapping=slot_mapping, + ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return attn_metadata, logits_indices + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + return sampling_metadata + + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if True: + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self._get_padded_batch_size( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + + with set_forward_context(attn_metadata): + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + kv_caches=self.kv_caches, + attn_metadata=None, + ) + hidden_states = hidden_states[:num_scheduled_tokens] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self._prepare_sampling(scheduler_output) + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + # NOTE: CPU-GPU synchronization happens here. + sampled_token_ids = sampler_output.sampled_token_ids.cpu() + sampled_token_ids_list = sampled_token_ids.tolist() + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + num_reqs = self.input_batch.num_reqs + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len <= req_state.num_tokens + if seq_len == req_state.num_tokens: + # Append the sampled token to the output token ids. + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + if sampler_output.logprob_token_ids is None: + logprob_token_ids = None + else: + logprob_token_ids = sampler_output.logprob_token_ids.cpu() + if sampler_output.logprobs is None: + logprobs = None + else: + logprobs = sampler_output.logprobs.cpu() + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + return model_runner_output + + def load_model(self) -> None: + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + # xm_tp_rank = xr.global_ordinal() + # with patch( + # "vllm.model_executor.layers.vocab_parallel_embedding." + # "get_tensor_model_parallel_rank", + # return_value=xm_tp_rank): + # model = get_model(vllm_config=self.vllm_config) + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + is_prompt: bool + ) -> None: + assert is_prompt + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(self.num_attn_layers)] + + num_total_tokens = batch_size * seq_len + assert num_total_tokens <= self.input_ids.numel() + + slot_mapping = torch.zeros((num_total_tokens, ), + dtype=torch.int64, + device=self.device) + query_start_loc = torch.ones((batch_size + 1, ), + dtype=torch.int32, + device=self.device) + seq_start_loc = torch.empty((batch_size + 1, ), + dtype=torch.int32, + device=self.device) + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_total_tokens, + max_query_len=num_total_tokens, + query_start_loc=query_start_loc, + max_seq_len=num_total_tokens, + seq_start_loc=seq_start_loc, + block_table=self.input_batch.block_table[:batch_size], + slot_mapping=slot_mapping, + ) + + self.model(self.input_ids, + self.positions, + attn_metadata, + kv_caches, + is_prompt=is_prompt) + + + def profile_run(self) -> None: + self._dummy_run(batch_size=1, + seq_len=self.max_num_tokens, + is_prompt=True) + xm.wait_device_ops() + + + def capture_model(self) -> None: + if not self.use_cuda_graph: + logger.warning( + "Skipping CUDA graph capture. Please set " + "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", + CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + with set_forward_context(None): + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for num_tokens in reversed(self.cudagraph_batch_sizes): + self.model( + self.input_ids[:num_tokens], + self.positions[:num_tokens], + kv_caches=self.kv_caches, + attn_metadata=None, + ) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device)) + + def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: + # TODO: Optimize this? + for size in self.cudagraph_batch_sizes: + if batch_size <= size: + return size + return None + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + num_samples = 1 + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + return argmax_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 \ No newline at end of file diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 0000000000000..411c180a81d4e --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,188 @@ +"""A TPU worker class.""" + +import os +from typing import TYPE_CHECKING, Tuple + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +class TPUWorker: + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE: This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config) + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + def load_model(self): + self.model_runner.load_model() + + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +# TODO: this is a duplicate. +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total \ No newline at end of file From 1142c89a9ce4240d7bc1e640f96447516cd05bc1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 01:48:42 +0000 Subject: [PATCH 02/33] profile run complete --- vllm/v1/attention/backends/pallas.py | 19 +++------- vllm/v1/executor/tpu_executor.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 53 +++++++++++----------------- 3 files changed, 25 insertions(+), 49 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index a6cf7cbaa802b..a078e3bbf9333 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -42,20 +42,9 @@ def get_kv_cache_shape( @dataclass class PallasAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_start_loc: torch.Tensor - block_table: torch.Tensor + is_prompt: bool + block_tables: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] slot_mapping: torch.Tensor @@ -166,7 +155,7 @@ def forward( write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - if attn_metadata.num_prefills > 0: + if attn_metadata.is_prompt: assert seq_len % 16 == 0, ( "Pallas FlashAttention kernel requires seq_len to be a " f"multiple of 16 but got {seq_len}") diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py index 66fa2cfed91e8..122917a5e242d 100644 --- a/vllm/v1/executor/tpu_executor.py +++ b/vllm/v1/executor/tpu_executor.py @@ -71,6 +71,6 @@ def execute_model( return output def check_health(self) -> None: - # GPUExecutor will always be healthy as long as + # TPUExecutor will always be healthy as long as # it's running. return diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 54bc579b561cc..5ec4dde118d25 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -11,9 +11,6 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -from vllm import envs -from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig @@ -21,12 +18,11 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalDataDict -from vllm.plugins import set_compilation_config from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available) -from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionMetadata) +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -35,11 +31,6 @@ logger = init_logger(__name__) -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 - - class TPUModelRunner: def __init__( @@ -296,7 +287,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc = query_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() - attn_metadata = FlashAttentionMetadata( + attn_metadata = PallasAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, query_start_loc=query_start_loc, @@ -447,31 +438,27 @@ def _dummy_run( device=self.device)) for _ in range(self.num_attn_layers)] - num_total_tokens = batch_size * seq_len - assert num_total_tokens <= self.input_ids.numel() + seq_len = (seq_len + 15) // 16 * 16 - slot_mapping = torch.zeros((num_total_tokens, ), - dtype=torch.int64, - device=self.device) - query_start_loc = torch.ones((batch_size + 1, ), - dtype=torch.int32, - device=self.device) - seq_start_loc = torch.empty((batch_size + 1, ), - dtype=torch.int32, + input_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + positions = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, device=self.device) - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_total_tokens, - max_query_len=num_total_tokens, - query_start_loc=query_start_loc, - max_seq_len=num_total_tokens, - seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:batch_size], + attn_metadata = PallasAttentionMetadata( + is_prompt=is_prompt, + block_tables=None, + context_lens=None, slot_mapping=slot_mapping, ) - self.model(self.input_ids, - self.positions, + self.model(input_ids, + positions, attn_metadata, kv_caches, is_prompt=is_prompt) @@ -828,7 +815,7 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: PallasAttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. From 9cc4fbee720f96dac24aafee2da12e85693ef0ba Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 02:23:24 +0000 Subject: [PATCH 03/33] actually dummy run --- vllm/attention/backends/pallas.py | 2 +- vllm/attention/selector.py | 6 ++--- vllm/v1/attention/backends/pallas.py | 4 +-- vllm/v1/worker/tpu_model_runner.py | 39 ++++++++++++---------------- 4 files changed, 22 insertions(+), 29 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 6fee81de14420..5c3e8dcd3e88a 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -147,7 +147,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5f3745522aec5..3e3c9bbb0a508 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -237,11 +237,9 @@ def which_attn_to_use(head_size: int, return _Backend.IPEX if current_platform.is_tpu(): - if selected_backend == _Backend.PALLAS_VLLM_V1: - return _Backend.PALLAS_VLLM_V1 - if selected_backend != _Backend.PALLAS: + if selected_backend != _Backend.PALLAS_VLLM_V1: logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS + return _Backend.PALLAS_VLLM_V1 if current_platform.is_rocm(): # AMD GPUs. diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index a078e3bbf9333..a74aad0108190 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -43,9 +43,9 @@ def get_kv_cache_shape( class PallasAttentionMetadata: is_prompt: bool + slot_mapping: torch.Tensor block_tables: Optional[torch.Tensor] context_lens: Optional[torch.Tensor] - slot_mapping: torch.Tensor class PallasAttentionImpl(AttentionImpl): @@ -149,7 +149,7 @@ def forward( self.head_size) # Write to KV cache. - if kv_cache[0].numel() > 0: + if kv_cache.numel() > 0: slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5ec4dde118d25..c4a41fd81ffe1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -417,12 +417,7 @@ def load_model(self) -> None: xm.wait_device_ops() self.model = ModelWrapper(model) - def _dummy_run( - self, - batch_size: int, - seq_len: int, - is_prompt: bool - ) -> None: + def _dummy_run(self, batch_size: int, seq_len: int, is_prompt: bool): assert is_prompt # use an empty tensor instead of `None`` to force Dynamo to pass @@ -432,36 +427,36 @@ def _dummy_run( # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. - kv_caches = [(torch.tensor([], dtype=torch.float32, - device=self.device), - torch.tensor([], dtype=torch.float32, - device=self.device)) - for _ in range(self.num_attn_layers)] + dummy_kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(self.num_attn_layers) + ] seq_len = (seq_len + 15) // 16 * 16 input_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=self.device) - positions = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) attn_metadata = PallasAttentionMetadata( is_prompt=is_prompt, + slot_mapping=slot_mapping, block_tables=None, context_lens=None, - slot_mapping=slot_mapping, ) - self.model(input_ids, - positions, - attn_metadata, - kv_caches, - is_prompt=is_prompt) + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + + self.model(input_ids, position_ids, attn_metadata, + dummy_kv_caches, is_prompt=is_prompt) def profile_run(self) -> None: @@ -816,7 +811,7 @@ def forward( token_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: PallasAttentionMetadata, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -829,7 +824,7 @@ def forward( """ # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: + if kv_caches[0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it From 61f7792cfec0c583fde9e909aeb1c8435ba76f0e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:13:38 +0000 Subject: [PATCH 04/33] stash --- vllm/v1/worker/tpu_model_runner.py | 209 ++++++++++++++++++----------- 1 file changed, 132 insertions(+), 77 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c4a41fd81ffe1..22509028104e3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,6 +29,9 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput +MIN_BATCH_SIZE = 8 +BATCH_SIZE_MULTIPLE = 16 + logger = init_logger(__name__) class TPUModelRunner: @@ -406,6 +409,7 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. + # xm_tp_rank = xr.global_ordinal() # with patch( # "vllm.model_executor.layers.vocab_parallel_embedding." @@ -417,8 +421,80 @@ def load_model(self) -> None: xm.wait_device_ops() self.model = ModelWrapper(model) - def _dummy_run(self, batch_size: int, seq_len: int, is_prompt: bool): - assert is_prompt + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[torch.Tensor], + is_prompt: bool + ) -> None: + """Dummy warmup run for memory usage and graph compilation.""" + + input_ids = torch.zeros( + (batch_size, seq_len), + dtype=torch.int32, + device=self.device + ) + position_ids = torch.zeros( + (batch_size, seq_len), + dtype=torch.int32, + device=self.device + ) + slot_mapping = torch.zeros( + (batch_size, seq_len), + dtype=torch.int64, + device=self.device + ) + + if is_prompt: + block_tables = None + context_lens = None + else: + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device, + ) + context_lens = torch.ones( + (batch_size, ), + dtype=torch.int32, + device=self.device, + ) + attn_metadata = PallasAttentionMetadata( + is_prompt=is_prompt, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE: There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + # Dummy run. + self.model(input_ids, + position_ids, + attn_metadata, + kv_caches, + is_prompt=is_prompt) + + + def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. @@ -432,74 +508,61 @@ def _dummy_run(self, batch_size: int, seq_len: int, is_prompt: bool): for _ in range(self.num_attn_layers) ] - seq_len = (seq_len + 15) // 16 * 16 - - input_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - - attn_metadata = PallasAttentionMetadata( - is_prompt=is_prompt, - slot_mapping=slot_mapping, - block_tables=None, - context_lens=None, - ) - - torch._dynamo.mark_dynamic(input_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - - self.model(input_ids, position_ids, attn_metadata, - dummy_kv_caches, is_prompt=is_prompt) - + # Round to multiple of 16. + seq_len = (self.max_num_tokens + 15) // 16 * 16 - def profile_run(self) -> None: - self._dummy_run(batch_size=1, - seq_len=self.max_num_tokens, - is_prompt=True) - xm.wait_device_ops() + # Run empty forward. + self._dummy_run( + batch_size=1, + seq_len=seq_len, + kv_caches=dummy_kv_caches, + is_prompt=True) def capture_model(self) -> None: - if not self.use_cuda_graph: - logger.warning( - "Skipping CUDA graph capture. Please set " - "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", - CompilationLevel.PIECEWISE) - return + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Prefill shapes. + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode shapes. + start = time.time() + seq_len = 1 + batch_size = MIN_BATCH_SIZE # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) - start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] - - with set_forward_context(None): - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - for num_tokens in reversed(self.cudagraph_batch_sizes): - self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - kv_caches=self.kv_caches, - attn_metadata=None, - ) - - end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] - elapsed_time = end_time - start_time - cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory - # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) def initialize_kv_cache(self, num_blocks: int) -> None: assert len(self.kv_caches) == 0 - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) for _ in range(self.num_attn_layers): self.kv_caches.append( @@ -508,11 +571,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None: device=self.device)) def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: - # TODO: Optimize this? - for size in self.cudagraph_batch_sizes: - if batch_size <= size: - return size - return None + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8 (== MIN_BATCH_SIZE). + if batch_size <= MIN_BATCH_SIZE: + return MIN_BATCH_SIZE + else: + return ((batch_size + 15) // 16) * 16 @dataclass @@ -868,13 +933,3 @@ def _get_padded_prefill_len(x: int) -> int: if x <= 16: return 16 return 1 << (x - 1).bit_length() - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 \ No newline at end of file From b8c644401905c7ee34272abefbe256d22dfac5ef Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 3 Nov 2024 20:07:05 +0000 Subject: [PATCH 05/33] updated Signed-off-by: Robert Shaw --- .buildkite/run-tpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 770dad6ffa3a1..3591804b4d208 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api] && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" From 75e2e5398e9477b4ae1d8d2a74df460760152ab8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:17:33 +0000 Subject: [PATCH 06/33] updated --- .buildkite/run-tpu-test.sh | 2 +- vllm/attention/backends/pallas.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 3591804b4d208..770dad6ffa3a1 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api] && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 5c3e8dcd3e88a..6fee81de14420 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -147,7 +147,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, From bebabfc263382c9362ad150d22dbdcf152c4d602 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:18:18 +0000 Subject: [PATCH 07/33] more cleaning --- vllm/attention/selector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3e3c9bbb0a508..478aebda68fd3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -257,8 +257,7 @@ def which_attn_to_use(head_size: int, return _Backend.HPU_ATTN if use_v1: - # return _Backend.FLASH_ATTN_VLLM_V1 - return _Backend.PALLAS_VLLM_V1 + return _Backend.FLASH_ATTN_VLLM_V1 # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: From 338e11c2fbeef5c92d28792d98ead54f77bf4fce Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:20:12 +0000 Subject: [PATCH 08/33] cleanup llmengine --- vllm/engine/llm_engine.py | 2131 ++----------------------------------- 1 file changed, 104 insertions(+), 2027 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 69ed6e6bd59d2..1c75b66086acf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,571 +1,70 @@ -import time -from collections import Counter as collectionsCounter -from collections import deque -from contextlib import contextmanager -from dataclasses import dataclass -from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Mapping, NamedTuple, Optional) -from typing import Sequence as GenericSequence -from typing import Set, Type, Union, cast, overload +from typing import Dict, List, Mapping, Optional, Type, Union -import torch -from typing_extensions import TypeVar - -import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) -from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, - SchedulerOutputs) +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase, Stats -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.engine.output_processor.util import create_output_by_sequence_group -from vllm.entrypoints.openai.logits_processors import ( - get_logits_processors as get_openai_logits_processors) -from vllm.executor.executor_base import ExecutorBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType) -from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt -from vllm.inputs.preprocess import InputPreprocessor +from vllm.engine.metrics_types import StatLoggerBase +from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING +from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger -from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - get_local_guided_decoding_logits_processor) -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, - RequestOutputFactory) +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - ParallelSampleSequenceGroup, Sequence, - SequenceGroup, SequenceGroupBase, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceStatus) -from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, - init_tracer) -from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind -from vllm.version import __version__ as VLLM_VERSION +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.processor import Processor +from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() - - -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - - -@dataclass -class SchedulerOutputState: - """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - allow_async_output_proc: bool = False - last_output: Optional[SamplerOutput] = None - - -class OutputData(NamedTuple): - outputs: List[SamplerOutput] - seq_group_metadata_list: List[SequenceGroupMetadata] - scheduler_outputs: SchedulerOutputs - is_async: bool - is_last_step: bool - # Indicates if this output is from the first step of the - # multi-step. When multi-step is disabled, this is always - # set to True. - # is_first_step_output is invalid when `outputs` has - # outputs from multiple steps. - is_first_step_output: Optional[bool] - skip: List[int] - - -class SchedulerContext: - - def __init__(self, multi_step_stream_outputs: bool = False): - self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] - self.seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None - self.scheduler_outputs: Optional[SchedulerOutputs] = None - - self.multi_step_stream_outputs: bool = multi_step_stream_outputs - - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool, - is_first_step_output: Optional[bool]): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=is_async, - is_last_step=is_last_step, - is_first_step_output=is_first_step_output, - skip=[])) class LLMEngine: - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The :class:`~vllm.LLM` class wraps this class for offline batched inference - and the :class:`AsyncLLMEngine` class wraps this class for online serving. - - The config arguments are derived from :class:`~vllm.EngineArgs`. (See - :ref:`engine_args`) - - Args: - model_config: The configuration related to the LLM model. - cache_config: The configuration related to the KV cache memory - management. - parallel_config: The configuration related to distributed execution. - scheduler_config: The configuration related to the request scheduler. - device_config: The configuration related to the device. - lora_config (Optional): The configuration related to serving multi-LoRA. - speculative_config (Optional): The configuration related to speculative - decoding. - executor_class: The model executor class for managing distributed - execution. - prompt_adapter_config (Optional): The configuration related to serving - prompt adapters. - log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection. - """ - - DO_VALIDATE_OUTPUT: ClassVar[bool] = False - """A flag to toggle whether to validate the type of request output.""" - - @classmethod - @contextmanager - def enable_output_validation(cls): - cls.DO_VALIDATE_OUTPUT = True - - yield - - cls.DO_VALIDATE_OUTPUT = False - - @classmethod - def validate_output( - cls, - output: object, - output_type: Type[_O], - ) -> _O: - do_validate = cls.DO_VALIDATE_OUTPUT - - if ((TYPE_CHECKING or do_validate) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - return cast(_O, output) - - @classmethod - def validate_outputs( - cls, - outputs: GenericSequence[object], - output_type: Type[_O], - ) -> List[_O]: - do_validate = cls.DO_VALIDATE_OUTPUT - - outputs_: List[_O] - if TYPE_CHECKING or do_validate: - outputs_ = [] - for output in outputs: - if not isinstance(output, output_type): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - outputs_.append(output) - else: - outputs_ = outputs - - return outputs_ - - tokenizer: Optional[BaseTokenizerGroup] + """Legacy LLMEngine for backwards compatibility.""" def __init__( self, vllm_config: VllmConfig, - executor_class: Type[ExecutorBase], + executor_class: Type[GPUExecutor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, + multiprocess_mode: bool = False, ) -> None: - # TODO: remove the local variables and use self.* throughout the class. - model_config = self.model_config = vllm_config.model_config - cache_config = self.cache_config = vllm_config.cache_config - lora_config = self.lora_config = vllm_config.lora_config - parallel_config = self.parallel_config = vllm_config.parallel_config - scheduler_config = self.scheduler_config = vllm_config.scheduler_config - device_config = self.device_config = vllm_config.device_config - speculative_config = self.speculative_config = vllm_config.speculative_config # noqa - load_config = self.load_config = vllm_config.load_config - decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + # TODO: Can we avoid this? + self.model_config = vllm_config.model_config + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + enable_lora=bool(vllm_config.lora_config)) + self.tokenizer.ping() + + # Processor (convert Inputs --> EngineCoreRequests) + self.processor = Processor(vllm_config.model_config, + vllm_config.lora_config, self.tokenizer, + input_registry) + + # Detokenizer (converts EngineCoreOutputs --> RequestOutput) + self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + + # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) + self.engine_core = EngineCoreClient.make_client( + vllm_config, + executor_class, + usage_context, + multiprocess_mode=multiprocess_mode, + asyncio_mode=False, ) - prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa - observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa - ) - - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_text_format=%s, mm_processor_kwargs=%s, " - "pooler_config=%r)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.num_scheduler_steps, - scheduler_config.chunked_prefill_enabled, - scheduler_config.multi_step_stream_outputs, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - use_cached_outputs, - model_config.chat_template_text_format, - model_config.mm_processor_kwargs, - model_config.pooler_config, - ) - # TODO(woosuk): Print more configs in debug mode. - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) - self.log_stats = log_stats - self.use_cached_outputs = use_cached_outputs - - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: - self.tokenizer = None - self.detokenizer = None - tokenizer_group = None - - # Ensure that the function doesn't contain a reference to self, - # to avoid engine GC issues - def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, ("tokenizer_group cannot be None, " - "make sure skip_tokenizer_init is False") - return tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict( - model_config) - - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) - - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - model_config) - - self.model_executor = executor_class(vllm_config=vllm_config, ) - - if self.model_config.task != "embedding": - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(model_config.dtype), - "tensor_parallel_size": - parallel_config.tensor_parallel_size, - "block_size": - cache_config.block_size, - "gpu_memory_utilization": - cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - model_config.quantization, - "kv_cache_dtype": - str(cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(lora_config), - "enable_prompt_adapter": - bool(prompt_adapter_config), - "enable_prefix_caching": - cache_config.enable_prefix_caching, - "enforce_eager": - model_config.enforce_eager, - "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, - }) - - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext(multi_step_stream_outputs=self.scheduler_config. - multi_step_stream_outputs) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - if model_config.use_async_output_proc: - process_model_outputs = weak_bind(self._process_model_outputs) - - self.async_callbacks = [ - partial(process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - else: - self.async_callbacks = [] - - # Currently used by AsyncLLMEngine to ensure quick append - # of request outputs to asyncio queues - self.process_request_outputs_callback: Optional[Callable] = None - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = [ - Scheduler( - scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if model_config.use_async_output_proc else None) - for v_id in range(parallel_config.pipeline_parallel_size) - ] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from vllm.engine.metrics import (LoggingStatLogger, - PrometheusStatLogger) - - self.stat_loggers = { - "logging": - LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len), - } - self.stat_loggers["prometheus"].info("cache_config", - self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = ( - SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - ), - )) - - self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} - - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). - - The workers will determine the number of blocks in both the GPU cache - and the swap CPU cache. - """ - num_gpu_blocks, num_cpu_blocks = ( - self.model_executor.determine_num_available_blocks()) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - @classmethod - def _get_executor_cls(cls, - engine_config: VllmConfig) -> Type[ExecutorBase]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutor - executor_class = NeuronExecutor - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutor - executor_class = RayTPUExecutor - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutor - executor_class = CPUExecutor - elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_hpu_executor import RayHPUExecutor - executor_class = RayHPUExecutor - else: - from vllm.executor.hpu_executor import HPUExecutor - executor_class = HPUExecutor - elif engine_config.device_config.device_type == "openvino": - from vllm.executor.openvino_executor import OpenVINOExecutor - executor_class = OpenVINOExecutor - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutor - executor_class = RayXPUExecutor - elif distributed_executor_backend == "mp": - # FIXME(kunshang): - # spawn needs calling `if __name__ == '__main__':`` - # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, Please try ray instead.") - else: - from vllm.executor.xpu_executor import XPUExecutor - executor_class = XPUExecutor - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor - executor_class = RayGPUExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingGPUExecutor - else: - from vllm.executor.gpu_executor import GPUExecutor - executor_class = GPUExecutor - return executor_class @classmethod def from_engine_args( @@ -573,176 +72,51 @@ def from_engine_args( engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Create the LLM engine. - engine = cls( - vllm_config=engine_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - - return engine - - def __reduce__(self): - # This is to ensure that the LLMEngine is not referenced in - # the closure used to initialize Ray worker actors - raise RuntimeError("LLMEngine should not be pickled!") - def __del__(self): - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): - model_executor.shutdown() - - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - - return tokenizer_group - - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - - def _init_tokenizer(self) -> BaseTokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) - - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - if self.prompt_adapter_config: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) - - def _add_processed_request( - self, - request_id: str, - processed_inputs: ProcessorInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> Optional[SequenceGroup]: - """Add a processed request to the engine's request pool. - return the created sequence group. - """ - if isinstance(params, SamplingParams) and params.n > 1: - ParallelSampleSequenceGroup.add_request( - request_id, - self, - params, - processed_inputs=processed_inputs, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) - return None + # Create the engine configs. + vllm_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(vllm_config) - self._validate_model_inputs(processed_inputs, lora_request) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + if VLLM_ENABLE_V1_MULTIPROCESSING: + logger.debug("Enabling multiprocessing for LLMEngine.") + enable_multiprocessing = True - if is_encoder_decoder_inputs(processed_inputs): - decoder_inputs = processed_inputs["decoder"] - encoder_inputs = processed_inputs["encoder"] - else: - decoder_inputs = processed_inputs - encoder_inputs = None + # Create the LLMEngine. + return cls(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing) - seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) + @classmethod + def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_tpu(): + return TPUExecutor + return GPUExecutor - encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request, - prompt_adapter_request)) + def stop_remote_worker_execution_loop(self) -> None: + raise NotImplementedError("TP not implemented yet.") - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + def get_num_unfinished_requests(self) -> int: + return self.detokenizer.get_num_unfinished_requests() - # Add the sequence group to the scheduler with least unfinished seqs. - costs = [ - scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler - ] - min_cost_scheduler = self.scheduler[costs.index(min(costs))] - min_cost_scheduler.add_seq_group(seq_group) + def has_unfinished_requests(self) -> bool: + return self.detokenizer.has_unfinished_requests() - return seq_group + @classmethod + def validate_outputs(cls, outputs, output_type): + return outputs - def stop_remote_worker_execution_loop(self) -> None: - self.model_executor.stop_remote_worker_execution_loop() + def abort_request(self, request_ids: List[str]) -> None: + """Remove request_ids from EngineCore and Detokenizer.""" - @overload # DEPRECATED - def add_request( - self, - request_id: str, - *, - inputs: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - ... + self.engine_core.abort_requests(request_ids) + self.detokenizer.abort_requests(request_ids) - @overload def add_request( self, request_id: str, @@ -754,1343 +128,46 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - def add_request( - self, - request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: - """Add a request to the engine's request pool. - - The request is added to the request pool and will be processed by the - scheduler as `engine.step()` is called. The exact scheduling policy is - determined by the scheduler. - - Args: - request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` - for more details about the format of each input. - params: Parameters for sampling or pooling. - :class:`~vllm.SamplingParams` for text generation. - :class:`~vllm.PoolingParams` for pooling. - arrival_time: The arrival time of the request. If None, we use - the current monotonic time. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Details: - - Set arrival_time to the current time if it is None. - - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of :class:`~vllm.Sequence` objects. - - Create a :class:`~vllm.SequenceGroup` object - from the list of :class:`~vllm.Sequence`. - - Add the :class:`~vllm.SequenceGroup` object to the scheduler. - - Example: - >>> # initialize engine - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> # set request arguments - >>> example_prompt = "Who is the president of the United States?" - >>> sampling_params = SamplingParams(temperature=0.0) - >>> request_id = 0 - >>> - >>> # add the request to the engine - >>> engine.add_request( - >>> str(request_id), - >>> example_prompt, - >>> SamplingParams(temperature=0.0)) - >>> # continue the request processing - >>> ... - """ - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - - if priority != 0 and not self.scheduler_config.policy == "priority": - raise ValueError(f"Got priority {priority} but " - "Priority scheduling is not enabled.") - - if isinstance(params, SamplingParams) \ - and (params.guided_decoding or params.logits_processors) \ - and self.scheduler_config.num_scheduler_steps > 1: - raise ValueError( - "Guided decoding and logits processors are not supported " - "in multi-step decoding") - - if arrival_time is None: - arrival_time = time.time() - - if self.tokenizer is not None: - self._validate_token_prompt( - prompt, - tokenizer=self.get_tokenizer(lora_request=lora_request)) - - preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) - - # This is a bit of a hack - copy the mm_processor_kwargs that were - # used in the input processor to the processed output, since these - # kwargs are presumed to be immutable and the values should be aligned - # between the input processor (here) and the input mapper. - processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( - "mm_processor_kwargs") - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - priority=priority, - ) - - def _validate_token_prompt(self, prompt: PromptType, - tokenizer: AnyTokenizer): - # Guard against out-of-vocab tokens. - # For some tokenizers, tokenizer.decode will happily return empty text - # for token ids that are out of vocab, and we don't detect token ids - # that are greater than the max token id before running the model. - # However, these token ids will later crash a cuda kernel at runtime - # with an index out of bounds error. This will crash the entire engine. - # This needs to happen before multimodal input pre-processing, which - # may add dummy tokens that aren't part of the tokenizer's - # vocabulary. - if is_token_prompt(prompt): - prompt_ids = prompt["prompt_token_ids"] - if len(prompt_ids) == 0: - # Empty prompt check is handled later - return - max_input_id = max(prompt_ids) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - "Token id {} is out of vocabulary".format(max_input_id)) - - def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") - - sampling_params = self._build_logits_processors( - sampling_params, lora_request) - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - - sampling_params.update_from_generation_config( - self.generation_config_fields, seq.eos_token_id) - - # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority) - - return seq_group - - def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with PoolingParams.""" - # Defensive copy of PoolingParams, which are used by the pooler - pooling_params = pooling_params.clone() - # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority) - return seq_group - - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - """Aborts a request(s) with the given ID. - - Args: - request_id: The ID(s) of the request to abort. - - Details: - - Refer to the - :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` - from class :class:`~vllm.core.scheduler.Scheduler`. - - Example: - >>> # initialize engine and add a request with request_id - >>> request_id = str(0) - >>> # abort the request - >>> engine.abort_request(request_id) - """ - for scheduler in self.scheduler: - scheduler.abort_seq_group(request_id) - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config - - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config - - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config - - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config - - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return sum(scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler) - - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return any(scheduler.has_unfinished_seqs() - for scheduler in self.scheduler) - - def has_unfinished_requests_for_virtual_engine( - self, virtual_engine: int) -> bool: - """ - Returns True if there are unfinished requests for the virtual engine. - """ - return self.scheduler[virtual_engine].has_unfinished_seqs() - - @staticmethod - def _process_sequence_group_outputs( - seq_group: SequenceGroup, - outputs: List[EmbeddingSequenceGroupOutput], - ) -> None: - seq_group.embeddings = outputs[0].embeddings - - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_STOPPED - - return - - def _update_num_computed_tokens_for_multi_step_prefill( - self, seq_group: SequenceGroup, - seq_group_meta: SequenceGroupMetadata, - is_first_step_output: Optional[bool]): - """ - This function updates num_computed_tokens for prompt sequences - when Multi-Step is enabled. - - seq_group: SequenceGroup to update the num_computed_tokens for. - seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - - When available, is_first_step_output indicates if the appended - output token is the output of the first-step in multi-step. - A value of None indicates that outputs from all steps in - in multi-step are submitted in a single burst. - """ - - assert self.scheduler_config.is_multi_step - - if not seq_group_meta.is_prompt: - # num_computed_token updates for multi-step decodes happen after - # the tokens are appended to the sequence. - return - - do_update: bool = False - if self.scheduler_config.chunked_prefill_enabled: - # In multi-step + chunked-prefill case, the prompt sequences - # that are scheduled are fully processed in the first step. - do_update = is_first_step_output is None or is_first_step_output - else: - # Normal multi-step decoding case. In this case prompt-sequences - # are actually single-stepped. Always update in this case. - assert seq_group.state.num_steps == 1 - do_update = True - - if do_update: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size) - - def _process_model_outputs(self, - ctx: SchedulerContext, - request_id: Optional[str] = None) -> None: - """Apply the model output to the sequences in the scheduled seq groups - and return responses. - - ctx: The virtual engine context to work on - request_id: If provided, then only this request is going to be processed - """ - - now = time.time() - - if len(ctx.output_queue) == 0: - return None - - # Get pending async postprocessor - if request_id: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, skip) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, - skip) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( - scheduler_outputs.scheduled_seq_groups) - - has_multiple_outputs: bool = len(outputs) > 1 - outputs_by_sequence_group: List[List[SequenceGroupOutput]] - if has_multiple_outputs: - assert self.scheduler_config.is_multi_step or \ - self.speculative_config - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - outputs_by_sequence_group = create_output_by_sequence_group( - outputs, num_seq_groups=len(seq_group_metadata_list)) - # We have outputs for multiple steps submitted in a single burst, - # so invalidate is_first_step_output. - is_first_step_output = None - else: - outputs_by_sequence_group = outputs - - # Determine the requests we need to operate on - if request_id: - indices = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): - if seq_group_meta.request_id == request_id: - assert i not in skip # Cannot be called twice - indices.append(i) - break - - # If the request_id was not found, then it means that - # this is a new request that has no pending async - # postprocessor - if not indices: - return - else: - indices = range(len(seq_group_metadata_list)) # type: ignore - - finished_before: List[int] = [] - finished_now: List[int] = [] - for i in indices: - if i in skip: - continue - - seq_group_meta = seq_group_metadata_list[i] - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group: SequenceGroup = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - finished_before.append(i) - continue - - output: List[SequenceGroupOutput] - if has_multiple_outputs: - output = outputs_by_sequence_group[i] - else: - output = [outputs_by_sequence_group[0][i]] - - if not is_async: - if self.scheduler_config.is_multi_step: - # Updates happen only if the sequence is prefill - self._update_num_computed_tokens_for_multi_step_prefill( - seq_group, seq_group_meta, is_first_step_output) - else: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size or 0) - - if outputs: - for o in outputs: - if (isinstance(o, SamplerOutput) - and seq_group.metrics is not None): - if seq_group.metrics.model_forward_time is not None: - seq_group.metrics.model_forward_time += ( - o.model_forward_time or 0) - else: - seq_group.metrics.model_forward_time = ( - o.model_forward_time) - if seq_group.metrics.model_execute_time is not None: - seq_group.metrics.model_execute_time += ( - o.model_execute_time or 0) - else: - seq_group.metrics.model_execute_time = ( - o.model_execute_time) - - if self.model_config.task == "embedding": - self._process_sequence_group_outputs(seq_group, output) - else: - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs( - seq_group, output, is_async) - - if seq_group.is_finished(): - finished_now.append(i) - - # Generate outputs for the requests that finished this iteration - for i in finished_now: - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # When we process a single request, we skip it for the next time, - # and invoke the request output callback (if there was final output) - if request_id: - assert len(indices) == 1 - skip.append(indices[0]) - - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Free currently finished requests - if finished_now: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() - - # For multi-step without streaming, don't create outputs each iteration - if not is_last_step and not ctx.multi_step_stream_outputs: - # Immediately process request outputs here (if callback is given) - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Create the outputs - for i in indices: - if i in skip or i in finished_before or i in finished_now: - continue # Avoids double processing - - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # For multi-step with streaming, create outputs each iteration - if not is_last_step and ctx.multi_step_stream_outputs: - # Immediately process request outputs here (if callback is given) - if self.process_request_outputs_callback is not None: - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - for seq_group in scheduler_outputs.ignored_seq_groups: - params = seq_group.sampling_params - if params is not None and params.output_kind == ( - RequestOutputKind.DELTA) and not seq_group.is_finished(): - continue - - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs, - ) - if request_output: - ctx.request_outputs.append(request_output) - - # Immediately process request outputs here (if callback is given) - if (ctx.request_outputs - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - - # For async case, we need to record the stats here. - # For non-async case, the stats are done in the - # LLMEngine/AsyncLLMEngine directly - if is_async: - # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before, - skip) - - # Tracing - self.do_tracing(scheduler_outputs, finished_before) - - return None - - def _advance_to_next_step( - self, output: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done inside output processor, but it is - required if the worker is to perform async forward pass to next step. - """ - for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ - zip(seq_group_metadata_list, output, scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - continue - - if self.scheduler_config.is_multi_step: - # Updates happen only if the sequence is prefill - self._update_num_computed_tokens_for_multi_step_prefill( - seq_group, seq_group_metadata, - seq_group.state.num_steps == 1) - else: - token_chunk_size = (seq_group_metadata.token_chunk_size - if seq_group_metadata.token_chunk_size - is not None else 0) - seq_group.update_num_computed_tokens(token_chunk_size) - - if seq_group_metadata.do_sample: - assert len(sequence_group_outputs.samples) == 1, ( - "Async output processor expects a single sample" - " (i.e sampling_params.n == 1)") - sample = sequence_group_outputs.samples[0] - - assert len(seq_group.seqs) == 1 - seq = seq_group.seqs[0] - - if self.scheduler_config.is_multi_step: - is_prefill_append = seq.data.get_num_uncomputed_tokens( - ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs) - if not is_prefill_append: - seq_group.update_num_computed_tokens(1) - else: - seq.append_token_id(sample.output_token, sample.logprobs) - - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: - """Performs one decoding iteration and returns newly generated results. - - .. figure:: https://i.imgur.com/sv2HssD.png - :alt: Overview of the step function - :align: center - - Overview of the step function. - - Details: - - Step 1: Schedules the sequences to be executed in the next - iteration and the token blocks to be swapped in/out/copy. - - - Depending on the scheduling policy, - sequences may be `preempted/reordered`. - - A Sequence Group (SG) refer to a group of sequences - that are generated from the same prompt. - - - Step 2: Calls the distributed executor to execute the model. - - Step 3: Processes the model output. This mainly includes: - - - Decodes the relevant outputs. - - Updates the scheduled sequence groups with model outputs - based on its `sampling parameters` (`use_beam_search` or not). - - Frees the finished sequence groups. - - - Finally, it creates and returns the newly generated results. - - Example: - >>> # Please see the example/ folder for more detailed examples. - >>> - >>> # initialize engine and request arguments - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> example_inputs = [(0, "What is LLM?", - >>> SamplingParams(temperature=0.0))] - >>> - >>> # Start the engine with an event loop - >>> while True: - >>> if example_inputs: - >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id),prompt,sampling_params) - >>> - >>> # continue the request processing - >>> request_outputs = engine.step() - >>> for request_output in request_outputs: - >>> if request_output.finished: - >>> # return or show the request output - >>> - >>> if not (engine.has_unfinished_requests() or example_inputs): - >>> break - """ - if self.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported through AsyncLLMEngine " - "as performance will be severely degraded otherwise.") - - # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0. - virtual_engine = 0 - - # These are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - if (self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0): - # cache the scheduler outputs for the next iteration if we have - # lookahead slots - self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() + # 1) Process raw inputs into the request. + detokenizer_req, engine_core_req = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = \ - self._get_last_sampled_token_ids(virtual_engine) + # 2) Add the request to Detokenizer. + self.detokenizer.add_request(detokenizer_req) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) + # 3) Add the request to EngineCore. + self.engine_core.add_request(engine_core_req) - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] + def step(self) -> List[RequestOutput]: - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + # 1) Get EngineCoreOutput from the EngineCore. + engine_core_outputs = self.engine_core.get_output() - # We need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, outputs) - else: - # Nothing scheduled => If there is pending async postprocessor, - # then finish it here. - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case - outputs = [] + # 2) Detokenizer the EngineCoreOutput. + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: - for seq_group in seq_group_metadata_list: - seq_group.finish_step() + # 3) Abort requests that finished due to stopping criteria. + if requests_to_abort: + self.abort_request(requests_to_abort) - if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps. - if self.scheduler_config.is_multi_step: - self.cached_scheduler_outputs[0] = SchedulerOutputState() + return request_outputs - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. When the num_steps > 1, - # multi_step_model_runner does the first-step output append. - is_first_step_output: bool = False if not seq_group_metadata_list \ - else seq_group_metadata_list[0].state.num_steps == 1 + # TODO(rob): Can we get rid of these? - # Add results to the output_queue - ctx.append_output(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True, - is_first_step_output=is_first_step_output) - - if outputs and allow_async_output_proc: - assert len(outputs) == 1, ( - "Async postprocessor expects only a single output set") - - self._advance_to_next_step( - outputs[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) - - # Check if need to run the usual non-async path - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() - - return ctx.request_outputs - - def _has_remaining_steps( - self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: - if (not self.scheduler_config.is_multi_step - or not seq_group_metadata_list): - return False - - # TODO(will) this is a sanity check for nowto make sure that all the - # seqs are on the same steps. Eventually we will want to do some sort of - # dynamic scheduling when doing multi-step decoding. - ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps - if any([ - seq_group.state.remaining_steps != ref_remaining_steps - for seq_group in seq_group_metadata_list[1:] - ]): - raise AssertionError("All running sequence groups should " - "have the same remaining steps.") - - return ref_remaining_steps > 0 - - def _cache_scheduler_outputs_for_multi_step( - self, virtual_engine: int, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - co = self.cached_scheduler_outputs[virtual_engine] - - co.seq_group_metadata_list = seq_group_metadata_list - co.scheduler_outputs = scheduler_outputs - co.allow_async_output_proc = allow_async_output_proc - co.last_output = None - - def _update_cached_scheduler_output( - self, virtual_engine: int, - output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 - and output[0] is not None): - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_cpu is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output - - def _get_last_sampled_token_ids( - self, virtual_engine: int) -> Optional[torch.Tensor]: - cached_last_output = self.cached_scheduler_outputs[ - virtual_engine].last_output - if (self.scheduler_config.is_multi_step - and self.parallel_config.pipeline_parallel_size > 1 - and cached_last_output is not None - and cached_last_output.sampled_token_ids_cpu is not None): - return cached_last_output.sampled_token_ids_cpu - return None - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} already exists.") - self.stat_loggers[logger_name] = logger - - def remove_logger(self, logger_name: str) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name not in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} does not exist.") - del self.stat_loggers[logger_name] - - def do_log_stats(self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> None: - """Forced log when no requests active.""" - if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output, - finished_before, skip) - for logger in self.stat_loggers.values(): - logger.log(stats) - - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> Stats: - """Get Stats to be Logged to Prometheus. - - Args: - scheduler_outputs: Optional, used to populate metrics related to - the scheduled batch, - model_output: Optional, used to emit speculative decoding metrics - which are created by the workers. - finished_before: Optional, indices of sequences that were finished - before. These sequences will be ignored. - skip: Optional, indices of sequences that were preempted. These - sequences will be ignored. - """ - now = time.time() - - # System State - # Scheduler State - num_running_sys = sum( - len(scheduler.running) for scheduler in self.scheduler) - num_swapped_sys = sum( - len(scheduler.swapped) for scheduler in self.scheduler) - num_waiting_sys = sum( - len(scheduler.waiting) for scheduler in self.scheduler) - - # KV Cache Usage in % - num_total_gpu = self.cache_config.num_gpu_blocks - gpu_cache_usage_sys = 0. - if num_total_gpu: # Guard against both None and 0 - num_free_gpu = sum( - scheduler.block_manager.get_num_free_gpu_blocks() - for scheduler in self.scheduler) - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) - - num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage_sys = 0. - if num_total_cpu: # Guard against both None and 0 - num_free_cpu = sum( - scheduler.block_manager.get_num_free_cpu_blocks() - for scheduler in self.scheduler) - cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) - - # Prefix Cache Hit Rate. Note that we always use - # the cache hit rate of the first virtual engine. - cpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.CPU) - gpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.GPU) - - # Iteration stats - num_prompt_tokens_iter = 0 - num_generation_tokens_iter = 0 - num_tokens_iter = 0 - time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] - num_preemption_iter = (0 if scheduler_outputs is None else - scheduler_outputs.preempted) - - # Request stats - # Latency - time_e2e_requests: List[float] = [] - time_queue_requests: List[float] = [] - time_inference_requests: List[float] = [] - time_prefill_requests: List[float] = [] - time_decode_requests: List[float] = [] - time_in_queue_requests: List[float] = [] - model_forward_time_requests: List[float] = [] - model_execute_time_requests: List[float] = [] - # Metadata - num_prompt_tokens_requests: List[int] = [] - num_generation_tokens_requests: List[int] = [] - n_requests: List[int] = [] - max_num_generation_tokens_requests: List[int] = [] - max_tokens_requests: List[int] = [] - finished_reason_requests: List[str] = [] - - # Lora requests - running_lora_adapters = dict( - collectionsCounter([ - running_request.lora_request.lora_name - for scheduler in self.scheduler - for running_request in scheduler.running - if running_request.lora_request - ])) - waiting_lora_adapters = dict( - collectionsCounter([ - waiting_request.lora_request.lora_name - for scheduler in self.scheduler - for waiting_request in scheduler.waiting - if waiting_request.lora_request - ])) - max_lora_stat = "0" - if self.lora_config: - max_lora_stat = str(self.lora_config.max_loras) - - # NOTE: This loop assumes prefill seq_groups are before - # decode seq_groups in scheduled_seq_groups. - if scheduler_outputs is not None: - # For async postprocessor, already finished sequences need to be - # not counted (to avoid double counting) - actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - - num_generation_tokens_from_prefill_groups = 0. - # NOTE: if scheduler_outputs.num_prefill_groups > 0 and - # the len of scheduler_outputs.scheduled_seq_groups is != - # scheduler_outputs.num_prefill_groups, this means that - # chunked prefills have been detected. - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double logging when using async output proc - if finished_before and idx in finished_before: - actual_num_batched_tokens -= 1 - continue - - # Currently, skip == preempted sequences, so we need to skip - # their log stats - if skip and idx in skip: - continue - - group_was_prefill = idx < scheduler_outputs.num_prefill_groups - seq_group = scheduled_seq_group.seq_group - - # NOTE: a seq_group that completed all of its prefill tokens - # in the last iteration will have seq_group.is_prefill() = False - # with group_was_prefill = True - if group_was_prefill: - # Number of prompt tokens. - num_prompt_tokens_iter += ( - scheduled_seq_group.token_chunk_size) - - # If the seq_group just finished the prefill state - # get TTFT. - if not seq_group.is_prefill(): - latency = seq_group.get_last_latency(now) - time_to_first_tokens_iter.append(latency) - - # One generation token per finished prefill. - num_generation_tokens_from_prefill_groups += ( - seq_group.num_seqs()) - else: - # TPOTs. - latency = seq_group.get_last_latency(now) - time_per_output_tokens_iter.append(latency) - if seq_group.state.current_step == 0: - # For async_output_proc, the do_log_stats() - # is called following init_multi_step(), which - # sets the current_step to zero. - actual_num_batched_tokens +=\ - seq_group.state.num_steps - 1 - else: - actual_num_batched_tokens +=\ - seq_group.state.current_step - 1 - - # Because of chunked prefill, we can have a single sequence - # group that does multiple prompt_runs. To prevent logging - # the same metadata more than once per request, we standardize - # on logging request level information for finished requests, - # which can only happen once. - if seq_group.is_finished(): - # Latency timings - time_e2e_requests.append(now - - seq_group.metrics.arrival_time) - if (seq_group.metrics.first_scheduled_time is not None and - seq_group.metrics.first_token_time is not None): - time_queue_requests.append( - seq_group.metrics.first_scheduled_time - - seq_group.metrics.arrival_time) - time_prefill_requests.append( - seq_group.metrics.first_token_time - - seq_group.metrics.first_scheduled_time) - time_decode_requests.append( - now - seq_group.metrics.first_token_time) - time_inference_requests.append( - now - seq_group.metrics.first_scheduled_time) - if seq_group.metrics.time_in_queue is not None: - time_in_queue_requests.append( - seq_group.metrics.time_in_queue) - if seq_group.metrics.model_forward_time is not None: - model_forward_time_requests.append( - seq_group.metrics.model_forward_time) - if seq_group.metrics.model_execute_time is not None: - model_execute_time_requests.append( - seq_group.metrics.model_execute_time * 1000) - # Metadata - num_prompt_tokens_requests.append( - len(seq_group.prompt_token_ids)) - num_generation_tokens_requests.extend([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ]) - max_num_generation_tokens_requests.append( - max(seq.get_output_len() - for seq in seq_group.get_seqs())) - if seq_group.sampling_params is not None: - n_requests.append(seq_group.sampling_params.n) - max_tokens_requests.append( - seq_group.sampling_params.max_tokens) - finished_reason_requests.extend([ - SequenceStatus.get_finished_reason(seq.status) - for seq in seq_group.get_finished_seqs() - ]) - - # Number of generation tokens. - # num_batched_tokens equals the number of prompt_tokens plus the - # number of decode_tokens in a single iteration. So, - # num_generation_tokens = num_batched_tokens - num_prompt_tokens - # + num_generation_tokens_from_prefill_groups (since we generate - # one token on prefills on iters where the prefill finishes). - num_generation_tokens_iter = ( - actual_num_batched_tokens - num_prompt_tokens_iter + - num_generation_tokens_from_prefill_groups) - num_tokens_iter = (num_generation_tokens_iter + - num_prompt_tokens_iter) - # Spec decode, if enabled, emits specialized metrics from the worker in - # sampler output. - if model_output and (model_output[0].spec_decode_worker_metrics - is not None): - spec_decode_metrics = model_output[0].spec_decode_worker_metrics - else: - spec_decode_metrics = None - - return Stats( - now=now, - # System stats - # Scheduler State - num_running_sys=num_running_sys, - num_swapped_sys=num_swapped_sys, - num_waiting_sys=num_waiting_sys, - # KV Cache Usage in % - gpu_cache_usage_sys=gpu_cache_usage_sys, - cpu_cache_usage_sys=cpu_cache_usage_sys, - # Prefix Cache Hit Rate - cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, - - # Iteration stats - num_prompt_tokens_iter=num_prompt_tokens_iter, - num_generation_tokens_iter=num_generation_tokens_iter, - num_tokens_iter=num_tokens_iter, - time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, - spec_decode_metrics=spec_decode_metrics, - num_preemption_iter=num_preemption_iter, - - # Request stats - # Latency - time_e2e_requests=time_e2e_requests, - time_queue_requests=time_queue_requests, - time_inference_requests=time_inference_requests, - time_prefill_requests=time_prefill_requests, - time_decode_requests=time_decode_requests, - time_in_queue_requests=time_in_queue_requests, - model_forward_time_requests=model_forward_time_requests, - model_execute_time_requests=model_execute_time_requests, - # Metadata - num_prompt_tokens_requests=num_prompt_tokens_requests, - num_generation_tokens_requests=num_generation_tokens_requests, - max_num_generation_tokens_requests= - max_num_generation_tokens_requests, - n_requests=n_requests, - max_tokens_requests=max_tokens_requests, - finished_reason_requests=finished_reason_requests, - max_lora=str(max_lora_stat), - waiting_lora_adapters=list(waiting_lora_adapters.keys()), - running_lora_adapters=list(running_lora_adapters.keys())) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_executor.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_executor.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_executor.list_loras() - - def pin_lora(self, lora_id: int) -> bool: - return self.model_executor.pin_lora(lora_id) - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_executor.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_executor.remove_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> List[int]: - return self.model_executor.list_prompt_adapters() - - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - - def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.start_profile() - else: - self.model_executor._run_workers("start_profile") - - def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.stop_profile() - else: - self.model_executor._run_workers("stop_profile") - - def is_tracing_enabled(self) -> bool: - return self.tracer is not None - - def do_tracing(self, - scheduler_outputs: SchedulerOutputs, - finished_before: Optional[List[int]] = None) -> None: - if self.tracer is None: - return - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double tracing when using async output proc - if finished_before and idx in finished_before: - continue - - seq_group = scheduled_seq_group.seq_group - if seq_group.is_finished(): - self.create_trace_span(seq_group) - - def create_trace_span(self, seq_group: SequenceGroup) -> None: - if self.tracer is None or seq_group.sampling_params is None: - return - arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) - - trace_context = extract_trace_context(seq_group.trace_headers) - - with self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as seq_span: - metrics = seq_group.metrics - ttft = metrics.first_token_time - metrics.arrival_time - e2e_time = metrics.finished_time - metrics.arrival_time - # attribute names are based on - # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md - seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, - self.model_config.model) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID, - seq_group.request_id) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE, - seq_group.sampling_params.temperature) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P, - seq_group.sampling_params.top_p) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, - seq_group.sampling_params.max_tokens) - seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, - seq_group.sampling_params.n) - seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, - seq_group.num_seqs()) - seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, - len(seq_group.prompt_token_ids)) - seq_span.set_attribute( - SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, - sum([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ])) - seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, - metrics.time_in_queue) - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) - seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) - if metrics.scheduler_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER, - metrics.scheduler_time) - if metrics.model_forward_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD, - metrics.model_forward_time / 1000.0) - if metrics.model_execute_time is not None: - seq_span.set_attribute( - SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, - metrics.model_execute_time) + def get_model_config(self): + pass def is_encoder_decoder_model(self): - return self.input_preprocessor.is_encoder_decoder_model() - - def _validate_model_inputs(self, inputs: ProcessorInputs, - lora_request: Optional[LoRARequest]): - if is_encoder_decoder_inputs(inputs): - # For encoder-decoder multimodal models, the max_prompt_len - # restricts the decoder prompt length - prompt_inputs = inputs["decoder" if self.model_config. - is_multimodal_model else "encoder"] - else: - prompt_inputs = inputs - - prompt_ids = prompt_inputs.get("prompt_token_ids") - - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") - - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - - # TODO: Find out how many placeholder tokens are there so we can - # check that chunked prefill does not truncate them - # max_batch_len = self.scheduler_config.max_num_batched_tokens - - def _build_logits_processors( - self, sampling_params: SamplingParams, - lora_request: Optional[LoRARequest]) -> SamplingParams: - """Constructs logits processors based on the guided_decoding, - logits_bias, and allowed_token_ids fields in sampling_params. Deletes - those fields and adds the constructed logits processors to the - logits_processors field. Returns the modified sampling params.""" - - logits_processors = [] - - if (guided_decoding := sampling_params.guided_decoding) is not None: - - logger.debug( - "Building guided decoding logits processor in " - "LLMEngine. Params: %s", guided_decoding) - - tokenizer = self.get_tokenizer(lora_request=lora_request) - guided_decoding.backend = guided_decoding.backend or \ - self.decoding_config.guided_decoding_backend - - processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) - if processor: - logits_processors.append(processor) - - # Unset so this doesn't get passed down to the model - sampling_params.guided_decoding = None - - if (sampling_params.logit_bias or sampling_params.allowed_token_ids): - tokenizer = self.get_tokenizer(lora_request=lora_request) - - processors = get_openai_logits_processors( - logit_bias=sampling_params.logit_bias, - allowed_token_ids=sampling_params.allowed_token_ids, - tokenizer=tokenizer) - logits_processors.extend(processors) - - # Unset so these don't get passed down to the model - sampling_params.logit_bias = None - sampling_params.allowed_token_ids = None + pass - if len(sampling_params.bad_words) > 0: - tokenizer = self.get_tokenizer(lora_request) - processors = get_bad_words_logits_processors( - bad_words=sampling_params.bad_words, tokenizer=tokenizer) - logits_processors.extend(processors) + def start_profile(self): + pass - if logits_processors: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = logits_processors - else: - sampling_params.logits_processors.extend(logits_processors) + def stop_profile(self): + pass - return sampling_params + def get_tokenizer_group(self, group_type): + pass From db49d3b2d79c35512b701f3b9ad0dfc176a26e43 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:22:37 +0000 Subject: [PATCH 09/33] Revert "cleanup llmengine" This reverts commit 338e11c2fbeef5c92d28792d98ead54f77bf4fce. --- vllm/engine/llm_engine.py | 2131 +++++++++++++++++++++++++++++++++++-- 1 file changed, 2027 insertions(+), 104 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c75b66086acf..69ed6e6bd59d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,70 +1,571 @@ -from typing import Dict, List, Mapping, Optional, Type, Union +import time +from collections import Counter as collectionsCounter +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union, cast, overload -from vllm.config import VllmConfig +import torch +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) +from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, + SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase -from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING -from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import ( + get_logits_processors as get_openai_logits_processors) +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger +from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) from vllm.pooling_params import PoolingParams -from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.detokenizer import Detokenizer -from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor -from vllm.v1.executor.tpu_executor import TPUExecutor +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + ParallelSampleSequenceGroup, Sequence, + SequenceGroup, SequenceGroupBase, + SequenceGroupMetadata, SequenceGroupOutput, + SequenceStatus) +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None + + +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] + skip: List[int] + + +class SchedulerContext: + + def __init__(self, multi_step_stream_outputs: bool = False): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool, + is_first_step_output: Optional[bool]): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + is_first_step_output=is_first_step_output, + skip=[])) class LLMEngine: - """Legacy LLMEngine for backwards compatibility.""" + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return cast(_O, output) + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, - multiprocess_mode: bool = False, ) -> None: - # TODO: Can we avoid this? - self.model_config = vllm_config.model_config - - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - enable_lora=bool(vllm_config.lora_config)) - self.tokenizer.ping() - - # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry) - - # Detokenizer (converts EngineCoreOutputs --> RequestOutput) - self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) - - # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) - self.engine_core = EngineCoreClient.make_client( - vllm_config, - executor_class, - usage_context, - multiprocess_mode=multiprocess_mode, - asyncio_mode=False, + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "chat_template_text_format=%s, mm_processor_kwargs=%s, " + "pooler_config=%r)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.chat_template_text_format, + model_config.mm_processor_kwargs, + model_config.pooler_config, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig( + ) + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) + + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) + + self.model_executor = executor_class(vllm_config=vllm_config, ) + + if self.model_config.task != "embedding": + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(model_config.dtype), + "tensor_parallel_size": + parallel_config.tensor_parallel_size, + "block_size": + cache_config.block_size, + "gpu_memory_utilization": + cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + model_config.quantization, + "kv_cache_dtype": + str(cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), + "enable_prefix_caching": + cache_config.enable_prefix_caching, + "enforce_eager": + model_config.enforce_eager, + "disable_custom_all_reduce": + parallel_config.disable_custom_all_reduce, + }) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if model_config.use_async_output_proc else None) + for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + )) + + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutor + executor_class = RayTPUExecutor + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor + elif engine_config.device_config.device_type == "hpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_hpu_executor import RayHPUExecutor + executor_class = RayHPUExecutor + else: + from vllm.executor.hpu_executor import HPUExecutor + executor_class = HPUExecutor + elif engine_config.device_config.device_type == "openvino": + from vllm.executor.openvino_executor import OpenVINOExecutor + executor_class = OpenVINOExecutor + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutor + executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") + else: + from vllm.executor.xpu_executor import XPUExecutor + executor_class = XPUExecutor + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingGPUExecutor + else: + from vllm.executor.gpu_executor import GPUExecutor + executor_class = GPUExecutor + return executor_class @classmethod def from_engine_args( @@ -72,51 +573,176 @@ def from_engine_args( engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - vllm_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(vllm_config) + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Create the LLM engine. + engine = cls( + vllm_config=engine_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) - if VLLM_ENABLE_V1_MULTIPROCESSING: - logger.debug("Enabling multiprocessing for LLMEngine.") - enable_multiprocessing = True + return engine - # Create the LLMEngine. - return cls(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=enable_multiprocessing) + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") - @classmethod - def _get_executor_cls(cls, vllm_config: VllmConfig): - if current_platform.is_tpu(): - return TPUExecutor - return GPUExecutor + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() - def stop_remote_worker_execution_loop(self) -> None: - raise NotImplementedError("TP not implemented yet.") + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer - def get_num_unfinished_requests(self) -> int: - return self.detokenizer.get_num_unfinished_requests() + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") - def has_unfinished_requests(self) -> bool: - return self.detokenizer.has_unfinished_requests() + return tokenizer_group - @classmethod - def validate_outputs(cls, outputs, output_type): - return outputs + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: ProcessorInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> Optional[SequenceGroup]: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + + self._validate_model_inputs(processed_inputs, lora_request) + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = processed_inputs["decoder"] + encoder_inputs = processed_inputs["encoder"] + else: + decoder_inputs = processed_inputs + encoder_inputs = None + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) + + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") - def abort_request(self, request_ids: List[str]) -> None: - """Remove request_ids from EngineCore and Detokenizer.""" + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) - self.engine_core.abort_requests(request_ids) - self.detokenizer.abort_requests(request_ids) + return seq_group + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + @overload # DEPRECATED + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload def add_request( self, request_id: str, @@ -128,46 +754,1343 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `n` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + + if arrival_time is None: + arrival_time = time.time() + + if self.tokenizer is not None: + self._validate_token_prompt( + prompt, + tokenizer=self.get_tokenizer(lora_request=lora_request)) + + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + # This is a bit of a hack - copy the mm_processor_kwargs that were + # used in the input processor to the processed output, since these + # kwargs are presumed to be immutable and the values should be aligned + # between the input processor (here) and the input mapper. + processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( + "mm_processor_kwargs") + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + def _validate_token_prompt(self, prompt: PromptType, + tokenizer: AnyTokenizer): + # Guard against out-of-vocab tokens. + # For some tokenizers, tokenizer.decode will happily return empty text + # for token ids that are out of vocab, and we don't detect token ids + # that are greater than the max token id before running the model. + # However, these token ids will later crash a cuda kernel at runtime + # with an index out of bounds error. This will crash the entire engine. + # This needs to happen before multimodal input pre-processing, which + # may add dummy tokens that aren't part of the tokenizer's + # vocabulary. + if is_token_prompt(prompt): + prompt_ids = prompt["prompt_token_ids"] + if len(prompt_ids) == 0: + # Empty prompt check is handled later + return + max_input_id = max(prompt_ids) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + "Token id {} is out of vocabulary".format(max_input_id)) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + sampling_params.update_from_generation_config( + self.generation_config_fields, seq.eos_token_id) + + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + return seq_group + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + for scheduler in self.scheduler: + scheduler.abort_seq_group(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() + + @staticmethod + def _process_sequence_group_outputs( + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, num_seq_groups=len(seq_group_metadata_list)) + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + else: + outputs_by_sequence_group = outputs + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + output: List[SequenceGroupOutput] + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time or 0) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time or 0) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.task == "embedding": + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs, finished_before) + + return None + + def _advance_to_next_step( + self, output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) + else: + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs) + + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + + .. figure:: https://i.imgur.com/sv2HssD.png + :alt: Overview of the step function + :align: center + + Overview of the step function. + + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the distributed executor to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + >>> # Please see the example/ folder for more detailed examples. + >>> + >>> # initialize engine and request arguments + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> example_inputs = [(0, "What is LLM?", + >>> SamplingParams(temperature=0.0))] + >>> + >>> # Start the engine with an event loop + >>> while True: + >>> if example_inputs: + >>> req_id, prompt, sampling_params = example_inputs.pop(0) + >>> engine.add_request(str(req_id),prompt,sampling_params) + >>> + >>> # continue the request processing + >>> request_outputs = engine.step() + >>> for request_output in request_outputs: + >>> if request_output.finished: + >>> # return or show the request output + >>> + >>> if not (engine.has_unfinished_requests() or example_inputs): + >>> break + """ + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None - # 1) Process raw inputs into the request. - detokenizer_req, engine_core_req = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - trace_headers, prompt_adapter_request, priority) + if not scheduler_outputs.is_empty(): + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() - # 2) Add the request to Detokenizer. - self.detokenizer.add_request(detokenizer_req) + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) - # 3) Add the request to EngineCore. - self.engine_core.add_request(engine_core_req) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) - def step(self) -> List[RequestOutput]: + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] - # 1) Get EngineCoreOutput from the EngineCore. - engine_core_outputs = self.engine_core.get_output() + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) - # 2) Detokenizer the EngineCoreOutput. - request_outputs, requests_to_abort = self.detokenizer.step( - engine_core_outputs) + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] - # 3) Abort requests that finished due to stopping criteria. - if requests_to_abort: - self.abort_request(requests_to_abort) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() - return request_outputs + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() - # TODO(rob): Can we get rid of these? + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 - def get_model_config(self): - pass + # Add results to the output_queue + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") + + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError("All running sequence groups should " + "have the same remaining steps.") + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: + """Forced log when no requests active.""" + if self.log_stats: + stats = self._get_stats(scheduler_outputs, model_output, + finished_before, skip) + for logger in self.stat_loggers.values(): + logger.log(stats) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. + """ + now = time.time() + + # System State + # Scheduler State + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) + + # KV Cache Usage in % + num_total_gpu = self.cache_config.num_gpu_blocks + gpu_cache_usage_sys = 0. + if num_total_gpu: # Guard against both None and 0 + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage_sys = 0. + if num_total_cpu: # Guard against both None and 0 + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + num_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + time_queue_requests: List[float] = [] + time_inference_requests: List[float] = [] + time_prefill_requests: List[float] = [] + time_decode_requests: List[float] = [] + time_in_queue_requests: List[float] = [] + model_forward_time_requests: List[float] = [] + model_execute_time_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + n_requests: List[int] = [] + max_num_generation_tokens_requests: List[int] = [] + max_tokens_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # Lora requests + running_lora_adapters = dict( + collectionsCounter([ + running_request.lora_request.lora_name + for scheduler in self.scheduler + for running_request in scheduler.running + if running_request.lora_request + ])) + waiting_lora_adapters = dict( + collectionsCounter([ + waiting_request.lora_request.lora_name + for scheduler in self.scheduler + for waiting_request in scheduler.waiting + if waiting_request.lora_request + ])) + max_lora_stat = "0" + if self.lora_config: + max_lora_stat = str(self.lora_config.max_loras) + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. + if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + + num_generation_tokens_from_prefill_groups = 0. + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + + group_was_prefill = idx < scheduler_outputs.num_prefill_groups + seq_group = scheduled_seq_group.seq_group + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_latency(now) + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_latency(now) + time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. + if seq_group.is_finished(): + # Latency timings + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + if (seq_group.metrics.first_scheduled_time is not None and + seq_group.metrics.first_token_time is not None): + time_queue_requests.append( + seq_group.metrics.first_scheduled_time - + seq_group.metrics.arrival_time) + time_prefill_requests.append( + seq_group.metrics.first_token_time - + seq_group.metrics.first_scheduled_time) + time_decode_requests.append( + now - seq_group.metrics.first_token_time) + time_inference_requests.append( + now - seq_group.metrics.first_scheduled_time) + if seq_group.metrics.time_in_queue is not None: + time_in_queue_requests.append( + seq_group.metrics.time_in_queue) + if seq_group.metrics.model_forward_time is not None: + model_forward_time_requests.append( + seq_group.metrics.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + model_execute_time_requests.append( + seq_group.metrics.model_execute_time * 1000) + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + max_num_generation_tokens_requests.append( + max(seq.get_output_len() + for seq in seq_group.get_seqs())) + if seq_group.sampling_params is not None: + n_requests.append(seq_group.sampling_params.n) + max_tokens_requests.append( + seq_group.sampling_params.max_tokens) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + actual_num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + num_tokens_iter = (num_generation_tokens_iter + + num_prompt_tokens_iter) + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and (model_output[0].spec_decode_worker_metrics + is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + + return Stats( + now=now, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + num_tokens_iter=num_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + time_queue_requests=time_queue_requests, + time_inference_requests=time_inference_requests, + time_prefill_requests=time_prefill_requests, + time_decode_requests=time_decode_requests, + time_in_queue_requests=time_in_queue_requests, + model_forward_time_requests=model_forward_time_requests, + model_execute_time_requests=model_execute_time_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + max_num_generation_tokens_requests= + max_num_generation_tokens_requests, + n_requests=n_requests, + max_tokens_requests=max_tokens_requests, + finished_reason_requests=finished_reason_requests, + max_lora=str(max_lora_stat), + waiting_lora_adapters=list(waiting_lora_adapters.keys()), + running_lora_adapters=list(running_lora_adapters.keys())) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + def start_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + + def do_tracing(self, + scheduler_outputs: SchedulerOutputs, + finished_before: Optional[List[int]] = None) -> None: + if self.tracer is None: + return + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double tracing when using async output proc + if finished_before and idx in finished_before: + continue + + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: + if self.tracer is None or seq_group.sampling_params is None: + return + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + + trace_context = extract_trace_context(seq_group.trace_headers) + + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time + # attribute names are based on + # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md + seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, + self.model_config.model) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE, + seq_group.sampling_params.temperature) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P, + seq_group.sampling_params.top_p) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, + seq_group.sampling_params.max_tokens) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, + seq_group.sampling_params.n) + seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, + seq_group.num_seqs()) + seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, + metrics.time_in_queue) + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) def is_encoder_decoder_model(self): - pass + return self.input_preprocessor.is_encoder_decoder_model() + + def _validate_model_inputs(self, inputs: ProcessorInputs, + lora_request: Optional[LoRARequest]): + if is_encoder_decoder_inputs(inputs): + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] + else: + prompt_inputs = inputs + + prompt_ids = prompt_inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + + if (guided_decoding := sampling_params.guided_decoding) is not None: + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_openai_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None - def start_profile(self): - pass + if len(sampling_params.bad_words) > 0: + tokenizer = self.get_tokenizer(lora_request) + processors = get_bad_words_logits_processors( + bad_words=sampling_params.bad_words, tokenizer=tokenizer) + logits_processors.extend(processors) - def stop_profile(self): - pass + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) - def get_tokenizer_group(self, group_type): - pass + return sampling_params From 4ade5b0f3bcc2565be27c1eb68165d512cbed6c7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 03:24:12 +0000 Subject: [PATCH 10/33] fixt --- vllm/v1/engine/llm_engine.py | 550 +++++++---------------------------- 1 file changed, 104 insertions(+), 446 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c739e25312ef5..af8f28377f31a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,176 +1,70 @@ -import time -from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, - Union) +from typing import Dict, List, Mapping, Optional, Type, Union -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderLLMInputs, InputRegistry, PromptType) -from vllm.inputs.preprocess import InputPreprocessor +from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING +from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.v1.core.scheduler import Scheduler -# from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.processor import Processor +from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.executor.tpu_executor import TPUExecutor -from vllm.v1.request import Request, RequestStatus -from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs -from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) class LLMEngine: + """Legacy LLMEngine for backwards compatibility.""" def __init__( self, vllm_config: VllmConfig, - executor_class: Type[TPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, + multiprocess_mode: bool = False, ) -> None: - # TODO: remove the local variables and use self.* throughout the class. - model_config = self.model_config = vllm_config.model_config - cache_config = self.cache_config = vllm_config.cache_config - lora_config = self.lora_config = vllm_config.lora_config - parallel_config = self.parallel_config = vllm_config.parallel_config - scheduler_config = self.scheduler_config = vllm_config.scheduler_config - device_config = self.device_config = vllm_config.device_config - speculative_config = self.speculative_config = vllm_config.speculative_config # noqa - load_config = self.load_config = vllm_config.load_config - decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + # TODO: Can we avoid this? + self.model_config = vllm_config.model_config + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + enable_lora=bool(vllm_config.lora_config)) + self.tokenizer.ping() + + # Processor (convert Inputs --> EngineCoreRequests) + self.processor = Processor(vllm_config.model_config, + vllm_config.lora_config, self.tokenizer, + input_registry) + + # Detokenizer (converts EngineCoreOutputs --> RequestOutput) + self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + + # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) + self.engine_core = EngineCoreClient.make_client( + vllm_config, + executor_class, + usage_context, + multiprocess_mode=multiprocess_mode, + asyncio_mode=False, ) - prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa - observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa - ) - - # Override the configs for V1. - # FIXME - if usage_context == UsageContext.LLM_CLASS: - scheduler_config.max_num_seqs = 1024 - scheduler_config.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - scheduler_config.max_num_seqs = 1024 - scheduler_config.max_num_batched_tokens = 2048 - - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.num_scheduler_steps, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - model_config.mm_processor_kwargs, - ) - - self.log_stats = log_stats - - assert not self.model_config.skip_tokenizer_init - self.tokenizer = self._init_tokenizer() - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - self.detokenizer = Detokenizer(self.model_config.tokenizer) - - self.generation_config_fields = _load_generation_config_dict( - model_config) - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - model_config) - - # Request id -> Request - self.requests: Dict[str, Request] = {} - # NOTE(woosuk): Now that the detokenizer works asynchronously, we need - # to keep track of how many steps each request has been lagged behind - # in terms of detokenization. - # Request id -> how many detokenizer steps the request should wait for. - self.num_lagged_steps: Dict[str, int] = {} - # OPTIMIZATION: Cache the request output and update it incrementally. - # This is used to avoid creating a new RequestOutput object every step. - # Request id -> RequestOutput - self.request_outputs: Dict[str, RequestOutput] = {} - - self.model_executor = executor_class(vllm_config=vllm_config) - assert self.model_config.task != "embedding" - self._initialize_kv_caches() - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) - - def _initialize_kv_caches(self) -> None: - num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( - ) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = 0 - self.model_executor.initialize_cache(num_gpu_blocks) @classmethod def from_engine_args( @@ -178,71 +72,51 @@ def from_engine_args( engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Create the LLM engine. - engine = cls( - vllm_config=engine_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - def _init_tokenizer(self) -> BaseTokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) - - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - if self.prompt_adapter_config: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) - - def _add_processed_request( - self, - request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - ) -> None: - assert prompt_adapter_request is None - assert trace_headers is None - self._validate_model_inputs(processed_inputs) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - # TODO(woosuk): Support embedding mode. - assert isinstance(params, SamplingParams) - sampling_params = params.clone() - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) - - # TODO(woosuk): Check max_logprobs - # TODO(woosuk): Support encoder-decoder models. - req = Request(request_id, processed_inputs, params, eos_token_id, - arrival_time) - self.requests[request_id] = req - self.num_lagged_steps[request_id] = 0 - self.scheduler.add_request(req) + vllm_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(vllm_config) + + if VLLM_ENABLE_V1_MULTIPROCESSING: + logger.debug("Enabling multiprocessing for LLMEngine.") + enable_multiprocessing = True + + # Create the LLMEngine. + return cls(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing) + + @classmethod + def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_tpu(): + return TPUExecutor + return GPUExecutor def stop_remote_worker_execution_loop(self) -> None: raise NotImplementedError("TP not implemented yet.") + def get_num_unfinished_requests(self) -> int: + return self.detokenizer.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + return self.detokenizer.has_unfinished_requests() + + @classmethod + def validate_outputs(cls, outputs, output_type): + return outputs + + def abort_request(self, request_ids: List[str]) -> None: + """Remove request_ids from EngineCore and Detokenizer.""" + + self.engine_core.abort_requests(request_ids) + self.detokenizer.abort_requests(request_ids) + def add_request( self, request_id: str, @@ -254,262 +128,46 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - if arrival_time is None: - arrival_time = time.time() - assert priority == 0, "vLLM V1 does not support priority at the moment." - - preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - ) - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - self.scheduler.finish_requests(request_id, - RequestStatus.FINISHED_ABORTED) - self._free_request(request_id) + # 1) Process raw inputs into the request. + detokenizer_req, engine_core_req = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return len(self.requests) + # 2) Add the request to Detokenizer. + self.detokenizer.add_request(detokenizer_req) - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return len(self.requests) > 0 + # 3) Add the request to EngineCore. + self.engine_core.add_request(engine_core_req) def step(self) -> List[RequestOutput]: - # NOTE(woosuk): This method may return an empty list when the - # detokenizer is still processing the outputs. This should not be - # considered as the end of the generation process. - # FIXME(woosuk): Currently, the step method is inefficient because it - # creates RequestOutput objects for all running requests, while they - # may not be needed unless the output is streamed to the client. - if self.scheduler.has_unfinished_requests(): - scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) - sampled = self.scheduler.update_from_output( - scheduler_output, output) - self.send_to_detokenizer(sampled) - req_outputs = self.recv_from_detokenizer() - return req_outputs - - def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: - inputs = DetokenizerInputs( - req_ids=[], - prompt_token_ids=[], - new_token_ids=[], - skip_special_tokens=[], - spaces_between_special_tokens=[], - free_req_ids=[], # TODO(woosuk): Implement freeing. - ) - for req, num_tokens in sampled: - inputs.req_ids.append(req.request_id) - if len(req.output_token_ids) == num_tokens: - # The request is first detokenized. - inputs.prompt_token_ids.append(req.prompt_token_ids) - else: - # The prompt token ids are already cached in the detokenizer. - inputs.prompt_token_ids.append([]) - inputs.new_token_ids.append(req.output_token_ids[-num_tokens:]) - inputs.skip_special_tokens.append( - req.sampling_params.skip_special_tokens) - inputs.spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens) - - # Update the number of lagged steps. - self.num_lagged_steps[req.request_id] += 1 - self.detokenizer.send(inputs) - - def recv_from_detokenizer(self) -> List[RequestOutput]: - detokenizer_output = self.detokenizer.recv() - if detokenizer_output is None: - return [] - - req_outputs: List[RequestOutput] = [] - num_reqs = len(detokenizer_output.req_ids) - for i in range(num_reqs): - req_id = detokenizer_output.req_ids[i] - if req_id not in self.requests: - # The request has been aborted while the detokenizer was - # processing the outputs. - continue - - req = self.requests[req_id] - req.output_text += detokenizer_output.detokenized_texts[i] - - self.num_lagged_steps[req_id] -= 1 - finished = (self.num_lagged_steps[req_id] == 0 - and req.is_finished()) - req_output = self._make_request_output( - req, detokenizer_output.num_output_token_ids[i], - detokenizer_output.detokenized_texts[i], finished) - req_outputs.append(req_output) - - if finished: - self._free_request(req_id) - return req_outputs - - def terminate_detokenizer(self) -> None: - self.detokenizer.terminate() - - def _make_request_output( - self, - request: Request, - num_output_tokens: int, - new_output_text: str, - finished: bool, - ) -> RequestOutput: - req_output = self.request_outputs.get(request.request_id) - if req_output is None: - # TODO: Support `n` > 1. - completion_output = CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, # TODO - finish_reason=None, - stop_reason=None, - lora_request=None, - ) - req_output = RequestOutput( - request_id=request.request_id, - prompt=request.prompt, - prompt_token_ids=request.prompt_token_ids, - prompt_logprobs=None, # TODO - outputs=[completion_output], - finished=False, - metrics=None, - lora_request=None, - encoder_prompt=None, - encoder_prompt_token_ids=None, - ) - self.request_outputs[request.request_id] = req_output - - completion_output = req_output.outputs[0] - if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE: - completion_output.text += new_output_text - completion_output.token_ids = ( - request.output_token_ids[:num_output_tokens]) - elif request.sampling_params.output_kind == RequestOutputKind.DELTA: - completion_output.text = new_output_text - num_prev_tokens = len(completion_output.token_ids) - completion_output.token_ids = request.output_token_ids[ - num_prev_tokens:num_output_tokens] - elif (request.sampling_params.output_kind == - RequestOutputKind.FINAL_ONLY): - if finished: - completion_output.text = request.output_text - completion_output.token_ids = request.output_token_ids - else: - completion_output.text = "" - completion_output.token_ids = [] - - if finished: - completion_output.finish_reason = request.get_finished_reason() - completion_output.stop_reason = request.stop_reason - req_output.finished = finished - return req_output - - def _free_request(self, request_id: str) -> None: - self.requests.pop(request_id, None) - self.num_lagged_steps.pop(request_id, None) - self.request_outputs.pop(request_id, None) - - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderLLMInputs]): - prompt_ids = inputs.get("prompt_token_ids") - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") - - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - @classmethod - def validate_outputs(cls, outputs, output_type): - return outputs - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config + # 1) Get EngineCoreOutput from the EngineCore. + engine_core_outputs = self.engine_core.get_output() - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config + # 2) Detokenizer the EngineCoreOutput. + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config + # 3) Abort requests that finished due to stopping criteria. + if requests_to_abort: + self.abort_request(requests_to_abort) - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config + return request_outputs - @classmethod - def _get_executor_cls(cls, engine_config: VllmConfig): - # return GPUExecutor - return TPUExecutor - - def is_tracing_enabled(self) -> bool: - return False + # TODO(rob): Can we get rid of these? - def do_log_stats(self, *args, **kwargs) -> None: + def get_model_config(self): pass - def is_encoder_decoder_model(self) -> bool: - return False - - def start_profile(self) -> None: + def is_encoder_decoder_model(self): pass - def stop_profile(self) -> None: + def start_profile(self): pass - def get_tokenizer_group(self, *args, **kwargs): - return self.tokenizer - - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} + def stop_profile(self): + pass - return config.to_diff_dict() + def get_tokenizer_group(self, group_type): + pass From dc784518e9bed729abfedb090e5f49fea629aa7f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 13:37:56 +0000 Subject: [PATCH 11/33] warmup is working! --- vllm/attention/selector.py | 7 +++- vllm/v1/attention/backends/pallas.py | 26 ++++++++----- vllm/v1/engine/core.py | 5 ++- vllm/v1/executor/tpu_executor.py | 6 +-- vllm/v1/worker/tpu_model_runner.py | 55 +++++++++++++++------------- vllm/v1/worker/tpu_worker.py | 16 ++++++++ 6 files changed, 72 insertions(+), 43 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 478aebda68fd3..32100f9fd5f16 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -237,9 +237,12 @@ def which_attn_to_use(head_size: int, return _Backend.IPEX if current_platform.is_tpu(): - if selected_backend != _Backend.PALLAS_VLLM_V1: + if (selected_backend != _Backend.PALLAS and + selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS_VLLM_V1 + if use_v1: + return _Backend.PALLAS_VLLM_V1 + return _Backend.PALLAS if current_platform.is_rocm(): # AMD GPUs. diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index a74aad0108190..936b662caab5b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -36,7 +36,7 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_kv_heads, num_blocks, block_size, head_size) @dataclass @@ -115,7 +115,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, @@ -124,13 +124,14 @@ def forward( """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache will be an empty tensor for profiling run attn_metadata: Metadata for attention. Returns: - shape = [num_tokens, num_heads * head_size] + shape = [batch_size, seq_len, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0, ( @@ -142,6 +143,10 @@ def forward( "are not implemented for " "PallasAttentionImpl") + # Empty KV cache when profiling (skip write to cache). + is_profiling = kv_cache[0].numel() == 0 + + # Unpack batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) @@ -149,9 +154,10 @@ def forward( self.head_size) # Write to KV cache. - if kv_cache.numel() > 0: + if not is_profiling: slot_mapping = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache + key_cache = kv_cache[0] + value_cache = kv_cache[1] write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale @@ -181,7 +187,7 @@ def forward( output = output.permute(0, 2, 1, 3) else: # Decoding run. - assert kv_cache[0].numel() > 0 + assert not is_profiling query = query.squeeze(dim=1) pages_per_compute_block = 16 # TODO(woosuk): Tune this value. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f9d3473d0131c..170df4fce01a7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -40,8 +40,9 @@ def __init__( # Override the configs for V1. # FIXME if usage_context == UsageContext.LLM_CLASS: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 8192 + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 8192 + pass elif usage_context == UsageContext.OPENAI_API_SERVER: vllm_config.scheduler_config.max_num_seqs = 1024 vllm_config.scheduler_config.max_num_batched_tokens = 2048 diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py index 122917a5e242d..03ba0ca72359f 100644 --- a/vllm/v1/executor/tpu_executor.py +++ b/vllm/v1/executor/tpu_executor.py @@ -53,14 +53,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.worker.determine_num_available_blocks() - def initialize_cache(self, num_gpu_blocks: int) -> None: + def initialize_cache(self, num_tpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker. """ # NOTE: This is logged in the executor because there can be >1 worker # with other executors. We could log in the engine level, but work # remains to abstract away the device for non-GPU configurations. - logger.info("# TPU blocks: %d", num_gpu_blocks) - self.worker.initialize_cache(num_gpu_blocks) + logger.info("# TPU blocks: %d", num_tpu_blocks) + self.worker.initialize_cache(num_tpu_blocks) self.worker.compile_or_warm_up_model() def execute_model( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 22509028104e3..0c7c902e3e772 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -30,7 +30,6 @@ from vllm.v1.core.scheduler import SchedulerOutput MIN_BATCH_SIZE = 8 -BATCH_SIZE_MULTIPLE = 16 logger = init_logger(__name__) @@ -53,6 +52,8 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + print(f"{self.scheduler_config.max_num_seqs=}") + model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config @@ -80,7 +81,8 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model - self.kv_caches: List[torch.Tensor] = [] + # List[k_cache, v_cache]s + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] # Request states. self.requests: Dict[str, CachedRequestState] = {} @@ -445,21 +447,16 @@ def _dummy_run( dtype=torch.int64, device=self.device ) - - if is_prompt: - block_tables = None - context_lens = None - else: - block_tables = torch.zeros( - (batch_size, self.max_num_blocks_per_seq), - dtype=torch.int32, - device=self.device, - ) - context_lens = torch.ones( - (batch_size, ), - dtype=torch.int32, - device=self.device, - ) + block_tables = None if is_prompt else torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device, + ) + context_lens = None if is_prompt else torch.ones( + (batch_size, ), + dtype=torch.int32, + device=self.device, + ) attn_metadata = PallasAttentionMetadata( is_prompt=is_prompt, slot_mapping=slot_mapping, @@ -503,9 +500,10 @@ def profile_run(self) -> None: # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. - dummy_kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers) ] # Round to multiple of 16. @@ -546,7 +544,7 @@ def capture_model(self) -> None: # Decode shapes. start = time.time() seq_len = 1 - batch_size = MIN_BATCH_SIZE # Must be in sync with _get_padded_batch_size() + batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=False) xm.wait_device_ops() @@ -565,17 +563,21 @@ def initialize_kv_cache(self, num_blocks: int) -> None: kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) for _ in range(self.num_attn_layers): - self.kv_caches.append( + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, - device=self.device)) + device=self.device), + )) def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. # To meet this requirement in the simplest way, we set the minimal batch # size to 8 (== MIN_BATCH_SIZE). - if batch_size <= MIN_BATCH_SIZE: - return MIN_BATCH_SIZE + if batch_size <= 8: + return 8 else: return ((batch_size + 15) // 16) * 16 @@ -889,7 +891,8 @@ def forward( """ # Skip this in memory profiling at initialization. - if kv_caches[0].numel() > 0: + is_profiling = kv_caches[0][0].numel() == 0 + if not is_profiling: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 411c180a81d4e..e8c5384a27d00 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.logger import init_logger from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size @@ -19,6 +20,8 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput +logger = init_logger(__name__) + class TPUWorker: def __init__( self, @@ -116,6 +119,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: m = xm.get_memory_info(self.device) total_tpu_memory = m["bytes_limit"] peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", + peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) cache_block_size = _get_cache_block_size(self.cache_config, self.model_config, @@ -147,6 +154,15 @@ def initialize_cache(self, num_tpu_blocks: int) -> None: self.model_runner.initialize_kv_cache(num_tpu_blocks) + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 7f8fdee16ae40e5271ffc9e10ded7ec88239e7e7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 12 Nov 2024 13:38:32 +0000 Subject: [PATCH 12/33] stash --- vllm/v1/worker/tpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0c7c902e3e772..4464dfe539b06 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,8 +29,6 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput -MIN_BATCH_SIZE = 8 - logger = init_logger(__name__) class TPUModelRunner: From f7de1b44d26b1e8dcdffdcef4a2a55c6f34f14ab Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 15 Nov 2024 20:55:14 +0000 Subject: [PATCH 13/33] stash --- vllm/v1/worker/tpu_model_runner.py | 380 ++++++++++++++++------------- 1 file changed, 216 insertions(+), 164 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4464dfe539b06..611ed675086e9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -31,6 +31,21 @@ logger = init_logger(__name__) +_PAD_SLOT_ID = 1_000_000_000 + +@dataclass +class PrefillData: + request_ids: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, + self.token_ids, + self.position_ids, + self.attn_metadata) + class TPUModelRunner: def __init__( @@ -50,8 +65,6 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config - print(f"{self.scheduler_config.max_num_seqs=}") - model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config @@ -77,9 +90,7 @@ def __init__( self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() - # Lazy initialization - # self.model: nn.Module # Set after load_model - # List[k_cache, v_cache]s + # List[k_cache, v_cache] self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] # Request states. @@ -93,12 +104,11 @@ def __init__( pin_memory=self.pin_memory, ) - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) + self.prefill_positions = torch.Tensor( + range(self.max_model_len), + dtype=torch.int64, + device="cpu", + ) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -164,148 +174,196 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the cached states of the resumed requests. for req_data in scheduler_output.scheduled_resumed_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] + # TODO: handle preemption. + assert False - req_state.block_ids = req_data.block_ids - req_state.num_computed_tokens = req_data.num_computed_tokens - req_ids_to_add.append(req_id) + # Condense the batched states if there are empty indices. + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) + # These are added at the end after the bacth is condensed. + self.num_prefills = len(req_ids_to_add) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) + self.input_batch.add_request(req_state, None) - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table[:num_reqs].copy_( - self.input_batch.block_table_cpu_tensor[:num_reqs], - non_blocking=True) + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + num_prefills = self.input_batch.num_prefills + + assert num_decodes + num_prefills > 0 + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] max_num_scheduled_tokens = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + + # Assert Decodes Are Decodes. + if idx < num_decodes: + assert num_scheduled_tokens == 1 + num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) assert max_num_scheduled_tokens > 0 - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - indices = np.arange(num_reqs) - req_indices = np.repeat(indices, num_scheduled_tokens) - - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), - (num_reqs, 1)) - mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] - arange = arange_matrix[mask] - - # Get positions. - positions = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - positions_np = positions.numpy() - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = positions_np + req_indices * self.max_model_len - token_indices = torch.from_numpy(token_indices) - input_ids = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.index_select(torch.from_numpy( - self.input_batch.token_ids_cpu).flatten(), - 0, - token_indices, - out=input_ids) - - # Calculate the slot mapping. - block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ - token_indices // self.block_size] - block_offsets = token_indices % self.block_size - slot_mapping = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.add(block_numbers * self.block_size, - block_offsets, - out=slot_mapping) - - # Prepare the attention metadata. - query_start_loc = torch.empty((num_reqs + 1, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - query_start_loc_np = query_start_loc.numpy() - query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) - - seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - max_seq_len = seq_lens.max() - seq_start_loc = torch.empty((num_reqs + 1, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - seq_start_loc_np = seq_start_loc.numpy() - seq_start_loc_np[0] = 0 - np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - - self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, - non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_(positions, - non_blocking=True) - - query_start_loc = query_start_loc.to(self.device, non_blocking=True) - seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) - slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() - attn_metadata = PallasAttentionMetadata( - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], - slot_mapping=slot_mapping, + + # PREFILLS + prefill_request_ids = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + for prefill_idx in range(num_decodes, num_prefills + num_decodes): + + # Pad to power of 2. + prefill_len = num_scheduled_tokens[prefill_idx] + padded_prefill_len = _get_padded_prefill_len(prefill_len) + assert padded_prefill_len < self.max_model_len + token_ids = self.input_batch.token_ids_cpu[prefill_idx, :padded_prefill_len] + positions = self.prefill_positions[:padded_prefill_len] + + # Block number / offsets for every token. + block_numbers = self.input_batch.block_table_cpu_tensor[prefill_idx, positions // self.block_size] + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + slot_mapping[prefill_len:] = _PAD_SLOT_ID + + attn_metadata = PallasAttentionMetadata( + is_prompt=True, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + ) + + prefill_request_ids.append(self.input_batch.req_ids[prefill_idx]) + prefill_token_ids.append(token_ids.to(self.device)) + prefill_position_ids.append(positions.to(self.device)) + prefill_attn_metadata.append(attn_metadata) + + prefill_data = PrefillData( + request_ids=prefill_request_ids, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - return attn_metadata, logits_indices + + return prefill_data, None + + # DECODES + + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # self.input_batch.block_table[:num_decodes].copy_( + # self.input_batch.block_table_cpu_tensor[:num_decodes], + # non_blocking=True) + + # # Get request indices. + # # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # indices = np.arange(num_reqs) + # req_indices = np.repeat(indices, num_scheduled_tokens) + + # # Get batched arange. + # # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + # (num_reqs, 1)) + # mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + # arange = arange_matrix[mask] + + # # Get positions. + # positions = torch.empty((total_num_scheduled_tokens, ), + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # positions_np = positions.numpy() + # np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + # arange, + # out=positions_np) + + # # Get token indices. + # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # # where M is the max_model_len. + # token_indices = positions_np + req_indices * self.max_model_len + # token_indices = torch.from_numpy(token_indices) + # input_ids = torch.empty((total_num_scheduled_tokens, ), + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # torch.index_select(torch.from_numpy( + # self.input_batch.token_ids_cpu).flatten(), + # 0, + # token_indices, + # out=input_ids) + + # # Calculate the slot mapping. + # block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ + # token_indices // self.block_size] + # block_offsets = token_indices % self.block_size + # slot_mapping = torch.empty((total_num_scheduled_tokens, ), + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # torch.add(block_numbers * self.block_size, + # block_offsets, + # out=slot_mapping) + + # # Prepare the attention metadata. + # query_start_loc = torch.empty((num_reqs + 1, ), + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # query_start_loc_np = query_start_loc.numpy() + # query_start_loc_np[0] = 0 + # np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + + # seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + # num_scheduled_tokens) + # max_seq_len = seq_lens.max() + # seq_start_loc = torch.empty((num_reqs + 1, ), + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # seq_start_loc_np = seq_start_loc.numpy() + # seq_start_loc_np[0] = 0 + # np.cumsum(seq_lens, out=seq_start_loc_np[1:]) + + # self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, + # non_blocking=True) + # self.positions[:total_num_scheduled_tokens].copy_(positions, + # non_blocking=True) + + # query_start_loc = query_start_loc.to(self.device, non_blocking=True) + # seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) + # slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + # attn_metadata = PallasAttentionMetadata( + # num_actual_tokens=total_num_scheduled_tokens, + # max_query_len=max_num_scheduled_tokens, + # query_start_loc=query_start_loc, + # max_seq_len=max_seq_len, + # seq_start_loc=seq_start_loc, + # block_table=self.input_batch.block_table[:num_reqs], + # slot_mapping=slot_mapping, + # ) + # # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # # request in the batch. While we should not sample any token from this + # # partial request, we do so for simplicity. We will ignore the sampled + # # token from the partial request. + # # TODO: Support prompt logprobs. + # logits_indices = query_start_loc[1:] - 1 + # return attn_metadata, logits_indices def _prepare_sampling( self, @@ -328,34 +386,29 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if True: - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self._get_padded_batch_size( - num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - - with set_forward_context(attn_metadata): - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - kv_caches=self.kv_caches, - attn_metadata=None, - ) - hidden_states = hidden_states[:num_scheduled_tokens] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) + prefill_data, decode_data = self._prepare_inputs(scheduler_output) - # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling(scheduler_output) - sampler_output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + for req_id, token_ids, position_ids, attn_metadata in prefill_data.zipped(): + token_id = self.model(token_ids, + position_ids, + attn_metadata) + + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # No chunked prefill so far. + assert seq_len == req_state.num_tokens + # Append the sampled token to the output token ids. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + + + + + # NOTE: CPU-GPU synchronization happens here. sampled_token_ids = sampler_output.sampled_token_ids.cpu() @@ -381,14 +434,7 @@ def execute_model( # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - if sampler_output.logprob_token_ids is None: - logprob_token_ids = None - else: - logprob_token_ids = sampler_output.logprob_token_ids.cpu() - if sampler_output.logprobs is None: - logprobs = None - else: - logprobs = sampler_output.logprobs.cpu() + model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids[:num_reqs], req_id_to_index=self.input_batch.req_id_to_index, @@ -672,6 +718,8 @@ def __init__( self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() + self.num_prefills = 0 + def add_request( self, request: "CachedRequestState", @@ -816,6 +864,10 @@ def make_sampling_metadata( def num_reqs(self) -> int: return len(self.req_id_to_index) + @property + def num_decodes(self) -> int: + return self.num_reqs - self.num_prefills + @property def all_greedy(self) -> bool: return len(self.random_reqs) == 0 From 5de1d9f0c8d9e0990ed4f2507d075e2b3ab530c1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 00:05:07 +0000 Subject: [PATCH 14/33] workin for prefill, except when I compile decode cudagraphs? --- vllm/v1/worker/tpu_model_runner.py | 146 ++++++++++++++++------------- 1 file changed, 81 insertions(+), 65 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 611ed675086e9..68f8e4eb692ef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -31,17 +31,21 @@ logger = init_logger(__name__) +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 @dataclass class PrefillData: request_ids: List + prompt_lens: List token_ids: List position_ids: List attn_metadata: List def zipped(self): return zip(self.request_ids, + self.prompt_lens, self.token_ids, self.position_ids, self.attn_metadata) @@ -104,11 +108,10 @@ def __init__( pin_memory=self.pin_memory, ) - self.prefill_positions = torch.Tensor( + self.prefill_positions = torch.tensor( range(self.max_model_len), - dtype=torch.int64, device="cpu", - ) + ).to(torch.int64).reshape(1,-1) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -184,7 +187,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add the new or resumed requests to the persistent batch. # These are added at the end after the bacth is condensed. - self.num_prefills = len(req_ids_to_add) + self.input_batch.num_prefills = len(req_ids_to_add) for req_id in req_ids_to_add: req_state = self.requests[req_id] self.input_batch.add_request(req_state, None) @@ -221,24 +224,28 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # PREFILLS prefill_request_ids = [] + prefill_prompt_lens = [] prefill_token_ids = [] prefill_position_ids = [] prefill_attn_metadata = [] for prefill_idx in range(num_decodes, num_prefills + num_decodes): - # Pad to power of 2. - prefill_len = num_scheduled_tokens[prefill_idx] - padded_prefill_len = _get_padded_prefill_len(prefill_len) - assert padded_prefill_len < self.max_model_len - token_ids = self.input_batch.token_ids_cpu[prefill_idx, :padded_prefill_len] - positions = self.prefill_positions[:padded_prefill_len] + prompt_len = num_scheduled_tokens[prefill_idx] + padded_prompt_len = _get_padded_prefill_len(prompt_len) + assert padded_prompt_len < self.max_model_len + + token_ids = torch.tensor( + self.input_batch.token_ids_cpu[prefill_idx, :padded_prompt_len].reshape(1,-1), + device=self.device + ) + positions = self.prefill_positions[:, :padded_prompt_len] # Block number / offsets for every token. - block_numbers = self.input_batch.block_table_cpu_tensor[prefill_idx, positions // self.block_size] + block_numbers = self.input_batch.block_table_cpu_tensor[prefill_idx, positions // self.block_size].reshape(1,-1) block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets - slot_mapping[prefill_len:] = _PAD_SLOT_ID + slot_mapping[prompt_len:] = _PAD_SLOT_ID attn_metadata = PallasAttentionMetadata( is_prompt=True, @@ -248,12 +255,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): ) prefill_request_ids.append(self.input_batch.req_ids[prefill_idx]) - prefill_token_ids.append(token_ids.to(self.device)) + prefill_prompt_lens.append(prompt_len) + prefill_token_ids.append(token_ids) prefill_position_ids.append(positions.to(self.device)) prefill_attn_metadata.append(attn_metadata) prefill_data = PrefillData( request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, token_ids=prefill_token_ids, position_ids=prefill_position_ids, attn_metadata=prefill_attn_metadata, @@ -386,63 +395,70 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - - for req_id, token_ids, position_ids, attn_metadata in prefill_data.zipped(): - token_id = self.model(token_ids, - position_ids, - attn_metadata) + prefill_data, _ = self._prepare_inputs(scheduler_output) + for req_id, prompt_len, token_ids, position_ids, attn_metadata in prefill_data.zipped(): + # [padded_prompt_len] + selected_token_ids = self.model( + token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + # TODO: move this into the model. + token_id = selected_token_ids[prompt_len - 1].cpu() + breakpoint() req_state = self.requests[req_id] + + # TODO: prefix caching. + assert req_state.num_computed_tokens == 0 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - # No chunked prefill so far. + # TODO: chunked prefill. assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + # Append the sampled token to the output token ids. req_idx = self.input_batch.req_id_to_index[req_id] self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id req_state.output_token_ids.append(token_id) - - - - - - - # NOTE: CPU-GPU synchronization happens here. - sampled_token_ids = sampler_output.sampled_token_ids.cpu() - sampled_token_ids_list = sampled_token_ids.tolist() - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - num_reqs = self.input_batch.num_reqs - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len <= req_state.num_tokens - if seq_len == req_state.num_tokens: - # Append the sampled token to the output token ids. - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - else: - # Ignore the sampled token from the partial request. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) - - - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, - logprob_token_ids_cpu=logprob_token_ids, - logprobs_cpu=logprobs, - ) - return model_runner_output + breakpoint() + + # # NOTE: CPU-GPU synchronization happens here. + # sampled_token_ids = sampler_output.sampled_token_ids.cpu() + # sampled_token_ids_list = sampled_token_ids.tolist() + # # TODO(woosuk): The following loop can be slow since it iterates over + # # the requests one by one. Optimize. + # num_reqs = self.input_batch.num_reqs + # for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + # req_state = self.requests[req_id] + # seq_len = (req_state.num_computed_tokens + + # scheduler_output.num_scheduled_tokens[req_id]) + # assert seq_len <= req_state.num_tokens + # if seq_len == req_state.num_tokens: + # # Append the sampled token to the output token ids. + # token_id = sampled_token_ids_list[i] + # self.input_batch.token_ids_cpu[i, seq_len] = token_id + # req_state.output_token_ids.append(token_id) + # else: + # # Ignore the sampled token from the partial request. + # # Rewind the generator state as if the token was not sampled. + # generator = self.input_batch.generators.get(i) + # if generator is not None: + # # This relies on cuda-specific torch-internal impl details + # generator.set_offset(generator.get_offset() - 4) + + + # model_runner_output = ModelRunnerOutput( + # req_ids=self.input_batch.req_ids[:num_reqs], + # req_id_to_index=self.input_batch.req_id_to_index, + # sampled_token_ids_cpu=sampled_token_ids, + # logprob_token_ids_cpu=logprob_token_ids, + # logprobs_cpu=logprobs, + # ) + # return model_runner_output def load_model(self) -> None: @@ -573,8 +589,9 @@ def capture_model(self) -> None: while True: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=True) xm.wait_device_ops() + xm.mark_step() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - + break if seq_len >= self.model_config.max_model_len: break num_tokens = batch_size * seq_len @@ -592,6 +609,7 @@ def capture_model(self) -> None: while True: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=False) xm.wait_device_ops() + xm.mark_step() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) if batch_size >= self.scheduler_config.max_num_seqs: @@ -916,6 +934,7 @@ def __call__(self, *args, is_prompt: bool, **kwargs): # 1: for prompt # 2: for decode # dispatch to the compiled code directly, skip PyTorch + print(f"{is_prompt=}") if is_prompt: with self.dispatch_to_code(1): return self.forward(*args, **kwargs) @@ -971,12 +990,9 @@ def forward( hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) - # Argmax sampling. + # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - num_samples = 1 - argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - - return argmax_token_ids + return argmax_token_ids.squeeze(dim=1) def _get_padded_prefill_len(x: int) -> int: From 15a2f74dfcf1d1c4dd530cee913450b8f81e5dc6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 01:18:56 +0000 Subject: [PATCH 15/33] working! It was the type of the position ids! --- vllm/v1/attention/backends/pallas.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 936b662caab5b..b80f402d2669f 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -44,8 +44,8 @@ class PallasAttentionMetadata: is_prompt: bool slot_mapping: torch.Tensor - block_tables: Optional[torch.Tensor] - context_lens: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None class PallasAttentionImpl(AttentionImpl): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 68f8e4eb692ef..e4f54f0a55950 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -111,7 +111,7 @@ def __init__( self.prefill_positions = torch.tensor( range(self.max_model_len), device="cpu", - ).to(torch.int64).reshape(1,-1) + ).to(torch.int32).reshape(1,-1) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -246,6 +246,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets slot_mapping[prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() attn_metadata = PallasAttentionMetadata( is_prompt=True, @@ -549,7 +550,6 @@ def _dummy_run( kv_caches, is_prompt=is_prompt) - def profile_run(self) -> None: """Profile to measure peak memory during forward pass.""" @@ -589,7 +589,6 @@ def capture_model(self) -> None: while True: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=True) xm.wait_device_ops() - xm.mark_step() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) break if seq_len >= self.model_config.max_model_len: @@ -609,7 +608,6 @@ def capture_model(self) -> None: while True: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=False) xm.wait_device_ops() - xm.mark_step() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) if batch_size >= self.scheduler_config.max_num_seqs: @@ -947,7 +945,7 @@ def forward( token_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: PallasAttentionMetadata, - kv_caches: List[torch.Tensor], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -960,8 +958,7 @@ def forward( """ # Skip this in memory profiling at initialization. - is_profiling = kv_caches[0][0].numel() == 0 - if not is_profiling: + if kv_caches[0][0].numel() != 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it From 14b9500e10b00af532a8027f83591d84eaacc2d1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 04:27:33 +0000 Subject: [PATCH 16/33] forward pass --- vllm/v1/worker/tpu_model_runner.py | 281 ++++++++++++++--------------- vllm/worker/tpu_model_runner.py | 5 + 2 files changed, 136 insertions(+), 150 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e4f54f0a55950..368099da93bc0 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -49,6 +49,13 @@ def zipped(self): self.token_ids, self.position_ids, self.attn_metadata) +@dataclass +class DecodeData: + num_decodes: int + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: PallasAttentionMetadata + class TPUModelRunner: @@ -203,7 +210,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): assert num_decodes + num_prefills > 0 - # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] @@ -216,13 +222,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Assert Decodes Are Decodes. if idx < num_decodes: - assert num_scheduled_tokens == 1 + assert num_tokens == 1 num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) assert max_num_scheduled_tokens > 0 - - # PREFILLS + ######################### PREFILLS ######################### prefill_request_ids = [] prefill_prompt_lens = [] prefill_token_ids = [] @@ -269,111 +274,66 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): attn_metadata=prefill_attn_metadata, ) - return prefill_data, None - - # DECODES - - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - # self.input_batch.block_table[:num_decodes].copy_( - # self.input_batch.block_table_cpu_tensor[:num_decodes], - # non_blocking=True) - - # # Get request indices. - # # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - # indices = np.arange(num_reqs) - # req_indices = np.repeat(indices, num_scheduled_tokens) - - # # Get batched arange. - # # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), - # (num_reqs, 1)) - # mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] - # arange = arange_matrix[mask] - - # # Get positions. - # positions = torch.empty((total_num_scheduled_tokens, ), - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # positions_np = positions.numpy() - # np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - # arange, - # out=positions_np) - - # # Get token indices. - # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # # where M is the max_model_len. - # token_indices = positions_np + req_indices * self.max_model_len - # token_indices = torch.from_numpy(token_indices) - # input_ids = torch.empty((total_num_scheduled_tokens, ), - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # torch.index_select(torch.from_numpy( - # self.input_batch.token_ids_cpu).flatten(), - # 0, - # token_indices, - # out=input_ids) - - # # Calculate the slot mapping. - # block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ - # token_indices // self.block_size] - # block_offsets = token_indices % self.block_size - # slot_mapping = torch.empty((total_num_scheduled_tokens, ), - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # torch.add(block_numbers * self.block_size, - # block_offsets, - # out=slot_mapping) - - # # Prepare the attention metadata. - # query_start_loc = torch.empty((num_reqs + 1, ), - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # query_start_loc_np = query_start_loc.numpy() - # query_start_loc_np[0] = 0 - # np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) - - # seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - # num_scheduled_tokens) - # max_seq_len = seq_lens.max() - # seq_start_loc = torch.empty((num_reqs + 1, ), - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # seq_start_loc_np = seq_start_loc.numpy() - # seq_start_loc_np[0] = 0 - # np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - - # self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, - # non_blocking=True) - # self.positions[:total_num_scheduled_tokens].copy_(positions, - # non_blocking=True) - - # query_start_loc = query_start_loc.to(self.device, non_blocking=True) - # seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) - # slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() - # attn_metadata = PallasAttentionMetadata( - # num_actual_tokens=total_num_scheduled_tokens, - # max_query_len=max_num_scheduled_tokens, - # query_start_loc=query_start_loc, - # max_seq_len=max_seq_len, - # seq_start_loc=seq_start_loc, - # block_table=self.input_batch.block_table[:num_reqs], - # slot_mapping=slot_mapping, - # ) - # # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # # request in the batch. While we should not sample any token from this - # # partial request, we do so for simplicity. We will ignore the sampled - # # token from the partial request. - # # TODO: Support prompt logprobs. - # logits_indices = query_start_loc[1:] - 1 - # return attn_metadata, logits_indices + if num_decodes == 0: + return prefill_data, None + + ######################### DECODES ######################### + + # PAD FOR STATIC SHAPE + batch_size = _get_padded_batch_size(num_decodes) + + # INDEX FOR EACH SEQUENCE (current location). + index = torch.tensor(self.input_batch.num_computed_tokens_cpu[:num_decodes], + dtype=torch.int64).reshape(-1,1) + + # TOKEN_IDS + token_ids = torch.zeros((batch_size, 1), dtype=torch.int32) + token_ids[:num_decodes] = torch.gather( + input=torch.tensor(self.input_batch.token_ids_cpu), + dim=1, + index=index, + ) + + # POSITION_IDS + position_ids = torch.zeros((batch_size, 1), + dtype=torch.int64) + position_ids[:num_decodes] = index + + # SLOT_MAPPING + slot_mapping = torch.full( + (batch_size, 1), + _PAD_SLOT_ID, + dtype=torch.int64, + ) + block_number = torch.gather( + input=self.input_batch.block_table_cpu_tensor[:num_decodes], + dim=1, + index=(index // self.block_size) + ) + block_offset = index % self.block_size + slot_mapping[:num_decodes] = (block_number + block_offset) + + # BLOCK_TABLE + self.input_batch.block_table[:batch_size].copy_( + self.input_batch.block_table_cpu_tensor[:batch_size]) + + # CONTEXT_LENS + context_lens = torch.zeros(batch_size, dtype=torch.int32) + context_lens[:num_decodes] = (index.reshape(-1) + 1) + + decode_data = DecodeData( + num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=position_ids.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=self.input_batch.block_table[:batch_size], + context_lens=context_lens.to(self.device), + ) + ) + + return prefill_data, decode_data def _prepare_sampling( self, @@ -396,19 +356,59 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - prefill_data, _ = self._prepare_inputs(scheduler_output) - for req_id, prompt_len, token_ids, position_ids, attn_metadata in prefill_data.zipped(): + + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + ########## DECODES ########## + num_decodes = 0 + if decode_data: + num_decodes = decode_data.num_decodes + + selected_token_ids = self.model( + decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False + ) + + breakpoint() + token_ids = selected_token_ids[:num_decodes].cpu() + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:num_decodes] = token_ids + + for i, req_id in enumerate(self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # NO CHUNKED PREFILL + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ########## PREFILLS ########## + for idx, (req_id, prompt_len, + token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + # [padded_prompt_len] selected_token_ids = self.model( token_ids, position_ids, attn_metadata, self.kv_caches, - is_prompt=True) + is_prompt=True + ) # TODO: move this into the model. token_id = selected_token_ids[prompt_len - 1].cpu() - breakpoint() - + sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id] # TODO: prefix caching. @@ -425,41 +425,14 @@ def execute_model( self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id req_state.output_token_ids.append(token_id) - breakpoint() - - # # NOTE: CPU-GPU synchronization happens here. - # sampled_token_ids = sampler_output.sampled_token_ids.cpu() - # sampled_token_ids_list = sampled_token_ids.tolist() - # # TODO(woosuk): The following loop can be slow since it iterates over - # # the requests one by one. Optimize. - # num_reqs = self.input_batch.num_reqs - # for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - # req_state = self.requests[req_id] - # seq_len = (req_state.num_computed_tokens + - # scheduler_output.num_scheduled_tokens[req_id]) - # assert seq_len <= req_state.num_tokens - # if seq_len == req_state.num_tokens: - # # Append the sampled token to the output token ids. - # token_id = sampled_token_ids_list[i] - # self.input_batch.token_ids_cpu[i, seq_len] = token_id - # req_state.output_token_ids.append(token_id) - # else: - # # Ignore the sampled token from the partial request. - # # Rewind the generator state as if the token was not sampled. - # generator = self.input_batch.generators.get(i) - # if generator is not None: - # # This relies on cuda-specific torch-internal impl details - # generator.set_offset(generator.get_offset() - 4) - - - # model_runner_output = ModelRunnerOutput( - # req_ids=self.input_batch.req_ids[:num_reqs], - # req_id_to_index=self.input_batch.req_id_to_index, - # sampled_token_ids_cpu=sampled_token_ids, - # logprob_token_ids_cpu=logprob_token_ids, - # logprobs_cpu=logprobs, - # ) - # return model_runner_output + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + return model_runner_output def load_model(self) -> None: @@ -932,7 +905,6 @@ def __call__(self, *args, is_prompt: bool, **kwargs): # 1: for prompt # 2: for decode # dispatch to the compiled code directly, skip PyTorch - print(f"{is_prompt=}") if is_prompt: with self.dispatch_to_code(1): return self.forward(*args, **kwargs) @@ -992,6 +964,15 @@ def forward( return argmax_token_ids.squeeze(dim=1) +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a721186137328..0c73a5dde8be6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -417,6 +417,11 @@ def _prepare_decode( block_tables=block_tables, context_lens=context_lens, ) + + print(f"{input_tokens.shape=}") + print(f"{input_positions.shape=}") + print(f"{attn_metadata.slot_mapping.shape=}") + print(f"{attn_metadata.context_lens.shape=}") return input_tokens, input_positions, attn_metadata, input_lens def _prepare_sample( From 6eeecb7f100071369d66d655fbbe2df396253b60 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 04:36:23 +0000 Subject: [PATCH 17/33] correct output for single prompt with --enforce-eager --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 368099da93bc0..c8968903b4e11 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -375,7 +375,7 @@ def execute_model( is_prompt=False ) - breakpoint() + # breakpoint() token_ids = selected_token_ids[:num_decodes].cpu() sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids From 0b256c296b71ef86c4d8acfbde8883afb076c1bf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 04:42:15 +0000 Subject: [PATCH 18/33] end to end passing working for single request with CUDAGraphs! --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c8968903b4e11..65638d5912e64 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -296,7 +296,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # POSITION_IDS position_ids = torch.zeros((batch_size, 1), - dtype=torch.int64) + dtype=torch.int32) position_ids[:num_decodes] = index # SLOT_MAPPING From b44227dee5499b2b6058d4d01804885ec9a4cf22 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 20:39:32 +0000 Subject: [PATCH 19/33] yay! working with multiple requests! the issue was copy_() does not seem to work for xla --- vllm/v1/attention/backends/pallas.py | 13 ++++---- vllm/v1/core/kv_cache_manager.py | 1 + vllm/v1/worker/tpu_model_runner.py | 43 ++++++++++++++++---------- vllm/v1/worker/tpu_worker.py | 45 ++++++++++++++-------------- vllm/worker/tpu_model_runner.py | 15 ++++++---- 5 files changed, 68 insertions(+), 49 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index b80f402d2669f..dc976981e7fa3 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -127,8 +127,10 @@ def forward( query: shape = [batch_size, seq_len, num_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size] - NOTE: kv_cache will be an empty tensor for profiling run + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [batch_size, seq_len, num_heads * head_size] @@ -143,9 +145,6 @@ def forward( "are not implemented for " "PallasAttentionImpl") - # Empty KV cache when profiling (skip write to cache). - is_profiling = kv_cache[0].numel() == 0 - # Unpack batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) @@ -154,7 +153,7 @@ def forward( self.head_size) # Write to KV cache. - if not is_profiling: + if kv_cache[0].numel() > 0: slot_mapping = attn_metadata.slot_mapping key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -187,7 +186,7 @@ def forward( output = output.permute(0, 2, 1, 3) else: # Decoding run. - assert not is_profiling + assert kv_cache[0].numel() > 0 query = query.squeeze(dim=1) pages_per_compute_block = 16 # TODO(woosuk): Tune this value. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 38f1c03a4d3ac..a07f477ecbdc1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -37,6 +37,7 @@ def __init__( # N new empty blocks. self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + self.num_preallocate_blocks = 0 # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 65638d5912e64..767aad4817f2c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -35,6 +35,9 @@ # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 +from transformers import AutoTokenizer +tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + @dataclass class PrefillData: request_ids: List @@ -250,7 +253,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_numbers = self.input_batch.block_table_cpu_tensor[prefill_idx, positions // self.block_size].reshape(1,-1) block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets - slot_mapping[prompt_len:] = _PAD_SLOT_ID + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() attn_metadata = PallasAttentionMetadata( @@ -266,6 +269,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): prefill_position_ids.append(positions.to(self.device)) prefill_attn_metadata.append(attn_metadata) + prefill_data = PrefillData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, @@ -310,12 +314,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): dim=1, index=(index // self.block_size) ) - block_offset = index % self.block_size - slot_mapping[:num_decodes] = (block_number + block_offset) + block_offsets = index % self.block_size + slot_mapping[:num_decodes] = (block_number * self.block_size + block_offsets) # BLOCK_TABLE - self.input_batch.block_table[:batch_size].copy_( - self.input_batch.block_table_cpu_tensor[:batch_size]) + # cannot do a _copy - silently fails (cry) + block_table = self.input_batch.block_table_cpu_tensor[:batch_size] # CONTEXT_LENS context_lens = torch.zeros(batch_size, dtype=torch.int32) @@ -328,7 +332,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): attn_metadata=PallasAttentionMetadata( is_prompt=False, slot_mapping=slot_mapping.to(self.device), - block_tables=self.input_batch.block_table[:batch_size], + block_tables=block_table.to(self.device), context_lens=context_lens.to(self.device), ) ) @@ -356,9 +360,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - num_reqs = self.input_batch.num_reqs sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) @@ -374,8 +376,12 @@ def execute_model( self.kv_caches, is_prompt=False ) - - # breakpoint() + # print(decode_data.token_ids) + # print(decode_data.position_ids) + # print(decode_data.attn_metadata) + # print(tok.decode(self.requests["0"].output_token_ids)) + # # breakpoint() + token_ids = selected_token_ids[:num_decodes].cpu() sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids @@ -397,8 +403,9 @@ def execute_model( for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): - + # [padded_prompt_len] + # breakpoint() selected_token_ids = self.model( token_ids, position_ids, @@ -406,8 +413,14 @@ def execute_model( self.kv_caches, is_prompt=True ) + + # print(token_ids) + # print(position_ids) + # print(attn_metadata) + # breakpoint() + # TODO: move this into the model. - token_id = selected_token_ids[prompt_len - 1].cpu() + token_id = selected_token_ids[prompt_len - 1].cpu().item() sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id] @@ -653,9 +666,9 @@ def __init__( self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) # Attention-related. self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), @@ -930,7 +943,7 @@ def forward( """ # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() != 0: + if kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index e8c5384a27d00..666d550618b30 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -109,28 +109,29 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: by adjusting the `gpu_memory_utilization` parameter. """ - self.model_runner.profile_run() - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_tpu_memory = m["bytes_limit"] - peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", - peak_memory // 1024 // 1024 // 1024) - logger.debug("Total Memory: %sGB", - total_tpu_memory // 1024 // 1024 // 1024) - - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_tpu_blocks = int( - (total_tpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + # self.model_runner.profile_run() + + # # Synchronize before measuring the memory usage. + # xm.wait_device_ops() + + # # Get the maximum amount of memory used by the model weights and + # # intermediate activations. + # m = xm.get_memory_info(self.device) + # total_tpu_memory = m["bytes_limit"] + # peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. + # logger.debug("Peak Used: %sGB", + # peak_memory // 1024 // 1024 // 1024) + # logger.debug("Total Memory: %sGB", + # total_tpu_memory // 1024 // 1024 // 1024) + + # cache_block_size = _get_cache_block_size(self.cache_config, + # self.model_config, + # self.parallel_config) + # num_tpu_blocks = int( + # (total_tpu_memory * self.cache_config.gpu_memory_utilization - + # peak_memory) // cache_block_size) + # num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return 3144, 0 return num_tpu_blocks, 0 diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 0c73a5dde8be6..d85a7d47468bc 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -418,10 +418,6 @@ def _prepare_decode( context_lens=context_lens, ) - print(f"{input_tokens.shape=}") - print(f"{input_positions.shape=}") - print(f"{attn_metadata.slot_mapping.shape=}") - print(f"{attn_metadata.context_lens.shape=}") return input_tokens, input_positions, attn_metadata, input_lens def _prepare_sample( @@ -585,6 +581,10 @@ def execute_model( model_input.num_samples, kv_caches, is_prompt=True) + print(f"{token_ids=}") + print(f"{position_ids=}") + print(f"{attn_metadata=}") + # breakpoint() next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -629,6 +629,7 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, position_ids, attn_metadata, @@ -638,6 +639,10 @@ def execute_model( model_input.num_samples, kv_caches, is_prompt=False) + print(f"{token_ids=}") + print(f"{position_ids=}") + print(f"{attn_metadata=}") + # breakpoint() self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: @@ -769,7 +774,7 @@ def forward( logits = self.model.compute_logits(hidden_states, sampling_metadata) # Argmax sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) argmax_token_ids = argmax_token_ids.repeat(1, num_samples) # Zero temperature means greedy decoding. Avoid division by zero. From 451dfbff924bb3d634261f08c425411dcdb49515 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 22:22:51 +0000 Subject: [PATCH 20/33] yay! working end to end via lm eval harness! --- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/v1/core/scheduler.py | 16 ++++++---- vllm/v1/engine/async_llm.py | 4 +++ vllm/v1/engine/core.py | 5 +-- vllm/v1/worker/tpu_model_runner.py | 39 ++++++++--------------- vllm/v1/worker/tpu_worker.py | 3 +- 6 files changed, 33 insertions(+), 36 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b1d4461d164aa..7097f0db0610f 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -66,7 +66,7 @@ def run_test(more_args): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), +@pytest.mark.skipif(not current_platform.is_cuda() and not current_platform.is_tpu(), reason="V1 currently only supported on CUDA") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index d2ba0235aa9a0..fa3ea20e2ba04 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -22,6 +22,10 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], ) -> None: + # TODO: properly handle for TPU. + cache_config.enable_prefix_caching = False + scheduler_config.chunked_prefill_enabled = False + self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config @@ -148,12 +152,12 @@ def schedule(self) -> "SchedulerOutput": num_new_tokens = 1 computed_blocks.pop() - # Disabled Chunking. - if not self.scheduler_config.chunked_prefill_enabled: - if num_new_tokens > token_budget: - break - else: - num_new_tokens = min(num_new_tokens, token_budget) + # If chunked prefill is not enabled, breakout of the loop. + if (not self.scheduler_config.chunked_prefill_enabled and + num_new_tokens > token_budget): + break + + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 new_blocks = self.kv_cache_manager.allocate_slots( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2d7c58cfea13b..bd537fac50e76 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -20,6 +21,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor logger = init_logger(__name__) @@ -120,6 +122,8 @@ def shutdown(self): @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_tpu: + return TPUExecutor return GPUExecutor async def add_request( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 170df4fce01a7..09f7b0762b6ba 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -44,8 +44,9 @@ def __init__( # vllm_config.scheduler_config.max_num_batched_tokens = 8192 pass elif usage_context == UsageContext.OPENAI_API_SERVER: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 2048 + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 2048 + pass # TODO (ywang96): Enable APC by default when VLM supports it. if not vllm_config.model_config.is_multimodal_model: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 767aad4817f2c..27ab1d55bc7b4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,4 +1,3 @@ -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -9,12 +8,9 @@ import torch.distributed import torch.nn as nn import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalDataDict @@ -32,11 +28,9 @@ logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +# FIXME: Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 -from transformers import AutoTokenizer -tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") @dataclass class PrefillData: @@ -187,15 +181,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the cached states of the resumed requests. for req_data in scheduler_output.scheduled_resumed_reqs: - # TODO: handle preemption. - assert False + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. # Condense the batched states if there are empty indices. removed_req_indices = sorted(removed_req_indices, reverse=True) if removed_req_indices: self.input_batch.condense(removed_req_indices) - # Add the new or resumed requests to the persistent batch. + # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. # These are added at the end after the bacth is condensed. self.input_batch.num_prefills = len(req_ids_to_add) for req_id in req_ids_to_add: @@ -226,8 +225,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Assert Decodes Are Decodes. if idx < num_decodes: assert num_tokens == 1 - - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + assert max_num_scheduled_tokens > 0 ######################### PREFILLS ######################### @@ -241,7 +239,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Pad to power of 2. prompt_len = num_scheduled_tokens[prefill_idx] padded_prompt_len = _get_padded_prefill_len(prompt_len) - assert padded_prompt_len < self.max_model_len + assert padded_prompt_len <= self.max_model_len token_ids = torch.tensor( self.input_batch.token_ids_cpu[prefill_idx, :padded_prompt_len].reshape(1,-1), @@ -376,11 +374,6 @@ def execute_model( self.kv_caches, is_prompt=False ) - # print(decode_data.token_ids) - # print(decode_data.position_ids) - # print(decode_data.attn_metadata) - # print(tok.decode(self.requests["0"].output_token_ids)) - # # breakpoint() token_ids = selected_token_ids[:num_decodes].cpu() sampled_token_ids_list = token_ids.tolist() @@ -405,7 +398,6 @@ def execute_model( attn_metadata) in enumerate(prefill_data.zipped()): # [padded_prompt_len] - # breakpoint() selected_token_ids = self.model( token_ids, position_ids, @@ -413,18 +405,14 @@ def execute_model( self.kv_caches, is_prompt=True ) - - # print(token_ids) - # print(position_ids) - # print(attn_metadata) - # breakpoint() - # TODO: move this into the model. token_id = selected_token_ids[prompt_len - 1].cpu().item() sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id] # TODO: prefix caching. + if req_state.num_computed_tokens > 0: + breakpoint() assert req_state.num_computed_tokens == 0 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -576,7 +564,6 @@ def capture_model(self) -> None: self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=True) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - break if seq_len >= self.model_config.max_model_len: break num_tokens = batch_size * seq_len diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 666d550618b30..b016922460251 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -109,6 +109,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: by adjusting the `gpu_memory_utilization` parameter. """ + return 3144, 0 + # self.model_runner.profile_run() # # Synchronize before measuring the memory usage. @@ -131,7 +133,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # (total_tpu_memory * self.cache_config.gpu_memory_utilization - # peak_memory) // cache_block_size) # num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return 3144, 0 return num_tpu_blocks, 0 From d2ae4a52b862bce60a0f3b038d6ce31cc71527c3 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 22:35:00 +0000 Subject: [PATCH 21/33] we have end to end correctness --- tests/entrypoints/openai/test_accuracy.py | 4 +++- vllm/v1/core/kv_cache_manager.py | 1 - vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 7097f0db0610f..99e3eaca2f2de 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -61,12 +61,14 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] + print(f"{measured_value=}") assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda() and not current_platform.is_tpu(), +@pytest.mark.skipif(not current_platform.is_cuda() and + not current_platform.is_tpu(), reason="V1 currently only supported on CUDA") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a07f477ecbdc1..38f1c03a4d3ac 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -37,7 +37,6 @@ def __init__( # N new empty blocks. self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) - self.num_preallocate_blocks = 0 # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index bd537fac50e76..5b2561ecc98bf 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -31,7 +31,7 @@ class AsyncLLM(EngineClient): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 09f7b0762b6ba..95a910b3e1c4c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -18,6 +18,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.version import __version__ as VLLM_VERSION @@ -34,7 +35,7 @@ class EngineCore: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], usage_context: UsageContext, ): # Override the configs for V1. From 7dd18e0a8b37801a18f7efb089416ae4c247f246 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 22:36:25 +0000 Subject: [PATCH 22/33] nits --- vllm/worker/tpu_model_runner.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d85a7d47468bc..4e0f52a6bac39 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -581,10 +581,6 @@ def execute_model( model_input.num_samples, kv_caches, is_prompt=True) - print(f"{token_ids=}") - print(f"{position_ids=}") - print(f"{attn_metadata=}") - # breakpoint() next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -629,7 +625,6 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, position_ids, attn_metadata, @@ -639,10 +634,6 @@ def execute_model( model_input.num_samples, kv_caches, is_prompt=False) - print(f"{token_ids=}") - print(f"{position_ids=}") - print(f"{attn_metadata=}") - # breakpoint() self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: @@ -774,7 +765,7 @@ def forward( logits = self.model.compute_logits(hidden_states, sampling_metadata) # Argmax sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) argmax_token_ids = argmax_token_ids.repeat(1, num_samples) # Zero temperature means greedy decoding. Avoid division by zero. From d89200d016e17425f76b31b6748af85ef80b6c03 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 16 Nov 2024 23:04:07 +0000 Subject: [PATCH 23/33] updated --- vllm/v1/core/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index fa3ea20e2ba04..215afcbe0c9f7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -158,7 +158,6 @@ def schedule(self) -> "SchedulerOutput": break num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, computed_blocks) From 75c44b4b47de99432456dfa9b9fef7553129d863 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 19:18:44 +0000 Subject: [PATCH 24/33] update to call .cpu() before slicing to avoid recompilation --- benchmarks/benchmark_throughput.py | 7 +++++++ vllm/v1/executor/tpu_executor.py | 4 ++++ vllm/v1/worker/tpu_model_runner.py | 7 ++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 159cf055737ce..3a444489ad26d 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -7,6 +7,7 @@ from typing import List, Optional import torch +# import torch_xla.debug.metrics as met import uvloop from PIL import Image from tqdm import tqdm @@ -149,6 +150,8 @@ def run_vllm( use_beam_search = False + # met.clear_all() + if not use_beam_search: start = time.perf_counter() llm.generate(prompts, sampling_params, use_tqdm=True) @@ -168,6 +171,10 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() + + # print(met.metrics_report()) + # print(met.short_metrics_report()) + return end - start diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py index 03ba0ca72359f..910fd8c2f5583 100644 --- a/vllm/v1/executor/tpu_executor.py +++ b/vllm/v1/executor/tpu_executor.py @@ -8,6 +8,7 @@ logger = init_logger(__name__) +# import torch_xla.debug.profiler as xp class TPUExecutor: @@ -28,6 +29,8 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.worker.initialize() self.worker.load_model() + # self.server = xp.start_server(9012) + def _create_worker( self, local_rank: int = 0, @@ -67,6 +70,7 @@ def execute_model( self, scheduler_output, ) -> ModelRunnerOutput: + # xp.trace_detached('localhost:9012', "./profiles") output = self.worker.execute_model(scheduler_output) return output diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 27ab1d55bc7b4..4c7e6ec543337 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -375,7 +375,8 @@ def execute_model( is_prompt=False ) - token_ids = selected_token_ids[:num_decodes].cpu() + # NOTE: TPU<>CPU sync happens here. + token_ids = selected_token_ids.cpu()[:num_decodes] sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids @@ -405,8 +406,8 @@ def execute_model( self.kv_caches, is_prompt=True ) - # TODO: move this into the model. - token_id = selected_token_ids[prompt_len - 1].cpu().item() + # NOTE: TPU<>CPU sync happens here. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id] From 58e85eba90f9c438f3db484bf6fc011e20c2aea1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 19:26:00 +0000 Subject: [PATCH 25/33] a bit faster --- vllm/v1/worker/tpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4c7e6ec543337..868bb41d17365 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -376,6 +376,7 @@ def execute_model( ) # NOTE: TPU<>CPU sync happens here. + # It is important to call .cpu() first to avoid compilation on hotpath. token_ids = selected_token_ids.cpu()[:num_decodes] sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids @@ -407,6 +408,7 @@ def execute_model( is_prompt=True ) # NOTE: TPU<>CPU sync happens here. + # It is important to call .cpu() first to avoid compilation on hotpath. token_id = selected_token_ids.cpu()[prompt_len - 1].item() sampled_token_ids[num_decodes + idx] = token_id req_state = self.requests[req_id] From fcf46817bdc9c3af7a79fa3584d31653b4b4f94c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 21:23:16 +0000 Subject: [PATCH 26/33] better performance due to better input processing --- vllm/v1/worker/tpu_model_runner.py | 138 +++++++++++++++++------------ 1 file changed, 79 insertions(+), 59 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 868bb41d17365..ef5dfcb0954c1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -117,6 +117,7 @@ def __init__( device="cpu", ).to(torch.int32).reshape(1,-1) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. @@ -215,59 +216,73 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] - max_num_scheduled_tokens = 0 for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) # Assert Decodes Are Decodes. if idx < num_decodes: assert num_tokens == 1 - - assert max_num_scheduled_tokens > 0 ######################### PREFILLS ######################### + # Prefills run separately, each with shape [1, padded_prompt_len], + # due to lack of variable length flashattention. + # + # Due to static shapes, prefills are padded to the nearest power + # of two, such that we can avoid recompilation. + prefill_request_ids = [] prefill_prompt_lens = [] prefill_token_ids = [] prefill_position_ids = [] prefill_attn_metadata = [] - for prefill_idx in range(num_decodes, num_prefills + num_decodes): - # Pad to power of 2. - prompt_len = num_scheduled_tokens[prefill_idx] + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + # STATIC SHAPE: prefills are padded to the next power of 2. + prompt_len = num_scheduled_tokens[idx] padded_prompt_len = _get_padded_prefill_len(prompt_len) + prefill_prompt_lens.append(prompt_len) assert padded_prompt_len <= self.max_model_len - token_ids = torch.tensor( - self.input_batch.token_ids_cpu[prefill_idx, :padded_prompt_len].reshape(1,-1), - device=self.device + # TOKEN_IDS. + prefill_token_ids.append( + torch.from_numpy( + self.input_batch.token_ids_cpu[idx:idx+1, :padded_prompt_len] + ).to(self.device) ) + + # POSITIONS. positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append( + positions.to(self.device) + ) - # Block number / offsets for every token. - block_numbers = self.input_batch.block_table_cpu_tensor[prefill_idx, positions // self.block_size].reshape(1,-1) + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # We look up the block_idx in the block table (logical <> physical map) + # to compute this. + block_numbers = self.input_batch.block_table_cpu_tensor[idx, positions // self.block_size].reshape(1,-1) block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() - attn_metadata = PallasAttentionMetadata( - is_prompt=True, - slot_mapping=slot_mapping.to(self.device), - block_tables=None, - context_lens=None, + # ATTN_METADATA. + prefill_attn_metadata.append( + PallasAttentionMetadata( + is_prompt=True, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + ) ) - prefill_request_ids.append(self.input_batch.req_ids[prefill_idx]) - prefill_prompt_lens.append(prompt_len) - prefill_token_ids.append(token_ids) - prefill_position_ids.append(positions.to(self.device)) - prefill_attn_metadata.append(attn_metadata) - - prefill_data = PrefillData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, @@ -280,53 +295,58 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): return prefill_data, None ######################### DECODES ######################### + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1,1) + ) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] - # PAD FOR STATIC SHAPE - batch_size = _get_padded_batch_size(num_decodes) - - # INDEX FOR EACH SEQUENCE (current location). - index = torch.tensor(self.input_batch.num_computed_tokens_cpu[:num_decodes], - dtype=torch.int64).reshape(-1,1) - - # TOKEN_IDS - token_ids = torch.zeros((batch_size, 1), dtype=torch.int32) - token_ids[:num_decodes] = torch.gather( - input=torch.tensor(self.input_batch.token_ids_cpu), + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), dim=1, index=index, - ) - - # POSITION_IDS - position_ids = torch.zeros((batch_size, 1), - dtype=torch.int32) - position_ids[:num_decodes] = index - - # SLOT_MAPPING - slot_mapping = torch.full( - (batch_size, 1), - _PAD_SLOT_ID, - dtype=torch.int64, - ) + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # We look up the block_idx in the block table (logical <> physical map) + # to compute this. block_number = torch.gather( - input=self.input_batch.block_table_cpu_tensor[:num_decodes], + input=self.input_batch.block_table_cpu_tensor, dim=1, index=(index // self.block_size) ) block_offsets = index % self.block_size - slot_mapping[:num_decodes] = (block_number * self.block_size + block_offsets) - - # BLOCK_TABLE - # cannot do a _copy - silently fails (cry) - block_table = self.input_batch.block_table_cpu_tensor[:batch_size] + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[-num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = self.input_batch.block_table_cpu_tensor[:padded_batch_size] - # CONTEXT_LENS - context_lens = torch.zeros(batch_size, dtype=torch.int32) - context_lens[:num_decodes] = (index.reshape(-1) + 1) + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + # CPU<>TPU sync happens here. decode_data = DecodeData( num_decodes=num_decodes, token_ids=token_ids.to(self.device), - position_ids=position_ids.to(self.device), + position_ids=positions.to(self.device), attn_metadata=PallasAttentionMetadata( is_prompt=False, slot_mapping=slot_mapping.to(self.device), From d9dc36ad85373c293448cb6155c253068409a14c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 21:28:12 +0000 Subject: [PATCH 27/33] cleanup PR --- benchmarks/benchmark_throughput.py | 2 +- tests/entrypoints/openai/test_accuracy.py | 4 +- vllm/attention/selector.py | 4 +- vllm/config.py | 2 +- vllm/v1/core/scheduler.py | 8 +- vllm/v1/executor/tpu_executor.py | 10 +- vllm/v1/worker/tpu_model_runner.py | 188 ++++++++++------------ vllm/v1/worker/tpu_worker.py | 76 ++++----- vllm/worker/tpu_model_runner.py | 1 - 9 files changed, 133 insertions(+), 162 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3a444489ad26d..e64241c1a6e3c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -171,7 +171,7 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - + # print(met.metrics_report()) # print(met.short_metrics_report()) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 99e3eaca2f2de..d9ded25ee9163 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -67,8 +67,8 @@ def run_test(more_args): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda() and - not current_platform.is_tpu(), +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), reason="V1 currently only supported on CUDA") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 32100f9fd5f16..9919c31ef8ab2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -237,8 +237,8 @@ def which_attn_to_use(head_size: int, return _Backend.IPEX if current_platform.is_tpu(): - if (selected_backend != _Backend.PALLAS and - selected_backend != _Backend.PALLAS_VLLM_V1): + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) if use_v1: return _Backend.PALLAS_VLLM_V1 diff --git a/vllm/config.py b/vllm/config.py index 0d5ed5dc51e48..6760bcbc24c05 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1225,7 +1225,7 @@ def __init__(self, device: str = "auto") -> None: # Some device types require processing inputs on CPU if self.device_type in ["neuron", "openvino"]: self.device = torch.device("cpu") - # Device initialization should happen after initializing the + # Device initialization should happen after initializing the # distributed runtime. elif self.device_type in ["tpu"]: self.device = None diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 215afcbe0c9f7..98eeb07ba5cfb 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -151,12 +151,12 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= 1 num_new_tokens = 1 computed_blocks.pop() - + # If chunked prefill is not enabled, breakout of the loop. - if (not self.scheduler_config.chunked_prefill_enabled and - num_new_tokens > token_budget): + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): break - + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 new_blocks = self.kv_cache_manager.allocate_slots( diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py index 910fd8c2f5583..5e6e63086946d 100644 --- a/vllm/v1/executor/tpu_executor.py +++ b/vllm/v1/executor/tpu_executor.py @@ -10,6 +10,7 @@ # import torch_xla.debug.profiler as xp + class TPUExecutor: def __init__(self, vllm_config: VllmConfig) -> None: @@ -32,11 +33,10 @@ def __init__(self, vllm_config: VllmConfig) -> None: # self.server = xp.start_server(9012) def _create_worker( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None - ) -> TPUWorker: + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> TPUWorker: """Return worker init args for a given rank.""" if distributed_init_method is None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ef5dfcb0954c1..a57387c2797bd 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -41,18 +41,17 @@ class PrefillData: attn_metadata: List def zipped(self): - return zip(self.request_ids, - self.prompt_lens, - self.token_ids, - self.position_ids, - self.attn_metadata) + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + @dataclass class DecodeData: num_decodes: int token_ids: torch.Tensor position_ids: torch.Tensor attn_metadata: PallasAttentionMetadata - + class TPUModelRunner: @@ -115,8 +114,7 @@ def __init__( self.prefill_positions = torch.tensor( range(self.max_model_len), device="cpu", - ).to(torch.int32).reshape(1,-1) - + ).to(torch.int32).reshape(1, -1) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -202,7 +200,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state = self.requests[req_id] self.input_batch.add_request(req_state, None) - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -210,7 +207,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_reqs = self.input_batch.num_reqs num_decodes = self.input_batch.num_decodes num_prefills = self.input_batch.num_prefills - + assert num_decodes + num_prefills > 0 # Get the number of scheduled tokens for each request. @@ -219,7 +216,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) - + # Assert Decodes Are Decodes. if idx < num_decodes: assert num_tokens == 1 @@ -227,7 +224,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): ######################### PREFILLS ######################### # Prefills run separately, each with shape [1, padded_prompt_len], # due to lack of variable length flashattention. - # + # # Due to static shapes, prefills are padded to the nearest power # of two, such that we can avoid recompilation. @@ -251,28 +248,27 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # TOKEN_IDS. prefill_token_ids.append( torch.from_numpy( - self.input_batch.token_ids_cpu[idx:idx+1, :padded_prompt_len] - ).to(self.device) - ) + self.input_batch.token_ids_cpu[idx:idx + + 1, :padded_prompt_len]).to( + self.device)) # POSITIONS. positions = self.prefill_positions[:, :padded_prompt_len] - prefill_position_ids.append( - positions.to(self.device) - ) + prefill_position_ids.append(positions.to(self.device)) # SLOT_MAPPING. # The "slot" is the "physical index" of a token in the KV cache. # We look up the block_idx in the block table (logical <> physical map) # to compute this. - block_numbers = self.input_batch.block_table_cpu_tensor[idx, positions // self.block_size].reshape(1,-1) + block_numbers = self.input_batch.block_table_cpu_tensor[ + idx, positions // self.block_size].reshape(1, -1) block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets # Set an out of range value for the padding tokens so that they # are ignored when inserting into the KV cache. slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() - + # ATTN_METADATA. prefill_attn_metadata.append( PallasAttentionMetadata( @@ -280,8 +276,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): slot_mapping=slot_mapping.to(self.device), block_tables=None, context_lens=None, - ) - ) + )) prefill_data = PrefillData( request_ids=prefill_request_ids, @@ -293,23 +288,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): if num_decodes == 0: return prefill_data, None - + ######################### DECODES ######################### # Decodes run as one single padded batch with shape [batch, 1] # - # We need to set _PAD_SLOT_ID for the padding tokens in the + # We need to set _PAD_SLOT_ID for the padding tokens in the # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indicies. Otherwise, the + # logic knows to ignore those indicies. Otherwise, the # padding data can be dummy since we have a causal mask. # PAD FOR STATIC SHAPES. padded_batch_size = _get_padded_batch_size(num_decodes) - + # POSITIONS. [batch, 1] # We slice at the end, since we use the positions for gathering. positions = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1,1) - ) + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) index = positions.to(torch.int64) positions = positions[:padded_batch_size] @@ -327,8 +321,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_number = torch.gather( input=self.input_batch.block_table_cpu_tensor, dim=1, - index=(index // self.block_size) - ) + index=(index // self.block_size)) block_offsets = index % self.block_size slot_mapping = block_number * self.block_size + block_offsets # Set an out of range value for the padding tokens so that they @@ -337,23 +330,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): slot_mapping = slot_mapping[:padded_batch_size] # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = self.input_batch.block_table_cpu_tensor[:padded_batch_size] - + block_table = self.input_batch.block_table_cpu_tensor[: + padded_batch_size] + # CONTEXT_LENS [batch_size] - context_lens = (positions.reshape(-1) + 1) - + context_lens = (positions.reshape(-1) + 1) + # CPU<>TPU sync happens here. - decode_data = DecodeData( - num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasAttentionMetadata( - is_prompt=False, - slot_mapping=slot_mapping.to(self.device), - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - ) - ) + decode_data = DecodeData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + )) return prefill_data, decode_data @@ -372,7 +364,6 @@ def _prepare_sampling( sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) return sampling_metadata - def execute_model( self, scheduler_output: "SchedulerOutput", @@ -387,23 +378,22 @@ def execute_model( if decode_data: num_decodes = decode_data.num_decodes - selected_token_ids = self.model( - decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, - self.kv_caches, - is_prompt=False - ) - + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + # NOTE: TPU<>CPU sync happens here. # It is important to call .cpu() first to avoid compilation on hotpath. token_ids = selected_token_ids.cpu()[:num_decodes] sampled_token_ids_list = token_ids.tolist() sampled_token_ids[:num_decodes] = token_ids - for i, req_id in enumerate(self.input_batch.req_ids[:decode_data.num_decodes]): + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): req_state = self.requests[req_id] - + # NO CHUNKED PREFILL assert scheduler_output.num_scheduled_tokens[req_id] == 1 seq_len = (req_state.num_computed_tokens + @@ -415,18 +405,15 @@ def execute_model( req_state.output_token_ids.append(token_id) ########## PREFILLS ########## - for idx, (req_id, prompt_len, - token_ids, position_ids, + for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): # [padded_prompt_len] - selected_token_ids = self.model( - token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True - ) + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) # NOTE: TPU<>CPU sync happens here. # It is important to call .cpu() first to avoid compilation on hotpath. token_id = selected_token_ids.cpu()[prompt_len - 1].item() @@ -437,7 +424,7 @@ def execute_model( if req_state.num_computed_tokens > 0: breakpoint() assert req_state.num_computed_tokens == 0 - seq_len = (req_state.num_computed_tokens + + seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) # TODO: chunked prefill. @@ -481,30 +468,19 @@ def load_model(self) -> None: xm.wait_device_ops() self.model = ModelWrapper(model) - def _dummy_run( - self, - batch_size: int, - seq_len: int, - kv_caches: List[torch.Tensor], - is_prompt: bool - ) -> None: + def _dummy_run(self, batch_size: int, seq_len: int, + kv_caches: List[torch.Tensor], is_prompt: bool) -> None: """Dummy warmup run for memory usage and graph compilation.""" - input_ids = torch.zeros( - (batch_size, seq_len), - dtype=torch.int32, - device=self.device - ) - position_ids = torch.zeros( - (batch_size, seq_len), - dtype=torch.int32, - device=self.device - ) - slot_mapping = torch.zeros( - (batch_size, seq_len), - dtype=torch.int64, - device=self.device - ) + input_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) block_tables = None if is_prompt else torch.zeros( (batch_size, self.max_num_blocks_per_req), dtype=torch.int32, @@ -521,7 +497,7 @@ def _dummy_run( block_tables=block_tables, context_lens=context_lens, ) - + # NOTE: There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile # overhead by reusing the FX graph for different shapes. @@ -560,31 +536,31 @@ def profile_run(self) -> None: dummy_kv_caches = [( torch.tensor([], dtype=torch.float32, device=self.device), torch.tensor([], dtype=torch.float32, device=self.device), - ) for _ in range(self.num_attn_layers) - ] + ) for _ in range(self.num_attn_layers)] # Round to multiple of 16. seq_len = (self.max_num_tokens + 15) // 16 * 16 # Run empty forward. - self._dummy_run( - batch_size=1, - seq_len=seq_len, - kv_caches=dummy_kv_caches, - is_prompt=True) - + self._dummy_run(batch_size=1, + seq_len=seq_len, + kv_caches=dummy_kv_caches, + is_prompt=True) def capture_model(self) -> None: """Compile the model.""" - + logger.info("Compiling the model with different input shapes.") - + # Prefill shapes. start = time.perf_counter() for batch_size in [1]: seq_len = 16 while True: - self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=True) + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=True) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) if seq_len >= self.model_config.max_model_len: @@ -602,7 +578,10 @@ def capture_model(self) -> None: seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self._dummy_run(batch_size, seq_len, self.kv_caches, is_prompt=False) + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=False) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -613,7 +592,6 @@ def capture_model(self) -> None: end = time.time() logger.info("Compilation for decode done in %.2f s.", end - start) - def initialize_kv_cache(self, num_blocks: int) -> None: assert len(self.kv_caches) == 0 kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( @@ -621,8 +599,8 @@ def initialize_kv_cache(self, num_blocks: int) -> None: for _ in range(self.num_attn_layers): self.kv_caches.append(( torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), + dtype=self.kv_cache_dtype, + device=self.device), torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device), @@ -908,6 +886,7 @@ def no_logprob(self) -> bool: def no_prompt_logprob(self) -> bool: return len(self.prompt_logprob_reqs) == 0 + class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model: nn.Module): @@ -996,6 +975,7 @@ def _get_padded_batch_size(batch_size: int) -> int: else: return ((batch_size + 15) // 16) * 16 + def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index b016922460251..ecdc88745fa59 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -22,14 +22,11 @@ logger = init_logger(__name__) + class TPUWorker: - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str - ): + + def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, + distributed_init_method: str): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -41,7 +38,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config - + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -91,11 +88,10 @@ def initialize(self): per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}") xr.initialize_cache(per_rank_path, readonly=False) - + def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -108,33 +104,30 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - - return 3144, 0 - - # self.model_runner.profile_run() - - # # Synchronize before measuring the memory usage. - # xm.wait_device_ops() - - # # Get the maximum amount of memory used by the model weights and - # # intermediate activations. - # m = xm.get_memory_info(self.device) - # total_tpu_memory = m["bytes_limit"] - # peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. - # logger.debug("Peak Used: %sGB", - # peak_memory // 1024 // 1024 // 1024) - # logger.debug("Total Memory: %sGB", - # total_tpu_memory // 1024 // 1024 // 1024) - - # cache_block_size = _get_cache_block_size(self.cache_config, - # self.model_config, - # self.parallel_config) - # num_tpu_blocks = int( - # (total_tpu_memory * self.cache_config.gpu_memory_utilization - - # peak_memory) // cache_block_size) - # num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return num_tpu_blocks, 0 + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", + peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 def initialize_cache(self, num_tpu_blocks: int) -> None: """Allocate TPU and CPU KV cache with the specified number of blocks.""" @@ -143,7 +136,7 @@ def initialize_cache(self, num_tpu_blocks: int) -> None: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") - + max_seq_len = self.cache_config.block_size * num_tpu_blocks max_model_len = self.model_config.max_model_len if max_model_len > max_seq_len: @@ -161,11 +154,11 @@ def initialize_cache(self, num_tpu_blocks: int) -> None: xm.mark_step() xm.wait_device_ops() m = xm.get_memory_info(self.device) - peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", peak_memory // 1024 // 1024 // 1024) - def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -174,7 +167,6 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) - def execute_model( self, scheduler_output: "SchedulerOutput", @@ -203,4 +195,4 @@ def _get_cache_block_size( else: dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype_size = get_dtype_size(dtype) - return dtype_size * total \ No newline at end of file + return dtype_size * total diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 4e0f52a6bac39..a721186137328 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -417,7 +417,6 @@ def _prepare_decode( block_tables=block_tables, context_lens=context_lens, ) - return input_tokens, input_positions, attn_metadata, input_lens def _prepare_sample( From 85bc15403644c8ce3e5ef758c42d3515205071b9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 21:28:48 +0000 Subject: [PATCH 28/33] cleanup --- benchmarks/benchmark_throughput.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e64241c1a6e3c..18f21463cc5a3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -7,7 +7,6 @@ from typing import List, Optional import torch -# import torch_xla.debug.metrics as met import uvloop from PIL import Image from tqdm import tqdm @@ -150,8 +149,6 @@ def run_vllm( use_beam_search = False - # met.clear_all() - if not use_beam_search: start = time.perf_counter() llm.generate(prompts, sampling_params, use_tqdm=True) @@ -172,9 +169,6 @@ def run_vllm( )) end = time.perf_counter() - # print(met.metrics_report()) - # print(met.short_metrics_report()) - return end - start From 25fff99c658178bd70ce92fa7a461781371e3075 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 21:29:06 +0000 Subject: [PATCH 29/33] cleanup pr --- benchmarks/benchmark_throughput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 18f21463cc5a3..159cf055737ce 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -168,7 +168,6 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - return end - start From 5a87b9984719fedc8fe49ff790dfc7e225dbd779 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 21:53:12 +0000 Subject: [PATCH 30/33] formatting --- vllm/v1/attention/backends/pallas.py | 2 +- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 163 +++++++++++++-------------- vllm/v1/worker/tpu_worker.py | 8 +- 5 files changed, 86 insertions(+), 91 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index dc976981e7fa3..b2cdc06ee78cb 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -6,7 +6,7 @@ import torch_xla from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionType) class PallasAttentionBackend(AttentionBackend): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 5b2561ecc98bf..080fa46c6d54a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,8 +9,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams from vllm.platforms import current_platform +from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index af8f28377f31a..4d419b8f97bfa 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -8,9 +8,9 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a57387c2797bd..04750d42646db 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,7 +1,6 @@ import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple -from unittest.mock import patch import numpy as np import torch @@ -15,8 +14,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cdiv, - is_pin_memory_available) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput @@ -33,7 +31,8 @@ @dataclass -class PrefillData: +class PrefillInputData: + request_ids: List prompt_lens: List token_ids: List @@ -46,11 +45,12 @@ def zipped(self): @dataclass -class DecodeData: +class DecodeInputData: + num_decodes: int - token_ids: torch.Tensor - position_ids: torch.Tensor - attn_metadata: PallasAttentionMetadata + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasAttentionMetadata = None class TPUModelRunner: @@ -200,33 +200,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state = self.requests[req_id] self.input_batch.add_request(req_state, None) - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - - num_reqs = self.input_batch.num_reqs - num_decodes = self.input_batch.num_decodes - num_prefills = self.input_batch.num_prefills - - assert num_decodes + num_prefills > 0 - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - - # Assert Decodes Are Decodes. - if idx < num_decodes: - assert num_tokens == 1 - - ######################### PREFILLS ######################### - # Prefills run separately, each with shape [1, padded_prompt_len], - # due to lack of variable length flashattention. - # - # Due to static shapes, prefills are padded to the nearest power - # of two, such that we can avoid recompilation. + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Prefills run separately, each with shape [1, prompt_len], + # due to lack of variable length flashattention, so we + # create a list that will be used in execute_model() prefill_request_ids = [] prefill_prompt_lens = [] @@ -236,6 +216,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # DECODES are the first num_decodes REQUESTS. # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes for idx in range(num_decodes, num_reqs): prefill_request_ids.append(self.input_batch.req_ids[idx]) @@ -246,11 +228,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): assert padded_prompt_len <= self.max_model_len # TOKEN_IDS. - prefill_token_ids.append( - torch.from_numpy( - self.input_batch.token_ids_cpu[idx:idx + - 1, :padded_prompt_len]).to( - self.device)) + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(-1, 1)) + prefill_token_ids.append(token_ids.to(self.device)) # POSITIONS. positions = self.prefill_positions[:, :padded_prompt_len] @@ -258,7 +238,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # SLOT_MAPPING. # The "slot" is the "physical index" of a token in the KV cache. - # We look up the block_idx in the block table (logical <> physical map) + # Look up the block_idx in the block table (logical<>physical map) # to compute this. block_numbers = self.input_batch.block_table_cpu_tensor[ idx, positions // self.block_size].reshape(1, -1) @@ -278,7 +258,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): context_lens=None, )) - prefill_data = PrefillData( + return PrefillInputData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, token_ids=prefill_token_ids, @@ -286,10 +266,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): attn_metadata=prefill_attn_metadata, ) - if num_decodes == 0: - return prefill_data, None - - ######################### DECODES ######################### + def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: # Decodes run as one single padded batch with shape [batch, 1] # # We need to set _PAD_SLOT_ID for the padding tokens in the @@ -297,6 +274,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # logic knows to ignore those indicies. Otherwise, the # padding data can be dummy since we have a causal mask. + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + # PAD FOR STATIC SHAPES. padded_batch_size = _get_padded_batch_size(num_decodes) @@ -316,7 +296,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # SLOT_MAPPING [batch, 1] # The "slot" is the "physical index" of a token in the KV cache. - # We look up the block_idx in the block table (logical <> physical map) + # Look up the block_idx in the block table (logical<>physical map) # to compute this. block_number = torch.gather( input=self.input_batch.block_table_cpu_tensor, @@ -337,17 +317,41 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): context_lens = (positions.reshape(-1) + 1) # CPU<>TPU sync happens here. - decode_data = DecodeData(num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasAttentionMetadata( - is_prompt=False, - slot_mapping=slot_mapping.to(self.device), - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - )) - - return prefill_data, decode_data + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + )) + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + + # Assert Decodes Are Decodes. + if idx < num_decodes: + assert num_tokens == 1 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) def _prepare_sampling( self, @@ -373,11 +377,11 @@ def execute_model( num_reqs = self.input_batch.num_reqs sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) - ########## DECODES ########## - num_decodes = 0 - if decode_data: - num_decodes = decode_data.num_decodes + ######################### DECODES ######################### + # Decodes run as one single padded batch with shape [batch, 1] + if decode_data.num_decodes > 0: + # FORWARD. selected_token_ids = self.model(decode_data.token_ids, decode_data.position_ids, decode_data.attn_metadata, @@ -385,16 +389,17 @@ def execute_model( is_prompt=False) # NOTE: TPU<>CPU sync happens here. - # It is important to call .cpu() first to avoid compilation on hotpath. - token_ids = selected_token_ids.cpu()[:num_decodes] + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] sampled_token_ids_list = token_ids.tolist() - sampled_token_ids[:num_decodes] = token_ids + sampled_token_ids[:decode_data.num_decodes] = token_ids + # UPDATE REQUEST STATE. for i, req_id in enumerate( self.input_batch.req_ids[:decode_data.num_decodes]): req_state = self.requests[req_id] - # NO CHUNKED PREFILL + # TODO: ASSERT NO CHUNKED PREFILL. assert scheduler_output.num_scheduled_tokens[req_id] == 1 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -404,34 +409,33 @@ def execute_model( self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) - ########## PREFILLS ########## + ######################### PREFILLS ######################### for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): - # [padded_prompt_len] + # FORWARD. selected_token_ids = self.model(token_ids, position_ids, attn_metadata, self.kv_caches, is_prompt=True) + # NOTE: TPU<>CPU sync happens here. - # It is important to call .cpu() first to avoid compilation on hotpath. + # We need to call .cpu() first to avoid recompilation. token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids[num_decodes + idx] = token_id + sampled_token_ids[decode_data.num_decodes + idx] = token_id req_state = self.requests[req_id] - # TODO: prefix caching. - if req_state.num_computed_tokens > 0: - breakpoint() + # TODO: ASSERT NO PREFIX CACHING. assert req_state.num_computed_tokens == 0 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - # TODO: chunked prefill. + # TODO: ASSERT NO CHUNKED PREFILL. assert seq_len == req_state.num_tokens assert prompt_len == seq_len - # Append the sampled token to the output token ids. + # UPDATE REQUEST STATE. req_idx = self.input_batch.req_id_to_index[req_id] self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id req_state.output_token_ids.append(token_id) @@ -606,15 +610,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None: device=self.device), )) - def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8 (== MIN_BATCH_SIZE). - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - @dataclass class CachedRequestState: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ecdc88745fa59..866c1dbf6ea98 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -8,10 +8,10 @@ import torch_xla.runtime as xr import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.v1.outputs import ModelRunnerOutput @@ -114,9 +114,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # intermediate activations. m = xm.get_memory_info(self.device) total_tpu_memory = m["bytes_limit"] - peak_memory = m["peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", - peak_memory // 1024 // 1024 // 1024) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) logger.debug("Total Memory: %sGB", total_tpu_memory // 1024 // 1024 // 1024) From 63b301a88ec19bfb054c47ae7f98509a55960338 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 22:21:49 +0000 Subject: [PATCH 31/33] updated --- vllm/v1/worker/tpu_model_runner.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 04750d42646db..048782d5e7b43 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -204,9 +204,8 @@ def _prepare_prefill_inputs( self, num_scheduled_tokens: List[int], ) -> PrefillInputData: - # Prefills run separately, each with shape [1, prompt_len], - # due to lack of variable length flashattention, so we - # create a list that will be used in execute_model() + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). prefill_request_ids = [] prefill_prompt_lens = [] @@ -229,7 +228,7 @@ def _prepare_prefill_inputs( # TOKEN_IDS. token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(-1, 1)) + idx, :padded_prompt_len].reshape(1, -1)) prefill_token_ids.append(token_ids.to(self.device)) # POSITIONS. @@ -258,6 +257,10 @@ def _prepare_prefill_inputs( context_lens=None, )) + print(f"PREFILL {token_ids.shape=}") + print(f"PREFILL {positions.shape=}") + print(f"PREFILL {slot_mapping.shape=}") + return PrefillInputData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, @@ -316,6 +319,12 @@ def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: # CONTEXT_LENS [batch_size] context_lens = (positions.reshape(-1) + 1) + print(f"{token_ids.shape=}") + print(f"{positions.shape=}") + print(f"{slot_mapping.shape=}") + print(f"{block_table.shape=}") + print(f"{context_lens.shape=}") + # CPU<>TPU sync happens here. return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids.to(self.device), @@ -344,7 +353,7 @@ def _prepare_inputs( num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) - # Assert Decodes Are Decodes. + # NOTE: assert that all the decodes are "decodes". if idx < num_decodes: assert num_tokens == 1 @@ -368,6 +377,7 @@ def _prepare_sampling( sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) return sampling_metadata + @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -378,7 +388,7 @@ def execute_model( sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) ######################### DECODES ######################### - # Decodes run as one single padded batch with shape [batch, 1] + # Decodes run as one single batch with [padded_batch, 1] if decode_data.num_decodes > 0: # FORWARD. @@ -410,6 +420,8 @@ def execute_model( req_state.output_token_ids.append(token_id) ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): @@ -440,14 +452,13 @@ def execute_model( self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id req_state.output_token_ids.append(token_id) - model_runner_output = ModelRunnerOutput( + return ModelRunnerOutput( req_ids=self.input_batch.req_ids[:num_reqs], req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids_cpu=sampled_token_ids, logprob_token_ids_cpu=None, logprobs_cpu=None, ) - return model_runner_output def load_model(self) -> None: From 1af03e022a5b1490183fb31fb740af2811e65ddb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 22:25:22 +0000 Subject: [PATCH 32/33] updated --- vllm/v1/worker/tpu_model_runner.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 048782d5e7b43..e88e4d9b57d77 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -257,10 +257,6 @@ def _prepare_prefill_inputs( context_lens=None, )) - print(f"PREFILL {token_ids.shape=}") - print(f"PREFILL {positions.shape=}") - print(f"PREFILL {slot_mapping.shape=}") - return PrefillInputData( request_ids=prefill_request_ids, prompt_lens=prefill_prompt_lens, @@ -319,12 +315,6 @@ def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: # CONTEXT_LENS [batch_size] context_lens = (positions.reshape(-1) + 1) - print(f"{token_ids.shape=}") - print(f"{positions.shape=}") - print(f"{slot_mapping.shape=}") - print(f"{block_table.shape=}") - print(f"{context_lens.shape=}") - # CPU<>TPU sync happens here. return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids.to(self.device), From 02ee3042ef0f8dbb08b38c40e4386d29e10341cf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 17 Nov 2024 22:55:52 +0000 Subject: [PATCH 33/33] fixed accuracy bug --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e88e4d9b57d77..7963fe4973b55 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -305,7 +305,7 @@ def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: slot_mapping = block_number * self.block_size + block_offsets # Set an out of range value for the padding tokens so that they # are ignored when inserting into the KV cache. - slot_mapping[-num_decodes:] = _PAD_SLOT_ID + slot_mapping[num_decodes:] = _PAD_SLOT_ID slot_mapping = slot_mapping[:padded_batch_size] # BLOCK_TABLE [batch, max_num_blocks_per_req]