From 25b79d9fd38e2c53ce281be23241d8939ec7320c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 9 Dec 2024 12:33:41 -0500 Subject: [PATCH] [V1] Input Batch Relocation (#10962) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- vllm/v1/worker/gpu_input_batch.py | 280 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 273 +--------------------------- 2 files changed, 283 insertions(+), 270 deletions(-) create mode 100644 vllm/v1/worker/gpu_input_batch.py diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000..457784bb0287c --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,280 @@ +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..7f95be06188e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ import gc import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -15,16 +14,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -609,269 +608,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0