Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] Refactor KVCacheManager for more hash input than token ids #10507

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 206 additions & 19 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Compare the with and without prefix caching."""
import pytest

from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens


def make_request(request_id, prompt_token_ids):
Expand Down Expand Up @@ -31,7 +34,8 @@ def test_prefill():
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
Expand All @@ -40,24 +44,16 @@ def test_prefill():
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_hash = hash_block_tokens(parent_block_hash,
manager.block_pool[block_id].token_ids)
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
if block_id == 3:
assert manager.block_pool[block_id].token_ids == [3] * 7
else:
assert not manager.block_pool[block_id].token_ids

# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
Expand Down Expand Up @@ -113,7 +109,7 @@ def test_prefill():
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
rickyyx marked this conversation as resolved.
Show resolved Hide resolved
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
Expand Down Expand Up @@ -148,7 +144,7 @@ def test_decode():
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None

# Append slots without allocating a new block, but start using the
# preallocated block.
Expand All @@ -159,8 +155,7 @@ def test_decode():
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
assert len(manager.block_pool[4].token_ids) == 10
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None

# Append slots with allocating a new block.
req0.num_computed_tokens = 74
Expand All @@ -171,9 +166,6 @@ def test_decode():
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0


def test_evict():
Expand Down Expand Up @@ -217,3 +209,198 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6


def test_hash_block_correct_reuse():
"""
This tests when a previously cached block is reused as a new block,
its hash metadata should be correctly reset.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)

# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1

# Deallocate the block.
manager.free(req)

# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

assert manager.block_pool[blocks[0].block_id].block_hash is None


def test_computed_blocks_not_evicted():
"""
Test that the computed blocks are not evicted when getting new blocks
for a request if there are any other free blocks.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)

# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0

# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1

# Free the blocks.
manager.free(req0)
manager.free(req1)

# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0

blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1


def test_basic_prefix_caching_disabled():
"""
This tests that the prefix caching is disabled.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
sliding_window=False,
enable_caching=False,
num_preallocate_tokens=0,
)

req1 = make_request("1", list(range(10))) # 2 blocks and some more

computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3

# Free the blocks.
manager.free(req1)

# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4

# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks


@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
assert len(blocks) == 1 + num_preallocated_blocks

# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
assert len(blocks) == 1 + num_preallocated_blocks


def test_cache_blocks():
"""
This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=5,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
# Block 1: [4, 5, 6, 7]
# Block 2: [8, 9, 10, 11]
# Block 3: [12, 13]
req = make_request("0", list(range(14)))

# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]

manager._cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
)

assert len(manager.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])

# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None
Loading