Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx committed Nov 22, 2024
1 parent 4de5815 commit c2bfbd5
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Compare the with and without prefix caching."""
import pytest

from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens

Expand Down Expand Up @@ -246,6 +249,10 @@ def test_hash_block_correct_reuse():


def test_computed_blocks_not_evicted():
"""
Test that the computed blocks are not evicted when getting new blocks
for a request if there are any other free blocks.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
Expand Down Expand Up @@ -290,6 +297,9 @@ def test_computed_blocks_not_evicted():


def test_basic_prefix_caching_disabled():
"""
This tests that the prefix caching is disabled.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
Expand Down Expand Up @@ -324,8 +334,32 @@ def test_basic_prefix_caching_disabled():
assert not blocks


def test_preallocate_blocks():
pass
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
assert len(blocks) == 1 + num_preallocated_blocks

# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
assert len(blocks) == 1 + num_preallocated_blocks


def test_cache_blocks():
Expand Down

0 comments on commit c2bfbd5

Please sign in to comment.