diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index cfd749ad58694..11de444407f5a 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -1,11 +1,11 @@ import pytest +from vllm.core.block.token_ids import TokenIds from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager import SelfAttnBlockSpaceManager from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus -from vllm.utils import chunk_list from ..utils import (create_dummy_prompt, create_seq_group, create_seq_group_encoder_decoder) @@ -248,14 +248,11 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, block_manager.get_num_free_gpu_blocks()) # Expect consumed blocks to be new blocks required to support the new slots. - expected_consumed_blocks = len( - list( - chunk_list( - list( - range(prompt_len + num_slots_to_append + - num_lookahead_slots)), - block_size))) - len( - list(chunk_list(list(range(prompt_len)), block_size))) + required_blocks = list( + TokenIds(range(prompt_len + num_slots_to_append + num_lookahead_slots), + ()).to_chunks(block_size)) + existing_blocks = list(TokenIds(range(prompt_len)).to_chunks(block_size)) + expected_consumed_blocks = len(required_blocks) - len(existing_blocks) assert num_consumed_blocks == expected_consumed_blocks diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index e2391a5680b36..bf8ce4ac1b8a8 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -4,7 +4,15 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, cdiv, chunk_list +from vllm.core.block.token_ids import TokenIds +from vllm.utils import Device, cdiv + + +def _get_all_token_ids(block_table: BlockTable) -> TokenIds: + token_ids = TokenIds() + for block in block_table.blocks: + token_ids += block.token_ids + return token_ids @pytest.mark.parametrize("block_size", [16]) @@ -27,8 +35,8 @@ def test_allocate_naive(block_size: int, sequence_len: int): block_size=block_size, ) - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) + token_ids = TokenIds(range(sequence_len)) + num_blocks_per_alloc = len(list(token_ids.to_chunks(block_size))) block_tables: List[BlockTable] = [] for i in range(5): @@ -68,8 +76,8 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): block_size=block_size, ) - token_ids = list(range(sequence_len)) - chunked_tokens = list(chunk_list(token_ids, block_size)) + token_ids = TokenIds(range(sequence_len)) + chunked_tokens = list(token_ids.to_chunks(block_size)) num_mutable_blocks_per_alloc = 0 if len( chunked_tokens[-1]) == block_size else 1 num_immutable_blocks_per_alloc = len( @@ -117,8 +125,8 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, block_size=block_size, ) - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) + token_ids = TokenIds(range(sequence_len)) + num_blocks_per_alloc = len(list(token_ids.to_chunks(block_size))) block_table = BlockTable( block_size=block_size, @@ -160,8 +168,8 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) + token_ids = TokenIds(range(sequence_len)) + token_ids_to_append = TokenIds(range(append_len)) block_table = BlockTable( block_size=block_size, @@ -169,10 +177,10 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, ) num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) + list(token_ids.to_chunks(block_size))) num_expected_appended_blocks = len( - list(chunk_list(token_ids + token_ids_to_append, - block_size))) - num_expected_blocks_before_append + list((token_ids + token_ids_to_append + ).to_chunks(block_size))) - num_expected_blocks_before_append block_table.allocate(token_ids=token_ids, device=Device.GPU) @@ -210,7 +218,8 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) + token_ids = TokenIds(range(sequence_len)) + empty_slots = TokenIds([-1] * num_empty_slots) block_table = BlockTable( block_size=block_size, @@ -218,10 +227,10 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, ) num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) + list(token_ids.to_chunks(block_size))) num_expected_appended_blocks = len( - list(chunk_list(token_ids + [-1] * num_empty_slots, - block_size))) - num_expected_blocks_before_append + list((token_ids + empty_slots + ).to_chunks(block_size))) - num_expected_blocks_before_append block_table.allocate(token_ids=token_ids, device=Device.GPU) @@ -236,7 +245,7 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, # Now, ensure no additional blocks consumed as we fill up the empty slots. num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) - block_table.append_token_ids(token_ids=list(range(num_empty_slots))) + block_table.append_token_ids(token_ids=TokenIds(range(num_empty_slots))) assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU) @@ -261,8 +270,8 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) + token_ids = TokenIds(range(sequence_len)) + token_ids_to_append = TokenIds(range(append_len)) block_table = BlockTable( block_size=block_size, @@ -270,14 +279,14 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, ) block_table.allocate(token_ids=token_ids, device=Device.GPU) - appended_so_far: List[int] = [] - for append in chunk_list(token_ids_to_append, append_size): + appended_so_far: TokenIds = TokenIds() + for append in token_ids_to_append.to_chunks(append_size): block_table.append_token_ids(append) - appended_so_far.extend(append) + appended_so_far += append - assert block_table._get_all_token_ids() == token_ids + appended_so_far + assert _get_all_token_ids(block_table) == token_ids + appended_so_far - assert block_table._get_all_token_ids() == token_ids + token_ids_to_append + assert _get_all_token_ids(block_table) == token_ids + token_ids_to_append @pytest.mark.parametrize("seq_len", [1, 9, 129]) @@ -302,7 +311,7 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): block_size=block_size, ) - token_ids = list(range(seq_len)) + token_ids = TokenIds(range(seq_len)) block_table = BlockTable( block_size=block_size, @@ -319,8 +328,8 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): # Expect physical_block_ids and token_ids to match. assert (block_table.physical_block_ids == forked_block_table.physical_block_ids) - assert block_table._get_all_token_ids( - ) == forked_block_table._get_all_token_ids() + assert _get_all_token_ids(block_table) == _get_all_token_ids( + forked_block_table) # Do not expect any additional allocations. assert allocator.get_num_free_blocks( @@ -360,8 +369,8 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) + token_ids = TokenIds(range(sequence_len)) + token_ids_to_append = TokenIds(range(append_len)) original_block_table = BlockTable( block_size=block_size, @@ -446,8 +455,8 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) + token_ids = TokenIds(range(sequence_len)) + token_ids_to_append = TokenIds(range(append_len)) original_block_table = BlockTable( block_size=block_size, @@ -528,8 +537,8 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, block_size=block_size, ) - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(num_new_tokens)) + token_ids = TokenIds(range(sequence_len)) + token_ids_to_append = TokenIds(range(num_new_tokens)) block_table = BlockTable( block_size=block_size, @@ -548,7 +557,7 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, # Determine how many blocks should be touched. expected_num_touched_blocks = ( block_table.get_num_blocks_touched_by_append_slots( - token_ids=token_ids_to_append, + num_token_ids=len(token_ids_to_append), num_lookahead_slots=num_lookahead_slots)) # Measure how many blocks are touched by measuring num_free_blocks before diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py index a9e38d40444a9..ee3c65594c15b 100644 --- a/tests/core/block/test_cpu_gpu_block_allocator.py +++ b/tests/core/block/test_cpu_gpu_block_allocator.py @@ -1,7 +1,8 @@ import pytest from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, chunk_list +from vllm.core.block.token_ids import TokenIds +from vllm.utils import Device @pytest.mark.parametrize("num_cpu_blocks", [0, 512]) @@ -56,12 +57,12 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, block_size=block_size, ) - unique_token_ids = list( - range((num_cpu_blocks + num_gpu_blocks) * block_size)) - gpu_token_ids = list( - chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) - cpu_token_ids = list( - chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size)) + unique_token_ids = TokenIds( + range((num_cpu_blocks + num_gpu_blocks) * block_size), ()) + gpu_token_ids = list(unique_token_ids[:num_gpu_blocks * + block_size].to_chunks(block_size)) + cpu_token_ids = list(unique_token_ids[num_gpu_blocks * + block_size:].to_chunks(block_size)) assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index 10d5964dcfe8a..ab56a35aeda24 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -1,9 +1,10 @@ -from typing import List, Optional +from typing import Optional import pytest from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.token_ids import TokenIds class TestNaiveBlockAllocator: @@ -12,7 +13,7 @@ class TestNaiveBlockAllocator: def create_allocate_lambda(allocate_type: str, allocator: NaiveBlockAllocator, prev_block: Optional[Block], - token_ids: List[int]): + token_ids: TokenIds): if allocate_type == "immutable": allocate_block = lambda: allocator.allocate_immutable_block( prev_block=prev_block, token_ids=token_ids) @@ -37,7 +38,7 @@ def test_allocate_ooms(allocate_type: str, num_blocks: int, allocate_type, allocator, prev_block=None, - token_ids=list(range(block_size))) + token_ids=TokenIds(range(block_size))) [allocate_block() for _ in range(num_blocks)] with pytest.raises(BlockAllocator.NoFreeBlocksError): @@ -56,7 +57,7 @@ def test_free_prevents_oom(allocate_type: str, num_blocks: int, allocate_type, allocator, prev_block=None, - token_ids=list(range(block_size))) + token_ids=TokenIds(range(block_size))) blocks = [allocate_block() for _ in range(num_blocks)] @@ -91,7 +92,7 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int, allocate_type, allocator, prev_block=None, - token_ids=list(range(block_size))) + token_ids=TokenIds(range(block_size))) assert allocator.get_num_free_blocks() == num_blocks @@ -120,7 +121,7 @@ def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): "immutable", allocator_src, prev_block=None, - token_ids=list(range(block_size))) + token_ids=TokenIds(range(block_size))) src_blocks = [allocate_block() for _ in range(num_blocks - 1)] # All blocks are cached @@ -134,12 +135,12 @@ def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): prev_block=src_blocks[-1],token_ids=[] ) src_blocks.append(allocate_non_full_block()) - src_blocks[-1].append_token_ids([0]) + src_blocks[-1].append_token_ids(TokenIds([0])) assert allocator_dst.get_num_full_blocks_touched( src_blocks) == num_blocks - 1 # Fill up the last source block and then invoke # get_num_blocks_touched - src_blocks[-1].append_token_ids([0] * (block_size - 1)) + src_blocks[-1].append_token_ids(TokenIds([0] * (block_size - 1))) assert allocator_dst.get_num_full_blocks_touched( src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index d325b9606843e..d73128b9e8ca0 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -8,6 +8,7 @@ from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, PrefixCachingBlockAllocator) +from vllm.core.block.token_ids import TokenIds class TestPrefixCachingBlock: @@ -23,7 +24,7 @@ def test_first_block_has_correct_content_hash(seed: int, block_size: int, random.seed(seed) num_to_fill = block_size if is_curr_block_full else random.randint( 0, block_size - 1) - token_ids = list(range(num_to_fill)) + token_ids = TokenIds(range(num_to_fill)) mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) block_with_prev = PrefixCachingBlock(prev_block=None, @@ -63,7 +64,7 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int, num_to_fill = block_size if is_curr_block_full else random.randint( 0, block_size - 1) - token_ids = list(range(num_to_fill)) + token_ids = TokenIds(range(num_to_fill)) mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) block_with_prev = PrefixCachingBlock( @@ -97,7 +98,8 @@ def test_blocks_have_correct_hash_in_chain(block_size: int, """ random.seed(0) - token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] + token_ids = TokenIds( + [random.randint(0, 50_000) for _ in range(num_tokens)]) first_chain, second_chain = (TestPrefixCachingBlock.create_chain( block_size=block_size, @@ -116,7 +118,7 @@ def test_blocks_have_correct_hash_in_chain(block_size: int, @staticmethod def create_chain(block_size: int, - token_ids: List[int], + token_ids: TokenIds, num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]: """Helper method which creates a chain of blocks. """ @@ -176,7 +178,7 @@ def test_allocate_mutable_ooms(num_blocks: int, block_size: int): allocate_type="mutable", allocator=allocator, prev_block=None, - token_ids=list(range(block_size)), + token_ids=TokenIds(range(block_size)), ) [allocate_block() for _ in range(num_blocks)] @@ -194,7 +196,7 @@ def test_allocate_immutable_does_not_oom_single_hash( allocate_type="immutable", allocator=allocator, prev_block=None, - token_ids=list(range(block_size)), + token_ids=TokenIds(range(block_size)), ) blocks = [allocate_block() for _ in range(num_blocks)] @@ -220,7 +222,7 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int, block_size=block_size) # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) + token_ids = TokenIds(range(num_blocks * block_size)) chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -231,7 +233,7 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int, # Expect allocation with unseen hash to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): allocator.allocate_immutable_block(prev_block=chain[-1], - token_ids=list( + token_ids=TokenIds( range(block_size))) # Expect mutable allocation to fail. @@ -258,7 +260,7 @@ def test_free_prevents_oom(num_blocks: int, block_size: int): block_size=block_size) # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) + token_ids = TokenIds(range(num_blocks * block_size)) chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -297,7 +299,7 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) + token_ids = TokenIds(range(num_blocks_to_consume * block_size)) chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -327,7 +329,7 @@ def test_prefix_caching_block_get_num_full_blocks_touched( block_size=block_size) # Create token ids that will exhaust all blocks except the last - token_ids = list(range((num_blocks - 1) * block_size)) + token_ids = TokenIds(range((num_blocks - 1) * block_size)) # Create a chain of cacheable blocks in the dst cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( @@ -358,13 +360,13 @@ def test_prefix_caching_block_get_num_full_blocks_touched( # Insert one non-full block in the src non_full_block = allocator_src.allocate_mutable_block( blocks_to_swap_in[-1]) - non_full_block.append_token_ids([0]) + non_full_block.append_token_ids(TokenIds([0])) blocks_to_swap_in.append(non_full_block) assert allocator_dst.get_num_full_blocks_touched( blocks_to_swap_in) == 1 # Fill up the last mutable block and invoke get_num_blocks_touched. # Note: The last block is not cached so it will be touched. - non_full_block.append_token_ids([0] * (block_size - 1)) + non_full_block.append_token_ids(TokenIds([0] * (block_size - 1))) assert allocator_dst.get_num_full_blocks_touched( blocks_to_swap_in) == 2 @@ -383,7 +385,7 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) + token_ids = TokenIds(range(num_blocks_to_consume * block_size)) first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -428,7 +430,7 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) + token_ids = TokenIds(range(num_blocks_to_consume * block_size)) first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -440,7 +442,8 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, # make it different from here comparing with first_chain zero_point = random.randint(1, len(token_ids) - 1) zero_point_blocks = zero_point // block_size - token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) + token_ids = token_ids[:zero_point] + TokenIds( + [-1] * (len(token_ids) - zero_point)) second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -471,7 +474,7 @@ def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, block_size=block_size) - token_ids = list(range(block_size)) + token_ids = TokenIds(range(block_size)) block = allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) @@ -481,7 +484,7 @@ def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): block_id = m.block_id for i in range(block_size): - m.append_token_ids([i]) + m.append_token_ids(TokenIds([i])) # After block get promoted to immutable from mutable, if there is # already same content hash block, then it shall be released into @@ -505,7 +508,7 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): one_ref = {i: 1 for i in range(num_blocks)} allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, block_size=block_size) - token_ids = list(range(num_blocks * block_size)) + token_ids = TokenIds(range(num_blocks * block_size)) # Verify initial/pre-alloc state @@ -628,7 +631,7 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): block_size=block_size) num_blocks_to_consume = num_blocks + 1 - token_ids = list(range(num_blocks_to_consume * block_size)) + token_ids = TokenIds(range(num_blocks_to_consume * block_size)) num_blocks_in_first_chain = 2 num_tokens_in_first_chain = block_size * num_blocks_in_first_chain @@ -690,7 +693,7 @@ def test_metric(): # Test when no query (0/0) assert allocator.get_prefix_cache_hit_rate() == 0.0 - token_ids = list(range(block_size)) + token_ids = TokenIds(range(block_size)) allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) # Test 0/1 hit rate @@ -716,7 +719,7 @@ def test_touch_block(): allocator = PrefixCachingBlockAllocator(num_blocks=8, block_size=block_size) - common_token_ids = list(range(block_size * common_blocks)) + common_token_ids = TokenIds(range(block_size * common_blocks)) # Mimic the behavior of allocating the same block chain # (i.e., common prefix) for a batch of 3 different prefill sequences. @@ -741,7 +744,7 @@ def test_touch_block(): @staticmethod def create_immutable_chain( block_size: int, - token_ids: List[int], + token_ids: TokenIds, allocator: PrefixCachingBlockAllocator, ) -> List[PrefixCachingBlock]: """Helper method which creates a chain of blocks. diff --git a/tests/core/block/test_token_ids.py b/tests/core/block/test_token_ids.py new file mode 100644 index 0000000000000..8b2f80d8d5a78 --- /dev/null +++ b/tests/core/block/test_token_ids.py @@ -0,0 +1,162 @@ +import pytest + +from vllm.core.block.token_ids import TokenIds, TokenRangeAnnotation + + +@pytest.mark.parametrize( + "value", + [ + # Must be contained within the token IDs. + [TokenRangeAnnotation(0, 0, -1, 2)], + [TokenRangeAnnotation(0, 0, 0, 5)], + [TokenRangeAnnotation(0, 0, 4, 5)], + + # Must not overlap. + [ + TokenRangeAnnotation(000, 0, 0, 1), + TokenRangeAnnotation(111, 0, 0, 1) + ], + [ + TokenRangeAnnotation(000, 0, 0, 2), + TokenRangeAnnotation(111, 0, 1, 3) + ], + [ + TokenRangeAnnotation(000, 0, 2, 3), + TokenRangeAnnotation(111, 0, 0, 4) + ], + + # Must be sorted. + [ + TokenRangeAnnotation(000, 0, 2, 3), + TokenRangeAnnotation(111, 0, 0, 1) + ], + [ + TokenRangeAnnotation(000, 0, 0, 1), + TokenRangeAnnotation(111, 0, 3, 4), + TokenRangeAnnotation(222, 0, 2, 3) + ], + ]) +def test_invalid_annotations_should_raise(value): + with pytest.raises(ValueError): + TokenIds(range(4), value) + + +@pytest.mark.parametrize("value", [ + TokenIds(()), + TokenIds((1, 2, 3)), + TokenIds((1, 2, 3), [TokenRangeAnnotation(0, 0, 0, 1)]) +]) +def test_token_ids_add_unit(value): + assert value + TokenIds() == value + assert TokenIds() + value == value + + +def test_token_ids_add_without_annotations(): + a = TokenIds((1, 2, 3)) + b = TokenIds((4, 5, 6)) + assert a + b == TokenIds((1, 2, 3, 4, 5, 6)) + + +def test_token_ids_add_with_annotations(): + a = TokenIds((1, 2, 3)) + b = TokenIds((4, 5, 6), [TokenRangeAnnotation(0, 0, 0, 1)]) + + assert a + b == TokenIds((1, 2, 3, 4, 5, 6), + [TokenRangeAnnotation(0, 0, 3, 4)]) + assert b + a == TokenIds((4, 5, 6, 1, 2, 3), + [TokenRangeAnnotation(0, 0, 0, 1)]) + + +def test_token_ids_add_can_coalesce(): + a = TokenIds((1, 2, 3), [TokenRangeAnnotation(111, 0, 1, 3)]) + b = TokenIds((4, 5, 6), [TokenRangeAnnotation(111, 2, 0, 1)]) + + assert a + b == TokenIds((1, 2, 3, 4, 5, 6), + [TokenRangeAnnotation(111, 0, 1, 4)]) + + +def test_token_ids_add_cannot_coalesce_different_offsets(): + a = TokenIds((1, 2, 3), [TokenRangeAnnotation(111, 0, 1, 3)]) + b = TokenIds((4, 5, 6), [TokenRangeAnnotation(111, 4, 0, 1)]) + + assert a + b == TokenIds((1, 2, 3, 4, 5, 6), [ + TokenRangeAnnotation(111, 0, 1, 3), + TokenRangeAnnotation(111, 4, 3, 4) + ]) + + +def test_token_ids_add_cannot_coalesce_different_hash(): + a = TokenIds((1, 2, 3), [TokenRangeAnnotation(111, 0, 1, 3)]) + b = TokenIds((4, 5, 6), [TokenRangeAnnotation(222, 2, 0, 1)]) + + assert a + b == TokenIds((1, 2, 3, 4, 5, 6), [ + TokenRangeAnnotation(111, 0, 1, 3), + TokenRangeAnnotation(222, 2, 3, 4) + ]) + + +def test_annotation_clipping(): + r = TokenRangeAnnotation(111, 0, 2, 5) + # Overlapping windows + assert r.clipped_to_slice(0, 1) is None + assert r.clipped_to_slice(0, 2) is None + assert r.clipped_to_slice(0, 3) == TokenRangeAnnotation(111, 0, 2, 3) + assert r.clipped_to_slice(0, 4) == TokenRangeAnnotation(111, 0, 2, 4) + assert r.clipped_to_slice(1, 5) == TokenRangeAnnotation(111, 0, 1, 4) + assert r.clipped_to_slice(2, 6) == TokenRangeAnnotation(111, 0, 0, 3) + assert r.clipped_to_slice(3, 7) == TokenRangeAnnotation(111, 1, 0, 2) + assert r.clipped_to_slice(4, 8) == TokenRangeAnnotation(111, 2, 0, 1) + assert r.clipped_to_slice(5, 9) is None + + # Interior windows + assert r.clipped_to_slice(2, 3) == TokenRangeAnnotation(111, 0, 0, 1) + assert r.clipped_to_slice(2, 4) == TokenRangeAnnotation(111, 0, 0, 2) + assert r.clipped_to_slice(2, 5) == TokenRangeAnnotation(111, 0, 0, 3) + assert r.clipped_to_slice(3, 4) == TokenRangeAnnotation(111, 1, 0, 1) + assert r.clipped_to_slice(3, 5) == TokenRangeAnnotation(111, 1, 0, 2) + assert r.clipped_to_slice(4, 5) == TokenRangeAnnotation(111, 2, 0, 1) + assert r.clipped_to_slice(3, 3) is None + + +def test_token_id_chunks(): + token_ids = TokenIds(range(8), [ + TokenRangeAnnotation(111, 0, 2, 5), + TokenRangeAnnotation(222, 0, 6, 7), + TokenRangeAnnotation(333, 0, 7, 8) + ]) + single_chunks = [ + TokenIds([0]), + TokenIds([1]), + TokenIds([2], [TokenRangeAnnotation(111, 0, 0, 1)]), + TokenIds([3], [TokenRangeAnnotation(111, 1, 0, 1)]), + TokenIds([4], [TokenRangeAnnotation(111, 2, 0, 1)]), + TokenIds([5]), + TokenIds([6], [TokenRangeAnnotation(222, 0, 0, 1)]), + TokenIds([7], [TokenRangeAnnotation(333, 0, 0, 1)]), + ] + + # Without overriding initial chunk size + for chunk_size in range(1, len(single_chunks) + 1): + chunks = list(token_ids.to_chunks(chunk_size)) + expected_chunks = [ + sum(single_chunks[i:i + chunk_size], start=TokenIds()) + for i in range(0, len(single_chunks), chunk_size) + ] + assert chunks == expected_chunks + + # With overriding first chunk size + for first_chunk_size in range(1, len(single_chunks) + 1): + first_chunk = sum(single_chunks[0:first_chunk_size], start=TokenIds()) + for chunk_size in range(1, len(single_chunks) + 1): + chunks = list( + token_ids.to_chunks(chunk_size, + first_chunk_size=first_chunk_size)) + expected_chunks = [first_chunk] + [ + sum(single_chunks[i:i + chunk_size], start=TokenIds()) for i in + range(first_chunk_size, len(single_chunks), chunk_size) + ] + assert chunks == expected_chunks + + # Slicing should be equivalent + for i in range(len(single_chunks)): + assert token_ids[i:] == sum(single_chunks[i:], start=TokenIds()) diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 88f3fad4c79f8..b8eb61566c983 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -4,6 +4,7 @@ import pytest from transformers import PreTrainedTokenizer +from vllm.core.block.token_ids import TokenIds from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker @@ -49,7 +50,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING - new_token_ids = list(range(num_new_tokens)) + new_token_ids = TokenIds(range(num_new_tokens)) outputs = [ CompletionSequenceGroupOutput( @@ -101,7 +102,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING - new_token_ids = list(range(num_new_tokens)) + new_token_ids = TokenIds(range(num_new_tokens)) outputs = [ CompletionSequenceGroupOutput( @@ -165,10 +166,11 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING - new_token_ids = list(range(num_new_tokens)) + new_token_ids = TokenIds(range(num_new_tokens)) assert eos_token_id not in new_token_ids eos_index = random.randint(0, len(new_token_ids) - 1) - new_token_ids[eos_index] = eos_token_id + new_token_ids = (new_token_ids[:eos_index] + TokenIds( + (eos_token_id, )) + new_token_ids[eos_index + 1:]) outputs = [ CompletionSequenceGroupOutput( @@ -234,10 +236,11 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING - new_token_ids = list(range(num_new_tokens)) + new_token_ids = TokenIds(range(num_new_tokens)) assert eos_token_id not in new_token_ids eos_index = random.randint(0, len(new_token_ids) - 1) - new_token_ids[eos_index] = eos_token_id + new_token_ids = (new_token_ids[:eos_index] + TokenIds( + (eos_token_id, )) + new_token_ids[eos_index + 1:]) outputs = [ CompletionSequenceGroupOutput( diff --git a/tests/prefix_caching/test_multi_modal_prefix_caching.py b/tests/prefix_caching/test_multi_modal_prefix_caching.py new file mode 100644 index 0000000000000..40e801ea0ea6e --- /dev/null +++ b/tests/prefix_caching/test_multi_modal_prefix_caching.py @@ -0,0 +1,84 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_multi_modal_prefix_caching.py`. +""" +from typing import Tuple + +import pytest +from transformers import AutoTokenizer + +from ..models.utils import check_logprobs_close + +MODEL_NAME = "fixie-ai/ultravox-v0_3" +AUDIO_PLACEHOLDER = "<|reserved_special_token_0|>" + + +@pytest.fixture +def prompt(): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + return tokenizer.apply_chat_template( + [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + 'role': 'user', + 'content': f"{AUDIO_PLACEHOLDER}\n\nDescribe the audio above." + }], + tokenize=False, + add_generation_prompt=True) + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("audio_asset_names", + [("winning_call", "mary_had_lamb")]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("max_tokens", [30]) +def test_multi_modal_prefix_caching( + vllm_runner, + prompt: str, + audio_asset_names: Tuple[str, str], + dtype: str, + num_logprobs: int, + max_tokens: int, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. + """ + from vllm.assets.audio import AudioAsset + + audios = [ + AudioAsset(asset).audio_and_sample_rate for asset in audio_asset_names + ] + prompts = [prompt for _ in audios] + + with vllm_runner( + MODEL_NAME, + dtype=dtype, + enable_prefix_caching=True, + ) as vllm_model: + # Run against the first prompt so the cache is populated + _ = vllm_model.generate_greedy(prompts[:1], + max_tokens, + audios=audios[:1]) + + # Run all the prompts + with_prefix_caching = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs, audios=audios) + + with vllm_runner( + MODEL_NAME, + dtype=dtype, + enable_prefix_caching=False, + ) as vllm_model: + # Run all the prompts + without_prefix_caching = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs, audios=audios) + + check_logprobs_close( + outputs_0_lst=with_prefix_caching, + outputs_1_lst=without_prefix_caching, + name_0="prefix_caching=True", + name_1="prefix_caching=False", + ) diff --git a/vllm/config.py b/vllm/config.py index bed58fcecb5cb..9d6d67b4ff208 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -240,6 +240,8 @@ def __init__( self.is_attention_free = self._init_attention_free() self.has_inner_state = self._init_has_inner_state() + self.supports_chunked_prefill = self._init_supports_chunked_prefill() + self.suports_prefix_caching = self._init_supports_prefix_caching() if current_platform.is_neuron(): self.override_neuron_config = override_neuron_config @@ -311,6 +313,14 @@ def _init_has_inner_state(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.model_has_inner_state(architectures) + def _init_supports_chunked_prefill(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.model_supports_chunked_prefill(architectures) + + def _init_supports_prefix_caching(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.model_supports_prefix_caching(architectures) + def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow", "mistral"]: diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d10cb29ef4a7c..8ae481e70b97e 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -3,7 +3,9 @@ from vllm.core.block.common import BlockList from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator -from vllm.utils import Device, cdiv, chunk_list +from vllm.core.block.token_ids import TokenIds +from vllm.sequence import Sequence +from vllm.utils import Device, cdiv class BlockTable: @@ -55,7 +57,7 @@ def __init__( self._num_full_slots = self._get_num_token_ids() @staticmethod - def get_num_required_blocks(token_ids: List[int], + def get_num_required_blocks(token_count: int, block_size: int, num_lookahead_slots: int = 0) -> int: """Calculates the minimum number of blocks required to store a given @@ -66,7 +68,7 @@ def get_num_required_blocks(token_ids: List[int], allocation (e.g. ignoring prefix caching). Args: - token_ids (List[int]): The sequence of token IDs to be stored. + token_ids (int): The number of token ids to be stored. block_size (int): The maximum number of tokens that can be stored in a single block. num_lookahead_slots (int): look-ahead slots that the sequence may @@ -76,10 +78,10 @@ def get_num_required_blocks(token_ids: List[int], int: The minimum number of blocks required to store the given sequence of token IDs along with any required look-ahead slots. """ - return cdiv(len(token_ids) + num_lookahead_slots, block_size) + return cdiv(token_count + num_lookahead_slots, block_size) def allocate(self, - token_ids: List[int], + token_ids: TokenIds, device: Device = Device.GPU) -> None: """Allocates memory blocks for storing the given sequence of token IDs. @@ -87,7 +89,8 @@ def allocate(self, sequence of token IDs. Args: - token_ids (List[int]): The sequence of token IDs to be stored. + token_ids (TokenIds): The sequence of token IDs to be + stored. device (Device, optional): The device on which the blocks should be allocated. Defaults to Device.GPU. """ @@ -106,7 +109,7 @@ def update(self, blocks: List[Block]) -> None: self._blocks.update(blocks) def append_token_ids(self, - token_ids: List[int], + token_ids: TokenIds, num_lookahead_slots: int = 0, num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the @@ -122,7 +125,7 @@ def append_token_ids(self, separate block. Args: - token_ids (List[int]): The sequence of token IDs to be appended. + token_ids (TokenIds): The token IDs to be appended. num_computed_slots (Optional[int]): The number of KV cache slots that are already filled (computed). When sliding window is enabled, this is used to compute how many @@ -240,33 +243,35 @@ def physical_block_ids(self) -> List[int]: """ return self._blocks.ids() - def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + def get_unseen_token_id_count(self, sequence: Sequence) -> int: + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return max(0, sequence.get_len() - self.num_full_slots) + + def get_unseen_token_ids(self, sequence: Sequence) -> TokenIds: """Get the number of "unseen" tokens in the sequence. Unseen tokens are tokens in the sequence corresponding to this block table, but are not yet appended to this block table. Args: - sequence_token_ids (List[int]): The list of token ids in the - sequence. + sequence (Sequence): The sequence. Returns: - List[int]: The postfix of sequence_token_ids that has not yet been - appended to the block table. + TokenIds: The postfix of the sequence's tokens that has + not yet been appended to the block table. """ - # Since the block table is append-only, the unseen token ids are the - # ones after the appended ones. - return sequence_token_ids[self.num_full_slots:] + return sequence.get_token_ids()[self._num_full_slots:] def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, device: Device) -> List[Block]: blocks: List[Block] = [] - block_token_ids = [] - tail_token_ids = [] - for cur_token_ids in chunk_list(token_ids, self._block_size): + block_token_ids: List[TokenIds] = [] + tail_token_ids: List[TokenIds] = [] + for cur_token_ids in token_ids.to_chunks(self._block_size): if len(cur_token_ids) == self._block_size: block_token_ids.append(cur_token_ids) else: @@ -291,18 +296,6 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], return blocks - def _get_all_token_ids(self) -> List[int]: - # NOTE: This function is O(seq_len); use sparingly. - token_ids: List[int] = [] - - if not self._is_allocated: - return token_ids - - for block in self.blocks: - token_ids.extend(block.token_ids) - - return token_ids - def _get_num_token_ids(self) -> int: res = 0 for block in self.blocks: @@ -334,7 +327,7 @@ def num_full_slots(self) -> int: return self._num_full_slots def get_num_blocks_touched_by_append_slots( - self, token_ids: List[int], num_lookahead_slots: int) -> int: + self, num_token_ids: int, num_lookahead_slots: int) -> int: """Determine how many blocks will be "touched" by appending the token ids. @@ -346,15 +339,15 @@ def get_num_blocks_touched_by_append_slots( # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) # return len(token_blocks) - num_token_ids = len(token_ids) + num_lookahead_slots + num_token_ids = num_token_ids + num_lookahead_slots first_chunk_size = self._block_size - (self._num_full_slots % self._block_size) num_token_blocks = (1 + math.ceil( (num_token_ids - first_chunk_size) / self._block_size)) return num_token_blocks - def _chunk_token_blocks_for_append( - self, token_ids: List[int]) -> List[List[int]]: + def _chunk_token_blocks_for_append(self, + token_ids: TokenIds) -> List[TokenIds]: """Split the token ids into block-sized chunks so they can be easily appended to blocks. The first such "token block" may have less token ids than the block size, since the last allocated block may be partially @@ -363,12 +356,7 @@ def _chunk_token_blocks_for_append( If no token ids are provided, then no chunks are returned. """ - if not token_ids: - return [] - - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] - token_blocks.extend( - chunk_list(token_ids[first_chunk_size:], self._block_size)) - return token_blocks + return list( + token_ids.to_chunks(self._block_size, + first_chunk_size=self._block_size - + (self._num_full_slots % self._block_size))) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index eb190adfbe802..bba4895607be3 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -3,6 +3,7 @@ from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.token_ids import TokenIds BlockId = int RefCount = int @@ -174,7 +175,7 @@ def __init__(self, block_size: int, create_block: Block.Factory, for i in range(self._pool_size): self._pool.append( self._create_block(prev_block=None, - token_ids=[], + token_ids=TokenIds(), block_size=self._block_size, allocator=self._allocator, block_id=None)) @@ -191,12 +192,12 @@ def increase_pool(self): for i in range(cur_pool_size, new_pool_size): self._pool.append( self._create_block(prev_block=None, - token_ids=[], + token_ids=TokenIds(), block_size=self._block_size, allocator=self._allocator, block_id=None)) - def init_block(self, prev_block: Optional[Block], token_ids: List[int], + def init_block(self, prev_block: Optional[Block], token_ids: TokenIds, block_size: int, physical_block_id: Optional[int]) -> Block: if len(self._free_ids) == 0: self.increase_pool() @@ -248,7 +249,7 @@ def update(self, blocks: List[Block]): for block in self._blocks: self._add_block_id(block.block_id) - def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + def append_token_ids(self, block_index: int, token_ids: TokenIds) -> None: block = self._blocks[block_index] prev_block_id = block.block_id diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 9727f6e19b84e..3859e772bafbf 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -4,6 +4,7 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.core.block.token_ids import TokenIds from vllm.platforms import current_platform from vllm.utils import Device @@ -136,7 +137,7 @@ def allocate_mutable_block(self, prev_block: Optional[Block], return self._allocators[device].allocate_mutable_block(prev_block) def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], + block_token_ids: List[TokenIds], device: Device) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. @@ -144,7 +145,7 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block], Args: prev_block (Optional[Block]): The previous block in the sequence. Used for prefix hashing. - block_token_ids (List[int]): The list of block token IDs to be + block_token_ids (TokenIds): The block token IDs to be stored in the new blocks. device (Device): The device on which to allocate the new block. @@ -156,16 +157,15 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block], prev_block, block_token_ids) def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], - device: Device) -> Block: + token_ids: TokenIds, device: Device) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. Args: prev_block (Optional[Block]): The previous block in the sequence. Used for prefix hashing. - token_ids (List[int]): The list of token IDs to be stored in the new - block. + token_ids (TokenIds): The token IDs to be stored in the + new block. device (Device): The device on which to allocate the new block. Returns: @@ -356,7 +356,7 @@ def __init__(self, proxy: Block): super().__init__() self._proxy = proxy - def append_token_ids(self, token_ids: List[BlockId]): + def append_token_ids(self, token_ids: TokenIds): raise ValueError("null block should not be modified") @property @@ -368,7 +368,7 @@ def block_id(self, value: Optional[BlockId]): raise ValueError("null block should not be modified") @property - def token_ids(self) -> List[BlockId]: + def token_ids(self) -> TokenIds: return self._proxy.token_ids @property diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 72bbab1dcea5d..08e87274d116d 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple +from vllm.core.block.token_ids import TokenIds from vllm.utils import Device BlockId = int @@ -9,7 +10,7 @@ class Block(ABC): @abstractmethod - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: TokenIds) -> None: pass @property @@ -25,7 +26,7 @@ def block_id(self, value: Optional[int]) -> None: @property @abstractmethod - def token_ids(self) -> List[int]: + def token_ids(self) -> TokenIds: pass @property @@ -77,7 +78,7 @@ class Factory(Protocol): def __call__( self, prev_block: Optional["Block"], - token_ids: List[int], + token_ids: TokenIds, block_size: int, allocator: "BlockAllocator", block_id: Optional[int] = None, @@ -104,13 +105,13 @@ def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: @abstractmethod def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: TokenIds) -> Block: pass @abstractmethod def allocate_immutable_blocks( self, prev_block: Optional[Block], - block_token_ids: List[List[int]]) -> List[Block]: + block_token_ids: List[TokenIds]) -> List[Block]: pass @abstractmethod @@ -202,13 +203,12 @@ def allocate_mutable_block(self, prev_block: Optional[Block], @abstractmethod def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], - device: Device) -> Block: + token_ids: TokenIds, device: Device) -> Block: pass @abstractmethod def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], + block_token_ids: List[TokenIds], device: Device) -> List[Block]: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9341a518d11c6..be4816679307a 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -4,6 +4,7 @@ from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.token_ids import TokenIds Refcount = int @@ -62,7 +63,7 @@ def __init__( def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -71,7 +72,8 @@ def allocate_immutable_block(self, prev_block (Optional[Block]): The previous block in the sequence. If None, then the block to be allocated is the first block in the sequence. - token_ids (List[int]): The token IDs to be stored in the new block. + token_ids (TokenIds): The token IDs to be stored in the + new block. Returns: Block: The newly allocated immutable block. @@ -84,7 +86,7 @@ def allocate_immutable_block(self, def allocate_immutable_blocks( self, prev_block: Optional[Block], - block_token_ids: List[List[int]], + block_token_ids: List[TokenIds], device: Optional[Device] = None) -> List[Block]: assert device is None num_blocks = len(block_token_ids) @@ -120,7 +122,7 @@ def allocate_mutable_block(self, assert device is None block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], + token_ids=TokenIds(), block_size=self._block_size, physical_block_id=block_id) return block @@ -340,7 +342,8 @@ class NaiveBlock(Block): Args: prev_block (Block): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. + token_ids (TokenIds): The initial token IDs to be stored in + the block. block_size (int): The maximum number of token IDs that can be stored in the block. allocator (BlockAllocator): The block allocator associated with this @@ -354,12 +357,12 @@ class NaiveBlock(Block): def __init__(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, _cow_target: Optional[Block] = None): - self._token_ids: List[int] = [] + self._token_ids = TokenIds() self._block_size = block_size self._prev_block = prev_block self._block_id = block_id @@ -368,12 +371,12 @@ def __init__(self, self._append_token_ids_no_cow(token_ids) - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: TokenIds) -> None: """Appends the given token IDs to the block and performs a copy-on-write if necessary. Args: - token_ids (Optional[List[int]]): The token IDs to be appended + token_ids (TokenIds): The token IDs to be appended to the block. """ self._append_token_ids_no_cow(token_ids) @@ -382,18 +385,18 @@ def append_token_ids(self, token_ids: List[int]) -> None: self._block_id = (self._allocator.cow_block_if_not_appendable( self._cow_target)) - def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + def _append_token_ids_no_cow(self, token_ids: TokenIds) -> None: """Appends the given token IDs to the block Args: - token_ids (List[int]): The token IDs to be appended to the block. + token_ids (TokenIds): The token IDs to be appended to the + block. """ if len(token_ids) == 0: return assert len(token_ids) <= self.num_empty_slots - - self._token_ids.extend(token_ids) + self._token_ids += token_ids @property def computed(self) -> bool: @@ -428,7 +431,7 @@ def num_empty_slots(self) -> int: return self._block_size - len(self.token_ids) @property - def token_ids(self) -> List[int]: + def token_ids(self) -> TokenIds: return self._token_ids @property diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 57527e39b9bdd..86c61ac1b2a0c 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -7,6 +7,7 @@ from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) +from vllm.core.block.token_ids import TokenIds from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor PrefixHash = int @@ -116,7 +117,7 @@ def __init__( def _create_block( self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, @@ -136,14 +137,15 @@ def _create_block( def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. Args: prev_block (Optional[Block]): The previous block in the sequence. - token_ids (List[int]): The token IDs to be stored in the block. + token_ids (TokenIds): The token IDs to be stored in the + block. Returns: Block: The allocated immutable block. @@ -175,7 +177,7 @@ def allocate_immutable_block(self, def allocate_immutable_blocks( self, prev_block: Optional[Block], - block_token_ids: List[List[int]], + block_token_ids: List[TokenIds], device: Optional[Device] = None) -> List[Block]: blocks = [] for token_ids in block_token_ids: @@ -203,7 +205,7 @@ def allocate_mutable_block(self, block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], + token_ids=TokenIds(), block_size=self._block_size, physical_block_id=block_id) assert not block.computed @@ -646,7 +648,8 @@ class PrefixCachingBlock(Block): Args: prev_block (Optional[PrefixCachingBlock]): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. + token_ids (TokenIds): The initial token IDs to be stored in + the block. block_size (int): The maximum number of token IDs that can be stored in the block. allocator (BlockAllocator): The prefix @@ -658,7 +661,7 @@ class PrefixCachingBlock(Block): def __init__( self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: TokenIds, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, @@ -726,12 +729,13 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): self._last_accessed = last_accessed_ts - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: TokenIds) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. Args: - token_ids (List[int]): The token IDs to be appended to the block. + token_ids (TokenIds): The token IDs to be appended to the + block. """ # Ensure this is mutable block (not promoted) assert self.content_hash is None @@ -741,7 +745,7 @@ def append_token_ids(self, token_ids: List[int]) -> None: return # Ensure there are input tokens - assert token_ids, "Got token_ids = {}".format(token_ids) + assert token_ids, "Got token_ids = {}".format(token_ids.token_ids) # Naive block handles CoW. self._block.append_token_ids(token_ids) @@ -778,7 +782,7 @@ def block_size(self) -> int: return self._block.block_size @property - def token_ids(self) -> List[int]: + def token_ids(self) -> TokenIds: return self._block.token_ids @property @@ -820,7 +824,7 @@ def content_hash(self) -> Optional[int]: @staticmethod def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], - cur_block_token_ids: List[int]) -> int: + cur_block_token_ids: TokenIds) -> int: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. @@ -839,7 +843,9 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], - int: The computed hash value for the block. """ assert (prev_block_hash is None) == is_first_block - return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) + return hash( + (is_first_block, prev_block_hash, cur_block_token_ids.token_ids, + cur_block_token_ids.annotations)) class ComputedBlocksTracker: diff --git a/vllm/core/block/token_ids.py b/vllm/core/block/token_ids.py new file mode 100644 index 0000000000000..bd081048d5eed --- /dev/null +++ b/vllm/core/block/token_ids.py @@ -0,0 +1,257 @@ +from typing import Iterable, List, NamedTuple, Optional, Tuple, overload + + +class TokenRangeAnnotation(NamedTuple): + """ + Annotates a range of placeholder tokens to capture content that will + replace them. + """ + + content_hash: int + content_offset: int + token_start_index: int + token_end_index: int + + @property + def token_count(self) -> int: + return self.token_end_index - self.token_start_index + + @staticmethod + def are_adjacent(left: "TokenRangeAnnotation", + right: "TokenRangeAnnotation") -> bool: + """ + Indicates whether two annotations represent adjacent ranges in the + hashed content. + """ + + return (left.content_hash == right.content_hash and + left.content_offset + left.token_count == right.content_offset) + + def clipped_to_slice(self, tokens_start: int, + tokens_end: int) -> Optional["TokenRangeAnnotation"]: + """ + Computes a new TokenRangeAnnotation that corresponds to the same + content in a slice of the original token IDs. + + For example, consider the following token IDs/annotations: + + AAAA BBBB What do these images have in common? + + A = TokenRangeAnnotation(0xA, 0, 0, 4) + B = TokenRangeAnnotation(0xB, 0, 5, 9) + + tokens = AAAA BBBB What do these images have in common? + [AAAA] + + A.clipped_to_slice(0, 4) = TokenRangeAnnotation(0xA, 0, 0, 4) + B.clipped_to_slice(0, 4) = None + + tokens = AAAA BBBB What do these images have in common? + [AA BB] + + A.clipped_to_slice(2, 7) = TokenRangeAnnotation(0xA, 2, 0, 2) + B.clipped_to_slice(2, 7) = TokenRangeAnnotation(0xB, 0, 3, 5) + + tokens = AAAA BBBB What do these images have in common? + [BBBB What] + + A.clipped_to_slice(5, 14) = None + B.clipped_to_slice(5, 14) = TokenRangeAnnotation(0xB, 0, 0, 4) + """ + + unclamped_annotation_start = self.token_start_index - tokens_start + unclamped_annotation_end = self.token_end_index - tokens_start + + annotation_start = max(0, unclamped_annotation_start) + annotation_end = min(tokens_end - tokens_start, + unclamped_annotation_end) + + if annotation_start >= annotation_end: + # There is no overlap. + return None + + return TokenRangeAnnotation(content_hash=self.content_hash, + content_offset=self.content_offset + + annotation_start - + unclamped_annotation_start, + token_start_index=annotation_start, + token_end_index=annotation_end) + + +class TokenIds: + token_ids: Tuple[int, ...] + annotations: Tuple[TokenRangeAnnotation, ...] + + def __init__(self, + token_ids: Iterable[int] = (), + annotations: Iterable[TokenRangeAnnotation] = ()): + self.token_ids = tuple(token_ids) + self.annotations = tuple(annotations) + + # Ensure that the token annotations are monotonic. + current_token_index = 0 + for annotation in self.annotations: + if (annotation.token_start_index < current_token_index or + annotation.token_end_index < annotation.token_start_index): + raise ValueError("TokenRangeAnnotations must be sorted and " + "non-overlapping.") + + current_token_index = annotation.token_end_index + + if current_token_index > len(self.token_ids): + raise ValueError("TokenRangeAnnotations must be entirely " + "contained within the token IDs.") + + def to_chunks(self, + chunk_size: int, + *, + first_chunk_size: Optional[int] = None): + """ + Yields successive chunks over the TokenIds, taking care to filter + or split TokenRangeAnnotations accordingly. + """ + + current_annotation_index = 0 + current_chunk_start = 0 + current_chunk_end = (chunk_size + if first_chunk_size is None else first_chunk_size) + + while current_chunk_start < len(self.token_ids): + current_chunk_annotations: List[TokenRangeAnnotation] = [] + while current_annotation_index < len(self.annotations): + existing_annotation = self.annotations[ + current_annotation_index] + if existing_annotation.token_start_index >= current_chunk_end: + # This annotation starts after the current chunk. + break + + # Create a new annotation. + new_annotation = existing_annotation.clipped_to_slice( + tokens_start=current_chunk_start, + tokens_end=current_chunk_end) + assert new_annotation is not None, ( + "The existing annotation should overlap with the new one.") + current_chunk_annotations.append(new_annotation) + if (current_chunk_start + new_annotation.token_end_index == + existing_annotation.token_end_index): + # We've used up this annotation. + current_annotation_index += 1 + else: + break + + yield TokenIds( + self.token_ids[current_chunk_start:current_chunk_end], + current_chunk_annotations) + + current_chunk_start = current_chunk_end + current_chunk_end = current_chunk_start + chunk_size + + def __eq__(self, other: object) -> bool: + if isinstance(other, TokenIds): + return (self.token_ids == other.token_ids + and self.annotations == other.annotations) + + return NotImplemented + + def __add__(self, other: "TokenIds") -> "TokenIds": + """ + Combines two ``TokenIds``, possibly merging ``TokenRangeAnnotion``s. + + ``TokenRangeAnnotation``s at the boundary will be coalesced into a + single annotation if they have the same content hash and they cover + adjacent portions of the hashed content. + """ + + if not self.token_ids: + return other + elif not other.token_ids: + return self + + # Merge the token annotations if necessary + if not other.annotations: + combined_annotations: Iterable[ + TokenRangeAnnotation] = self.annotations + else: + combined_annotations = list(self.annotations) + for annotation in other.annotations: + if combined_annotations: + # Check if we can coalesce this annotation with the last. + last_annotation = combined_annotations[-1] + if (TokenRangeAnnotation.are_adjacent( + last_annotation, annotation) + and last_annotation.token_end_index == len( + self.token_ids) + and annotation.token_start_index == 0): + combined_annotations[-1] = TokenRangeAnnotation( + content_hash=last_annotation.content_hash, + content_offset=last_annotation.content_offset, + token_start_index=last_annotation. + token_start_index, + token_end_index=last_annotation.token_end_index + + annotation.token_count) + continue + + combined_annotations.append( + TokenRangeAnnotation( + content_hash=annotation.content_hash, + content_offset=annotation.content_offset, + token_start_index=len(self.token_ids) + + annotation.token_start_index, + token_end_index=len(self.token_ids) + + annotation.token_end_index)) + + return TokenIds(token_ids=self.token_ids + other.token_ids, + annotations=combined_annotations) + + def __len__(self) -> int: + return len(self.token_ids) + + @overload + def __getitem__(self, key: int) -> int: + ... + + @overload + def __getitem__(self, key: slice) -> "TokenIds": + ... + + def __getitem__(self, key): + """ + Gets a single token at an index or a slice of ``TokenIds``. + """ + if isinstance(key, int): + return self.token_ids[key] + + if isinstance(key, slice): + if key.step: + raise IndexError("Step is not supported.") + + # Resolve negative indices. + start = key.start or 0 + start += len(self) if start < 0 else 0 + + stop = key.stop if key.stop is not None else len(self) + stop += len(self) if stop < 0 else 0 + + # Fast path for the common case where the new slice doesn't + # include any annotations (e.g. slicing a decoded token). + if (not self.annotations + or start >= self.annotations[-1].token_end_index + or stop <= self.annotations[0].token_start_index): + return TokenIds(self.token_ids[start:stop]) + + # Clamp the indices. + start = max(0, min(len(self), start)) + stop = max(start, min(len(self), stop)) + + chunks_iter = iter( + self.to_chunks(chunk_size=stop - start, + first_chunk_size=start)) + + # Drop the first chunk and return the second chunk. + try: + next(chunks_iter) + return next(chunks_iter) + except StopIteration: + return TokenIds() + + raise TypeError(f"Unsupported key type: {type(key)}") diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 61ed7afba12ed..27e4f1095c3db 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -8,6 +8,7 @@ from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) +from vllm.core.block.token_ids import TokenIds from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -115,7 +116,7 @@ def can_allocate(self, seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), + seq.get_len(), block_size=self.block_size, num_lookahead_slots=num_lookahead_slots, ) @@ -124,7 +125,7 @@ def can_allocate(self, encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), + encoder_seq.get_len(), block_size=self.block_size, ) @@ -150,7 +151,7 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_allocator=self.block_allocator, max_block_sliding_window=self.max_block_sliding_window, ) - if seq.get_token_ids(): + if seq.get_len(): # Add blocks to the block table only if the sequence is non empty. block_table.allocate(seq.get_token_ids()) @@ -219,8 +220,7 @@ def can_append_slots(self, seq_group: SequenceGroup, num_touched_blocks += ( block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( - seq.get_token_ids()), + num_token_ids=block_table.get_unseen_token_id_count(seq), num_lookahead_slots=num_lookahead_slots, )) @@ -237,7 +237,7 @@ def append_slots( block_table = self.block_tables[seq.seq_id] block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + token_ids=block_table.get_unseen_token_ids(seq), num_lookahead_slots=num_lookahead_slots, num_computed_slots=seq.data.get_num_computed_tokens(), ) @@ -483,7 +483,7 @@ def _can_swap(self, # to be touched for the swap. num_blocks_touched += \ block_table.get_num_blocks_touched_by_append_slots( - block_table.get_unseen_token_ids(seq.get_token_ids()), + block_table.get_unseen_token_id_count(seq), num_lookahead_slots=num_lookahead_slots) blocks.extend(block_table.blocks) # Compute the number of full blocks to touch and add it to the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c5b442e9f624..c4476f69a07d7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -998,11 +998,10 @@ def create_engine_config(self) -> VllmConfig: device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() - if model_config.is_multimodal_model: - if self.enable_prefix_caching: - logger.warning( - "--enable-prefix-caching is currently not " - "supported for multimodal models and has been disabled.") + if (not model_config.suports_prefix_caching + and self.enable_prefix_caching): + logger.warning("--enable-prefix-caching is currently not " + "supported by this model and has been disabled.") self.enable_prefix_caching = False maybe_register_config_serialize_by_value(self.trust_remote_code) @@ -1040,10 +1039,7 @@ def create_engine_config(self) -> VllmConfig: # If not explicitly set, enable chunked prefill by default for # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - - # Chunked prefill is currently disabled for multimodal models by - # default. - if use_long_context and not model_config.is_multimodal_model: + if use_long_context and model_config.supports_chunked_prefill: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead65115132..e48606927be3e 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -27,6 +27,16 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_chunked_prefill: ClassVar[bool] = False + """ + A flag that indicates this model supports chunked prefill. + """ + + supports_prefix_caching: ClassVar[bool] = False + """ + A flag that indicates this model supports prefix caching. + """ + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -36,6 +46,8 @@ def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: @runtime_checkable class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] + supports_chunked_prefill: bool + supports_prefix_caching: bool def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: ... diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 32750602b988c..0a5fbcfb39c99 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,8 @@ class _ModelInfo: supports_pp: bool has_inner_state: bool is_attention_free: bool + supports_chunked_prefill: bool + supports_prefix_caching: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -197,6 +199,10 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), + supports_chunked_prefill=model.supports_chunked_prefill + if supports_multimodal(model) else True, + supports_prefix_caching=model.supports_prefix_caching + if supports_multimodal(model) else True, ) @@ -425,6 +431,14 @@ def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool: return self.inspect_model_cls(architectures).is_attention_free + def model_supports_chunked_prefill( + self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).supports_chunked_prefill + + def model_supports_prefix_caching( + self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).supports_prefix_caching + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( @@ -479,4 +493,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 411584b1a6c3c..859fc3126a503 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -339,6 +339,8 @@ def forward( @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): + supports_chunked_prefill = True + supports_prefix_caching = True def __init__(self, config: UltravoxConfig, diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 04d71826f29fa..d738fc0b646e1 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,3 +1,7 @@ +from typing import Tuple, Union, cast + +import numpy as np + from vllm.inputs.registry import InputContext from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin @@ -8,6 +12,10 @@ class AudioPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" + def hash_content(self, data: object) -> int: + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + return hash((audio.data.tobytes(), sr)) + def _default_input_mapper(self, ctx: InputContext, data: object, **mm_processor_kwargs) -> MultiModalInputs: raise NotImplementedError("There is no default audio input mapper") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 26c94cf2d0b20..c8df70f4c82f8 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -236,6 +236,14 @@ def _default_input_mapper( """ raise NotImplementedError + @abstractmethod + def hash_content(self, data: object) -> int: + """ + Calculates a content-based hash of the multi-modal item that can be + used to represent the content for prefix caching. + """ + raise NotImplementedError + def register_input_mapper( self, mapper: Optional[MultiModalInputMapper] = None, diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3f6bb6c8338d2..d82886e209650 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -26,6 +26,9 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" + def hash_content(self, data: object) -> int: + raise NotImplementedError("Image hashing is not yet implemented") + def _get_hf_image_processor( self, model_config: "ModelConfig", diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index bce2f4c6abe5b..0afdc04d21588 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -243,3 +243,6 @@ def get_mm_limits_per_prompt( This should be called after :meth:`init_mm_limits_per_prompt`. """ return self._limits_by_model[model_config] + + def hash_mm_item_content(self, data_type_key: str, item: object) -> int: + return self._get_plugin(data_type_key).hash_content(item) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 40a92fed28c87..1c1e0cd79fbc7 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -38,6 +38,9 @@ class VideoPlugin(ImagePlugin): def get_data_key(self) -> str: return "video" + def hash_content(self, data: object) -> int: + raise NotImplementedError("Video hashing is not yet implemented") + def _get_hf_video_processor( self, model_config: "ModelConfig", diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec4447..07bb7a2c47a15 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -15,8 +15,10 @@ import torch from typing_extensions import assert_never +from vllm.core.block.token_ids import TokenIds, TokenRangeAnnotation from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalPlaceholderDict) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -476,7 +478,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]: assert_never(inputs) - @property + @cached_property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: inputs = self.inputs @@ -485,6 +487,29 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: assert_never(inputs) + @cached_property + def token_range_annotations(self): + annotations: List[TokenRangeAnnotation] = [] + for modality, ranges in self.multi_modal_placeholders.items(): + mm_items = self.multi_modal_data.get(modality) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + for range, mm_item in zip(ranges, mm_items): + content_hash = MULTIMODAL_REGISTRY.hash_mm_item_content( + modality, mm_item) + annotations.append( + TokenRangeAnnotation( + content_hash=content_hash, + content_offset=0, + token_start_index=range["offset"], + token_end_index=range["offset"] + range["length"], + )) + + sorted_annotations = sorted(annotations, + key=lambda a: a.token_start_index) + return sorted_annotations + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -569,8 +594,9 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() - def get_token_ids(self) -> List[int]: - return self.data.get_token_ids() + def get_token_ids(self) -> TokenIds: + return TokenIds(self.data.get_token_ids(), + self.token_range_annotations) def get_prompt_token_ids(self) -> Tuple[int, ...]: return self.data.get_prompt_token_ids() diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 7c8423d2b0a34..cdfe51ae14e2b 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -40,7 +40,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, # We can pick any sequence for the prompt. seq = seq_group.get_seqs()[0] # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() + all_token_ids = list(seq.get_token_ids().token_ids) prompt_token_ids = all_token_ids[:-1] tokenizer = self.get_tokenizer_for_seq(seq) prefix_offset = 0 @@ -105,7 +105,7 @@ def decode_sequence_inplace(self, seq: Sequence, Returns: The number of characters added to the output text. """ - all_input_ids = seq.get_token_ids() + all_input_ids = list(seq.get_token_ids().token_ids) token_id_generated_this_iteration = all_input_ids[-1] tokenizer = self.get_tokenizer_for_seq(seq) diff --git a/vllm/utils.py b/vllm/utils.py index 13d7f6d475346..a5b8a36772586 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -557,12 +557,6 @@ def update_environment_variables(envs: Dict[str, str]): os.environ[k] = v -def chunk_list(lst: List[T], chunk_size: int): - """Yield successive chunk_size chunks from lst.""" - for i in range(0, len(lst), chunk_size): - yield lst[i:i + chunk_size] - - def cdiv(a: int, b: int) -> int: """Ceiling division.""" return -(a // -b)