Skip to content

Commit

Permalink
[V1] Add BlockTable class (#11693)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Jan 6, 2025
1 parent 408e560 commit 06bfb51
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 25 deletions.
78 changes: 78 additions & 0 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List

import numpy as np
import torch

from vllm.logger import init_logger

logger = init_logger(__name__)


class BlockTable:

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device

self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks

def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)

def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks

def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)

def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)

def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table

def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu

def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
25 changes: 9 additions & 16 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
Expand Down Expand Up @@ -70,19 +71,14 @@ def __init__(
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()

# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
Expand Down Expand Up @@ -193,8 +189,7 @@ def add_request(
self.num_tokens[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
self.block_table.add_row(req_index, request.block_ids)

sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
Expand Down Expand Up @@ -300,9 +295,7 @@ def condense(self, empty_req_indices: List[int]) -> None:
self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
# TODO(woosuk): Optimize the copy of block_table_cpu.
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)

req_ids_to_add: List[str] = []
# Add new requests to the cached states.
Expand Down Expand Up @@ -275,9 +274,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
self.input_batch.block_table.commit(num_reqs)

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
Expand Down Expand Up @@ -333,8 +330,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
[block_table_indices].numpy())
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
Expand Down Expand Up @@ -450,7 +447,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
Expand Down

0 comments on commit 06bfb51

Please sign in to comment.