Skip to content

Commit

Permalink
[Core] Add support for multimodal models + prefix caching
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Salas <[email protected]>
  • Loading branch information
petersalas committed Nov 8, 2024
1 parent 10b67d8 commit 342d3d0
Show file tree
Hide file tree
Showing 29 changed files with 802 additions and 208 deletions.
15 changes: 6 additions & 9 deletions tests/core/block/test_block_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from vllm.core.block.token_ids import TokenIds
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
from vllm.core.block_manager import SelfAttnBlockSpaceManager
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list

from ..utils import (create_dummy_prompt, create_seq_group,
create_seq_group_encoder_decoder)
Expand Down Expand Up @@ -248,14 +248,11 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
block_manager.get_num_free_gpu_blocks())

# Expect consumed blocks to be new blocks required to support the new slots.
expected_consumed_blocks = len(
list(
chunk_list(
list(
range(prompt_len + num_slots_to_append +
num_lookahead_slots)),
block_size))) - len(
list(chunk_list(list(range(prompt_len)), block_size)))
required_blocks = list(
TokenIds(range(prompt_len + num_slots_to_append + num_lookahead_slots),
()).to_chunks(block_size))
existing_blocks = list(TokenIds(range(prompt_len)).to_chunks(block_size))
expected_consumed_blocks = len(required_blocks) - len(existing_blocks)
assert num_consumed_blocks == expected_consumed_blocks


Expand Down
77 changes: 43 additions & 34 deletions tests/core/block/test_block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@

from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
from vllm.core.block.token_ids import TokenIds
from vllm.utils import Device, cdiv


def _get_all_token_ids(block_table: BlockTable) -> TokenIds:
token_ids = TokenIds()
for block in block_table.blocks:
token_ids += block.token_ids
return token_ids


@pytest.mark.parametrize("block_size", [16])
Expand All @@ -27,8 +35,8 @@ def test_allocate_naive(block_size: int, sequence_len: int):
block_size=block_size,
)

token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
token_ids = TokenIds(range(sequence_len))
num_blocks_per_alloc = len(list(token_ids.to_chunks(block_size)))

block_tables: List[BlockTable] = []
for i in range(5):
Expand Down Expand Up @@ -68,8 +76,8 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int):
block_size=block_size,
)

token_ids = list(range(sequence_len))
chunked_tokens = list(chunk_list(token_ids, block_size))
token_ids = TokenIds(range(sequence_len))
chunked_tokens = list(token_ids.to_chunks(block_size))
num_mutable_blocks_per_alloc = 0 if len(
chunked_tokens[-1]) == block_size else 1
num_immutable_blocks_per_alloc = len(
Expand Down Expand Up @@ -117,8 +125,8 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str,
block_size=block_size,
)

token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
token_ids = TokenIds(range(sequence_len))
num_blocks_per_alloc = len(list(token_ids.to_chunks(block_size)))

block_table = BlockTable(
block_size=block_size,
Expand Down Expand Up @@ -160,19 +168,19 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
token_ids = TokenIds(range(sequence_len))
token_ids_to_append = TokenIds(range(append_len))

block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)

num_expected_blocks_before_append = len(
list(chunk_list(token_ids, block_size)))
list(token_ids.to_chunks(block_size)))
num_expected_appended_blocks = len(
list(chunk_list(token_ids + token_ids_to_append,
block_size))) - num_expected_blocks_before_append
list((token_ids + token_ids_to_append
).to_chunks(block_size))) - num_expected_blocks_before_append

block_table.allocate(token_ids=token_ids, device=Device.GPU)

Expand Down Expand Up @@ -210,18 +218,19 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids = TokenIds(range(sequence_len))
empty_slots = TokenIds([-1] * num_empty_slots)

block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)

num_expected_blocks_before_append = len(
list(chunk_list(token_ids, block_size)))
list(token_ids.to_chunks(block_size)))
num_expected_appended_blocks = len(
list(chunk_list(token_ids + [-1] * num_empty_slots,
block_size))) - num_expected_blocks_before_append
list((token_ids + empty_slots
).to_chunks(block_size))) - num_expected_blocks_before_append

block_table.allocate(token_ids=token_ids, device=Device.GPU)

Expand All @@ -236,7 +245,7 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,

# Now, ensure no additional blocks consumed as we fill up the empty slots.
num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU)
block_table.append_token_ids(token_ids=list(range(num_empty_slots)))
block_table.append_token_ids(token_ids=TokenIds(range(num_empty_slots)))
assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU)


Expand All @@ -261,23 +270,23 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
token_ids = TokenIds(range(sequence_len))
token_ids_to_append = TokenIds(range(append_len))

block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
block_table.allocate(token_ids=token_ids, device=Device.GPU)

appended_so_far: List[int] = []
for append in chunk_list(token_ids_to_append, append_size):
appended_so_far: TokenIds = TokenIds()
for append in token_ids_to_append.to_chunks(append_size):
block_table.append_token_ids(append)
appended_so_far.extend(append)
appended_so_far += append

assert block_table._get_all_token_ids() == token_ids + appended_so_far
assert _get_all_token_ids(block_table) == token_ids + appended_so_far

assert block_table._get_all_token_ids() == token_ids + token_ids_to_append
assert _get_all_token_ids(block_table) == token_ids + token_ids_to_append


@pytest.mark.parametrize("seq_len", [1, 9, 129])
Expand All @@ -302,7 +311,7 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str):
block_size=block_size,
)

token_ids = list(range(seq_len))
token_ids = TokenIds(range(seq_len))

block_table = BlockTable(
block_size=block_size,
Expand All @@ -319,8 +328,8 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str):
# Expect physical_block_ids and token_ids to match.
assert (block_table.physical_block_ids ==
forked_block_table.physical_block_ids)
assert block_table._get_all_token_ids(
) == forked_block_table._get_all_token_ids()
assert _get_all_token_ids(block_table) == _get_all_token_ids(
forked_block_table)

# Do not expect any additional allocations.
assert allocator.get_num_free_blocks(
Expand Down Expand Up @@ -360,8 +369,8 @@ def test_cow(block_size: int, sequence_len: int, append_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
token_ids = TokenIds(range(sequence_len))
token_ids_to_append = TokenIds(range(append_len))

original_block_table = BlockTable(
block_size=block_size,
Expand Down Expand Up @@ -446,8 +455,8 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
token_ids = TokenIds(range(sequence_len))
token_ids_to_append = TokenIds(range(append_len))

original_block_table = BlockTable(
block_size=block_size,
Expand Down Expand Up @@ -528,8 +537,8 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(num_new_tokens))
token_ids = TokenIds(range(sequence_len))
token_ids_to_append = TokenIds(range(num_new_tokens))

block_table = BlockTable(
block_size=block_size,
Expand All @@ -548,7 +557,7 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
# Determine how many blocks should be touched.
expected_num_touched_blocks = (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=token_ids_to_append,
num_token_ids=len(token_ids_to_append),
num_lookahead_slots=num_lookahead_slots))

# Measure how many blocks are touched by measuring num_free_blocks before
Expand Down
15 changes: 8 additions & 7 deletions tests/core/block/test_cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.utils import Device, chunk_list
from vllm.core.block.token_ids import TokenIds
from vllm.utils import Device


@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
Expand Down Expand Up @@ -56,12 +57,12 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
block_size=block_size,
)

unique_token_ids = list(
range((num_cpu_blocks + num_gpu_blocks) * block_size))
gpu_token_ids = list(
chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size))
cpu_token_ids = list(
chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size))
unique_token_ids = TokenIds(
range((num_cpu_blocks + num_gpu_blocks) * block_size), ())
gpu_token_ids = list(unique_token_ids[:num_gpu_blocks *
block_size].to_chunks(block_size))
cpu_token_ids = list(unique_token_ids[num_gpu_blocks *
block_size:].to_chunks(block_size))

assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
Expand Down
17 changes: 9 additions & 8 deletions tests/core/block/test_naive_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Optional
from typing import Optional

import pytest

from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.token_ids import TokenIds


class TestNaiveBlockAllocator:
Expand All @@ -12,7 +13,7 @@ class TestNaiveBlockAllocator:
def create_allocate_lambda(allocate_type: str,
allocator: NaiveBlockAllocator,
prev_block: Optional[Block],
token_ids: List[int]):
token_ids: TokenIds):
if allocate_type == "immutable":
allocate_block = lambda: allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=token_ids)
Expand All @@ -37,7 +38,7 @@ def test_allocate_ooms(allocate_type: str, num_blocks: int,
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
token_ids=TokenIds(range(block_size)))

[allocate_block() for _ in range(num_blocks)]
with pytest.raises(BlockAllocator.NoFreeBlocksError):
Expand All @@ -56,7 +57,7 @@ def test_free_prevents_oom(allocate_type: str, num_blocks: int,
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
token_ids=TokenIds(range(block_size)))

blocks = [allocate_block() for _ in range(num_blocks)]

Expand Down Expand Up @@ -91,7 +92,7 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
token_ids=TokenIds(range(block_size)))

assert allocator.get_num_free_blocks() == num_blocks

Expand Down Expand Up @@ -120,7 +121,7 @@ def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size):
"immutable",
allocator_src,
prev_block=None,
token_ids=list(range(block_size)))
token_ids=TokenIds(range(block_size)))
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]

# All blocks are cached
Expand All @@ -134,12 +135,12 @@ def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size):
prev_block=src_blocks[-1],token_ids=[]
)
src_blocks.append(allocate_non_full_block())
src_blocks[-1].append_token_ids([0])
src_blocks[-1].append_token_ids(TokenIds([0]))

assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks - 1
# Fill up the last source block and then invoke
# get_num_blocks_touched
src_blocks[-1].append_token_ids([0] * (block_size - 1))
src_blocks[-1].append_token_ids(TokenIds([0] * (block_size - 1)))
assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks
Loading

0 comments on commit 342d3d0

Please sign in to comment.