-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1] Input Batch Relocation (#10962)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
- Loading branch information
1 parent
aea2fc3
commit 25b79d9
Showing
2 changed files
with
283 additions
and
270 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
# Datastructures defining an input batch | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Set | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from vllm.multimodal import MultiModalKwargs | ||
from vllm.sampling_params import SamplingParams, SamplingType | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
if TYPE_CHECKING: | ||
from vllm.multimodal.inputs import PlaceholderRange | ||
|
||
|
||
@dataclass | ||
class CachedRequestState: | ||
|
||
req_id: str | ||
prompt_token_ids: List[int] | ||
prompt: Optional[str] | ||
mm_inputs: List[MultiModalKwargs] | ||
mm_positions: List["PlaceholderRange"] | ||
sampling_params: SamplingParams | ||
generator: Optional[torch.Generator] | ||
|
||
block_ids: List[int] | ||
num_computed_tokens: int | ||
output_token_ids: List[int] | ||
|
||
@property | ||
def num_tokens(self) -> int: | ||
return len(self.prompt_token_ids) + len(self.output_token_ids) | ||
|
||
|
||
class InputBatch: | ||
|
||
def __init__( | ||
self, | ||
max_num_reqs: int, | ||
max_model_len: int, | ||
max_num_blocks_per_req: int, | ||
device: torch.device, | ||
pin_memory: bool, | ||
): | ||
self.max_num_reqs = max_num_reqs | ||
self.max_model_len = max_model_len | ||
self.max_num_blocks_per_req = max_num_blocks_per_req | ||
self.device = device | ||
self.pin_memory = pin_memory | ||
|
||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs | ||
self.req_id_to_index: Dict[str, int] = {} | ||
|
||
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), | ||
dtype=np.int32) | ||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) | ||
|
||
# Attention-related. | ||
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), | ||
device=self.device, | ||
dtype=torch.int32) | ||
self.block_table_cpu_tensor = torch.zeros( | ||
(max_num_reqs, max_num_blocks_per_req), | ||
device="cpu", | ||
dtype=torch.int32, | ||
pin_memory=pin_memory, | ||
) | ||
self.block_table_cpu = self.block_table_cpu_tensor.numpy() | ||
|
||
# Sampling-related. | ||
self.temperature = torch.empty((max_num_reqs, ), | ||
dtype=torch.float32, | ||
device=device) | ||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), | ||
dtype=torch.float32, | ||
device="cpu", | ||
pin_memory=pin_memory) | ||
self.temperature_cpu = self.temperature_cpu_tensor.numpy() | ||
self.greedy_reqs: Set[str] = set() | ||
self.random_reqs: Set[str] = set() | ||
|
||
self.top_p = torch.empty((max_num_reqs, ), | ||
dtype=torch.float32, | ||
device=device) | ||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), | ||
dtype=torch.float32, | ||
device="cpu", | ||
pin_memory=pin_memory) | ||
self.top_p_cpu = self.top_p_cpu_tensor.numpy() | ||
self.top_p_reqs: Set[str] = set() | ||
|
||
self.top_k = torch.empty((max_num_reqs, ), | ||
dtype=torch.int32, | ||
device=device) | ||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), | ||
dtype=torch.int32, | ||
device="cpu", | ||
pin_memory=pin_memory) | ||
self.top_k_cpu = self.top_k_cpu_tensor.numpy() | ||
self.top_k_reqs: Set[str] = set() | ||
|
||
# req_index -> generator | ||
self.generators: Dict[int, torch.Generator] = {} | ||
|
||
self.num_logprobs: Dict[str, int] = {} | ||
self.prompt_logprob_reqs: Set[str] = set() | ||
|
||
def add_request( | ||
self, | ||
request: "CachedRequestState", | ||
req_index: Optional[int] = None, | ||
) -> None: | ||
if req_index is None: | ||
req_index = self.num_reqs | ||
assert req_index < self.max_num_reqs | ||
|
||
req_id = request.req_id | ||
self.req_ids[req_index] = req_id | ||
self.req_id_to_index[req_id] = req_index | ||
|
||
# Copy the prompt token ids and output token ids. | ||
num_prompt_tokens = len(request.prompt_token_ids) | ||
self.token_ids_cpu[ | ||
req_index, :num_prompt_tokens] = request.prompt_token_ids | ||
start_idx = num_prompt_tokens | ||
end_idx = start_idx + len(request.output_token_ids) | ||
self.token_ids_cpu[req_index, | ||
start_idx:end_idx] = request.output_token_ids | ||
|
||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens | ||
num_blocks = len(request.block_ids) | ||
self.block_table_cpu[req_index, :num_blocks] = request.block_ids | ||
|
||
sampling_params = request.sampling_params | ||
self.temperature_cpu[req_index] = sampling_params.temperature | ||
if sampling_params.sampling_type == SamplingType.GREEDY: | ||
self.greedy_reqs.add(req_id) | ||
else: | ||
self.random_reqs.add(req_id) | ||
|
||
self.top_p_cpu[req_index] = sampling_params.top_p | ||
if sampling_params.top_p < 1: | ||
self.top_p_reqs.add(req_id) | ||
self.top_k_cpu[req_index] = sampling_params.top_k | ||
if sampling_params.top_k > 0: | ||
self.top_k_reqs.add(req_id) | ||
|
||
self.generators[req_index] = request.generator | ||
|
||
num_logprobs = sampling_params.logprobs | ||
if num_logprobs is not None and num_logprobs > 0: | ||
self.num_logprobs[req_id] = num_logprobs | ||
if sampling_params.prompt_logprobs: | ||
self.prompt_logprob_reqs.add(req_id) | ||
|
||
def remove_request(self, req_id: str) -> Optional[int]: | ||
req_index = self.req_id_to_index.pop(req_id, None) | ||
if req_index is None: | ||
return None | ||
self.req_ids[req_index] = None | ||
|
||
self.greedy_reqs.discard(req_id) | ||
self.random_reqs.discard(req_id) | ||
self.top_p_reqs.discard(req_id) | ||
self.top_k_reqs.discard(req_id) | ||
self.generators.pop(req_index, None) | ||
self.num_logprobs.pop(req_id, None) | ||
self.prompt_logprob_reqs.discard(req_id) | ||
return req_index | ||
|
||
def clear(self) -> None: | ||
self.req_ids = [None] * self.max_num_reqs | ||
self.req_id_to_index.clear() | ||
self.greedy_reqs.clear() | ||
self.random_reqs.clear() | ||
self.top_p_reqs.clear() | ||
self.top_k_reqs.clear() | ||
self.generators.clear() | ||
self.num_logprobs.clear() | ||
self.prompt_logprob_reqs.clear() | ||
|
||
def condense(self, empty_req_indices: List[int]) -> None: | ||
if self.num_reqs == 0: | ||
# The batched states are empty. | ||
return | ||
|
||
# NOTE(woosuk): This function assumes that the empty_req_indices | ||
# is sorted in descending order. | ||
last_req_index = self.num_reqs + len(empty_req_indices) - 1 | ||
while empty_req_indices: | ||
# Find the largest non-empty index. | ||
while last_req_index in empty_req_indices: | ||
last_req_index -= 1 | ||
|
||
# Find the smallest empty index. | ||
empty_index = empty_req_indices.pop() | ||
if empty_index >= last_req_index: | ||
break | ||
|
||
# Swap the states. | ||
req_id = self.req_ids[last_req_index] | ||
self.req_ids[empty_index] = req_id | ||
self.req_ids[last_req_index] = None | ||
self.req_id_to_index[req_id] = empty_index | ||
|
||
# TODO(woosuk): Optimize the copy of token_ids_cpu and | ||
# block_table_cpu. | ||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[ | ||
last_req_index] | ||
self.num_computed_tokens_cpu[ | ||
empty_index] = self.num_computed_tokens_cpu[last_req_index] | ||
self.block_table_cpu[empty_index] = self.block_table_cpu[ | ||
last_req_index] | ||
self.temperature_cpu[empty_index] = self.temperature_cpu[ | ||
last_req_index] | ||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] | ||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] | ||
generator = self.generators.pop(last_req_index, None) | ||
if generator is not None: | ||
self.generators[empty_index] = generator | ||
|
||
# Decrement last_req_index since it is now empty. | ||
last_req_index -= 1 | ||
|
||
def make_sampling_metadata( | ||
self, | ||
skip_copy: bool = False, | ||
) -> SamplingMetadata: | ||
if not skip_copy: | ||
self.temperature[:self.num_reqs].copy_( | ||
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) | ||
self.top_p[:self.num_reqs].copy_( | ||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) | ||
self.top_k[:self.num_reqs].copy_( | ||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) | ||
return SamplingMetadata( | ||
temperature=self.temperature[:self.num_reqs], | ||
all_greedy=self.all_greedy, | ||
all_random=self.all_random, | ||
top_p=self.top_p[:self.num_reqs], | ||
top_k=self.top_k[:self.num_reqs], | ||
no_top_p=self.no_top_p, | ||
no_top_k=self.no_top_k, | ||
generators=self.generators, | ||
max_num_logprobs=self.max_num_logprobs, | ||
) | ||
|
||
@property | ||
def num_reqs(self) -> int: | ||
return len(self.req_id_to_index) | ||
|
||
@property | ||
def all_greedy(self) -> bool: | ||
return len(self.random_reqs) == 0 | ||
|
||
@property | ||
def all_random(self) -> bool: | ||
return len(self.greedy_reqs) == 0 | ||
|
||
@property | ||
def no_top_p(self) -> bool: | ||
return len(self.top_p_reqs) == 0 | ||
|
||
@property | ||
def no_top_k(self) -> bool: | ||
return len(self.top_k_reqs) == 0 | ||
|
||
@property | ||
def max_num_logprobs(self) -> int: | ||
return max(self.num_logprobs.values()) if self.num_logprobs else 0 | ||
|
||
@property | ||
def no_logprob(self) -> bool: | ||
return len(self.num_logprobs) == 0 | ||
|
||
@property | ||
def no_prompt_logprob(self) -> bool: | ||
return len(self.prompt_logprob_reqs) == 0 |
Oops, something went wrong.