Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
  • Loading branch information
comaniac committed Nov 8, 2024
1 parent 33cacf9 commit 9c56442
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 208 deletions.
15 changes: 9 additions & 6 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import (KVCacheManager, Request,
hash_block_tokens)
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens


def make_request(request_id, prompt_token_ids):
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_prefill():
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == [block_id] * 16
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash

# Check partial/preallocated block metadata
Expand Down Expand Up @@ -144,7 +144,8 @@ def test_decode():

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
req0.output_token_ids = [8] * 4
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11
Expand All @@ -154,7 +155,8 @@ def test_decode():
req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
req0.output_token_ids += [7] * (5 + 10)
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
Expand All @@ -164,7 +166,8 @@ def test_decode():
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
req0.output_token_ids += [12] * (6 + 11)
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
Expand Down
219 changes: 17 additions & 202 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,147 +1,16 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional

from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request

logger = init_logger(__name__)


@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id: int
# Reference count.
ref_cnt: int = 0
# Token IDs in the block.
token_ids: List[int] = field(default_factory=list)
# The hash of the block. It is only available when the block is full.
block_hash: Optional[int] = None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens: int = 0

# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None

def reset(self):
"""Reset the block metadata."""
self.ref_cnt = 0
self.token_ids.clear()
self.block_hash = None
self.num_hashed_tokens = 0


class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""

def __init__(self, blocks: List[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)

# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]

def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
raise ValueError("No free blocks available")

block = self.free_list_head
self.remove(block)
return block

def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block

if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block

# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
self.num_free_blocks -= 1

def append(self, block: KVCacheBlock) -> None:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block

block.next_free_block = None
self.num_free_blocks += 1

def get_all_free_blocks(self) -> List[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret


class KVCacheManager:

def __init__(
Expand Down Expand Up @@ -187,7 +56,7 @@ def __init__(
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[int, Dict[
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)

# Mapping from request ID to blocks to track the blocks allocated
Expand All @@ -210,7 +79,8 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
return []

computed_blocks = []
block_hashes = self.hash_prompt_tokens(request.prompt_token_ids)
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -255,18 +125,9 @@ def append_slots(
parent_block = None
if self.enable_caching:
# Figure out the token IDs to add to the blocks.
if request.num_computed_tokens < request.num_prompt_tokens:
# (Chunked) Prefill.
new_token_ids = request.prompt_token_ids[
request.num_computed_tokens:request.num_computed_tokens +
num_tokens]
else:
# Decode.
num_computed_output_tokens = (request.num_computed_tokens -
request.num_prompt_tokens)
new_token_ids = request.output_token_ids[
num_computed_output_tokens:num_computed_output_tokens +
num_tokens]
new_token_ids = request.all_token_ids[
request.num_computed_tokens:request.num_computed_tokens +
num_tokens]

# Find the last full block index.
# TODO: This may be optimized by calculating the computed tokens.
Expand Down Expand Up @@ -353,15 +214,12 @@ def allocate_slots(
self._touch(computed_blocks)

# Get the token IDs for the blocks being allocated for hashing.
# Note that we expect allocate_slots to be called only once per
# new request, so num_computed_tokens + num_tokens must be less
# than or equal to the total number of tokens in the prompt.
new_token_ids = request.prompt_token_ids[
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
raise RuntimeError(
"Failed to infer the token IDs for allocation. "
f"#prompt_tokens={len(request.prompt_token_ids)} < "
f"#all_tokens={len(request.all_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")

# Get the parent block ID to construct the block chain.
Expand Down Expand Up @@ -458,14 +316,16 @@ def _cache_full_block(self,
"""
parent_block_hash = (parent_block.block_hash
if parent_block is not None else None)
block_hash = hash_block_tokens(parent_block_hash,
tuple(block.token_ids))
assert len(block.token_ids) == self.block_size
block.token_ids = tuple(block.token_ids)
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
block.block_hash = block_hash
block.num_hashed_tokens = self.block_size + (
parent_block.num_hashed_tokens if parent_block is not None else 0)
self.cached_block_hash_to_block[block_hash][block.block_id] = block

def _get_cached_block(self, block_hash: int) -> Optional[KVCacheBlock]:
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Expand Down Expand Up @@ -534,48 +394,3 @@ def _add_token_ids_to_blocks(
parent_block = curr_block
token_id_start = token_id_end
return token_id_start

def hash_prompt_tokens(self, token_ids: List[int]) -> List[int]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
token_ids: A sequence of token ids in the prompt.
Returns:
The list of computed hash values.
"""
ret = []
parent_block_hash = None
for start in range(0, len(token_ids), self.block_size):
end = start + self.block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < self.block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
return ret


def hash_block_tokens(parent_block_hash: Optional[int],
cur_block_token_ids: Tuple[int]) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
cur_block_token_ids: A tuple of token ids in the current
block. The current block is assumed to be full.
Returns:
The computed hash value for the block.
"""
return hash((parent_block_hash, *cur_block_token_ids))
Loading

0 comments on commit 9c56442

Please sign in to comment.