From bf8717ebaea8d74279df84fbe127ad22cf62e219 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 17 Dec 2024 16:37:59 -0800 Subject: [PATCH] [V1] Prefix caching for vision language models (#11187) Signed-off-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 88 +++++++++++++++++++- tests/v1/engine/test_engine_args.py | 15 ---- vllm/engine/arg_utils.py | 27 ++++--- vllm/inputs/data.py | 20 +++++ vllm/multimodal/inputs.py | 3 + vllm/v1/core/kv_cache_manager.py | 74 +++++++++++------ vllm/v1/core/kv_cache_utils.py | 115 ++++++++++++++++++++++++--- vllm/v1/core/scheduler.py | 2 + vllm/v1/engine/async_llm.py | 10 ++- vllm/v1/engine/core.py | 8 +- vllm/v1/engine/llm_engine.py | 9 ++- vllm/v1/engine/mm_input_mapper.py | 33 ++++---- vllm/v1/engine/processor.py | 12 +-- vllm/v1/request.py | 24 +++++- 14 files changed, 342 insertions(+), 98 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 00f7b0fcfe1dc..ed04f0a373c51 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2,16 +2,23 @@ import pytest from vllm.inputs import token_inputs +from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens -def make_request(request_id, prompt_token_ids): +def make_request(request_id, + prompt_token_ids, + mm_positions=None, + mm_hashes=None): return Request( request_id=request_id, - inputs=token_inputs(prompt_token_ids=prompt_token_ids), + inputs=token_inputs(prompt_token_ids=prompt_token_ids, + multi_modal_placeholders={"image": mm_positions} + if mm_positions else None, + multi_modal_hashes=mm_hashes), sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, arrival_time=0, @@ -38,6 +45,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks = manager.get_computed_blocks(req0) + assert len(req0.kv_block_hashes) == 3 assert not computed_blocks blocks = manager.allocate_slots(req0, 55, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -61,6 +69,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks = manager.get_computed_blocks(req1) + assert len(req1.kv_block_hashes) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) @@ -90,6 +99,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_block = manager.get_computed_blocks(req2) + assert len(req2.kv_block_hashes) == 3 assert [b.block_id for b in computed_block] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) @@ -416,3 +426,77 @@ def test_cache_blocks(): ) assert len(manager.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None + + +def test_mm_prefix_caching(): + """ + This tests that the multi-modal prefix caching is correct. + """ + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=16, + ) + + # Common prompt tokens (T is text tokens and P is image placeholder tokens) + # [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1] + common_token_ids = list(range(10)) + [-1] * 6 + common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2 + common_token_ids += [-1] * 16 + + common_mm_positions = [ + PlaceholderRange(offset=11, length=10), + PlaceholderRange(offset=30, length=18), + ] + common_mm_hashes = ["aaa", "bbb"] + + # A unique image plus some text tokens. + unique_token_ids = [-1] * 7 + [100] * 4 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req0 = make_request("0", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks + assert len(req0.kv_block_hashes) == 3 + assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), ) + assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0)) + assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), ) + + blocks = manager.allocate_slots(req0, 59, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.append_slots(req0, 5) + assert new_blocks is not None and len(new_blocks) == 0 + + # The just completed block should have hashes with extra keys. + assert len(req0.kv_block_hashes) == 4 + assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), ) + + # Cache hit. + unique_token_ids = [-1] * 7 + [200] * 5 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req1 = make_request("1", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req1) + assert len(computed_blocks) == 3 diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index ac5e7dde525a7..ff38a4568ecb1 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -31,14 +31,6 @@ def test_prefix_caching_from_cli(): assert engine_args.enable_prefix_caching -def test_defaults(): - engine_args = EngineArgs(model="facebook/opt-125m") - - # Assert V1 defaults - assert (engine_args.enable_prefix_caching - ), "V1 turns on prefix caching by default" - - def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") vllm_config: VllmConfig = engine_args.create_engine_config( @@ -52,10 +44,3 @@ def test_defaults_with_usage_context(): UsageContext.OPENAI_API_SERVER) assert vllm_config.scheduler_config.max_num_seqs == 1024 assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 - - -def test_prefix_cache_disabled_with_multimodel(): - engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf") - - vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - assert not vllm_config.cache_config.enable_prefix_caching diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6d276fe7c0c8..674577f23eba6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -205,6 +205,7 @@ def __post_init__(self): # by user. if self.enable_prefix_caching is None: self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + # Override max_num_seqs if it's not set by user. if self.max_num_seqs is None: self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024 @@ -1026,11 +1027,11 @@ def create_engine_config(self, device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() - if model_config.is_multimodal_model: - if self.enable_prefix_caching: - logger.warning( - "--enable-prefix-caching is currently not " - "supported for multimodal models and has been disabled.") + if (model_config.is_multimodal_model and not envs.VLLM_USE_V1 + and self.enable_prefix_caching): + logger.warning("--enable-prefix-caching is currently not " + "supported for multimodal models in v0 and " + "has been disabled.") self.enable_prefix_caching = False cache_config = CacheConfig( @@ -1249,11 +1250,14 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None: # When no user override, set the default values based on the usage # context. # TODO(woosuk): Tune the default values for different hardware. - if self.max_num_batched_tokens is None: - if usage_context == UsageContext.LLM_CLASS: - self.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - self.max_num_batched_tokens = 2048 + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] logger.warning( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, usage_context.value) @@ -1263,9 +1267,6 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - if engine_config.model_config.is_multimodal_model: - # TODO (ywang96): Enable APC by default when VLM supports it. - assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 85aaaa776907f..d54cbb5c37819 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -162,6 +162,11 @@ class TokenInputs(TypedDict): Placeholder ranges for the multi-modal data. """ + multi_modal_hashes: NotRequired[List[str]] + """ + The hashes of the multi-modal data. + """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the @@ -177,6 +182,7 @@ def token_inputs( prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None, + multi_modal_hashes: Optional[List[str]] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: @@ -191,6 +197,8 @@ def token_inputs( inputs["multi_modal_data"] = multi_modal_data if multi_modal_inputs is not None: inputs["multi_modal_inputs"] = multi_modal_inputs + if multi_modal_hashes is not None: + inputs["multi_modal_hashes"] = multi_modal_hashes if multi_modal_placeholders is not None: inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: @@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: assert_never(inputs) + @cached_property + def multi_modal_hashes(self) -> List[str]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_hashes", []) + + if inputs["type"] == "multimodal": + return inputs.get("mm_hashes", []) + + assert_never(inputs) + @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 229a8fbdf5831..c00943a5f26d9 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict): mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" + mm_hashes: NotRequired[List[str]] + """The hashes of the multi-modal data.""" + mm_placeholders: MultiModalPlaceholderDict """ For each modality, information about the placeholder tokens in diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index aaa44c930e324..61a3f5fd6d841 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,7 +4,9 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, hash_block_tokens, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens, hash_request_tokens) from vllm.v1.request import Request @@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: computed_blocks = [] - # TODO(rickyx): potentially we could cache this so we don't have to - # recompute it every time. - block_hashes = hash_request_tokens(self.block_size, - request.all_token_ids) + # The block hashes for the request may already be computed + # if the request was preempted and resumed. + if not request.kv_block_hashes: + request.set_kv_block_hashes( + hash_request_tokens(self.block_size, request)) + block_hashes = request.kv_block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -242,14 +246,16 @@ def allocate_slots( num_computed_tokens = len(computed_blocks) * self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - self._cache_full_blocks( - request=request, - blk_start_idx=len(computed_blocks), - # The new full blocks are the full blocks that are not computed. - full_blocks=self.req_to_blocks[request.request_id] - [len(computed_blocks):num_full_blocks], - prev_block=computed_blocks[-1] if computed_blocks else None, - ) + new_full_blocks = self.req_to_blocks[ + request.request_id][len(computed_blocks):num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=computed_blocks[-1] if computed_blocks else None, + ) return new_blocks @@ -376,6 +382,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ + num_cached_block_hashes = len(request.kv_block_hashes) + # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None if prev_block is not None: @@ -387,17 +395,35 @@ def _cache_full_blocks( for i, blk in enumerate(full_blocks): blk_idx = blk_start_idx + i - block_tokens = request.all_token_ids[blk_idx * - self.block_size:(blk_idx + - 1) * - self.block_size] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got {len(block_tokens)} " - f"at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = request.kv_block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * self.block_size + end_token_idx = (blk_idx + 1) * self.block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + request.append_kv_block_hashes(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 0ba338aa5a3d2..d80ea128c7749 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,20 +1,25 @@ """KV-Cache Utilities.""" from collections.abc import Sequence from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple from vllm.logger import init_logger +from vllm.v1.request import Request logger = init_logger(__name__) class BlockHashType(NamedTuple): - """Hash value of a block and the token IDs in the block. - The reason we keep a tuple of token IDs is to make sure no hash - collision happens when the hash value is the same. + """Hash value of a block (int), the token IDs in the block, and extra keys. + The reason we keep a tuple of token IDs and extra keys is to make sure + no hash collision happens when the hash value is the same. """ + # Hash value of the block in an integer. hash_value: int + # Token IDs in the block. token_ids: Tuple[int, ...] + # Extra keys for the block. + extra_keys: Optional[Any] = None @dataclass @@ -159,8 +164,80 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def hash_block_tokens(parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int]) -> BlockHashType: +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + For multi-modal inputs, the extra keys are (mm_hash, start_offset) that + indicate a mm input contained in the block and its starting offset in + the block tokens. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if not mm_positions: + return None, start_mm_idx + + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match. This " + "is likely because you do not enable MM preprocessor hashing. " + "Please set mm_cache_preprocessor=True.") + + # Note that we assume mm_positions is sorted by offset. + # We do not need to check all mm inputs if the start token index is out of + # range. This usually happens in the late prefill phase and decoding phase. + if mm_positions[-1]["offset"] + mm_positions[-1][ + "length"] < start_token_idx: + return None, start_mm_idx + + # Support start_mm_idx == -1 to indicate the last mm input. + if start_mm_idx < 0: + assert -start_mm_idx <= len(mm_positions) + start_mm_idx = len(mm_positions) + start_mm_idx + + extra_keys = [] + curr_mm_idx = start_mm_idx + while mm_positions and curr_mm_idx < len(mm_positions): + assert mm_hashes[curr_mm_idx] is not None + offset = mm_positions[curr_mm_idx]["offset"] + length = mm_positions[curr_mm_idx]["length"] + if end_token_idx > offset: + if start_token_idx > offset + length: + # This block has passed the current mm input. + curr_mm_idx += 1 + continue + + # The block contains the current mm input. + mm_start = max(0, start_token_idx - offset) + extra_keys.append((mm_hashes[curr_mm_idx], mm_start)) + if end_token_idx >= offset + length: + # If this block contains the end of the current mm input, + # move to the next mm input as this block may also contain + # the next mm input. + curr_mm_idx += 1 + else: + # Otherwise this block is done with mm inputs. + break + else: + # This block has not reached the current mm input. + break + return tuple(extra_keys), curr_mm_idx + + +def hash_block_tokens( + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int], if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. + extra_keys: Extra keys for the block. Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. """ return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), - tuple(curr_block_token_ids)) + tuple(curr_block_token_ids), extra_keys) def hash_request_tokens(block_size: int, - token_ids: Sequence[int]) -> List[BlockHashType]: + request: Request) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. - token_ids: A sequence of token ids in the request. + request: The request object. Returns: The list of computed hash values. """ + token_ids = request.all_token_ids + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match.") + + # TODO: Extend this to support other features such as LoRA. + need_extra_keys = bool(mm_positions) + extra_keys = None + curr_mm_idx = 0 + ret = [] parent_block_hash_value = None for start in range(0, len(token_ids), block_size): @@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int, # Do not hash the block if it is not full. if len(block_token_ids) < block_size: break + + # Add extra keys if the block is a multi-modal block. + if need_extra_keys: + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start, end, curr_mm_idx) + block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids) + block_token_ids, extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 178532e477dae..08e7c0fd4dc9b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -516,6 +516,7 @@ class NewRequestData: prompt_token_ids: List[int] prompt: Optional[str] mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams block_ids: List[int] @@ -533,6 +534,7 @@ def from_request( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b36de5f66917c..41fb4b25d45bb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -60,9 +60,13 @@ def __init__( self.client_aborted_requests: List[str] = [] # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry) + self.processor = Processor( + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + ) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 56d4dc67e4a0e..497d5db5b4c99 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -65,7 +65,8 @@ def __init__( self._last_logging_time = time.time() - self.mm_input_mapper_server = MMInputMapperServer() + self.mm_input_mapper_server = MMInputMapperServer( + vllm_config.model_config) def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: @@ -98,9 +99,8 @@ def add_request(self, request: EngineCoreRequest): # MM mapper, so anything that has a hash must have a HIT cache # entry here as well. assert request.mm_inputs is not None - request.mm_inputs, request.mm_hashes = ( - self.mm_input_mapper_server.process_inputs( - request.mm_inputs, request.mm_hashes)) + request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 15dedbd0f9529..bea8c5502f612 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -55,9 +55,12 @@ def __init__( self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry, mm_registry) + self.processor = Processor(model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + mm_registry=mm_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput) self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 6cdeba6f3f71e..e53ba092ede04 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import PIL from blake3 import blake3 @@ -42,6 +42,8 @@ def __init__( model_config) self.mm_registry.init_mm_limits_per_prompt(model_config) + # Init cache + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable @@ -61,7 +63,7 @@ def process_inputs( mm_hashes: Optional[List[str]], mm_processor_kwargs: Optional[Dict[str, Any]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]], - ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: + ) -> List[MultiModalKwargs]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -70,26 +72,21 @@ def process_inputs( else: num_inputs = len(precomputed_mm_inputs) - # Check if hash is enabled - use_hash = mm_hashes is not None - if use_hash: + # Sanity + if self.use_cache: assert mm_hashes is not None - assert num_inputs == len( - mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format( - num_inputs, len(mm_hashes)) + assert num_inputs == len(mm_hashes) # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_hashes: Optional[List[str]] = [] if use_hash else None ret_inputs: List[MultiModalKwargs] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) - mm_hash = None mm_input = None - if use_hash: + if self.use_cache: assert mm_hashes is not None mm_hash = mm_hashes[input_id] mm_input = self.mm_cache.get(mm_hash) @@ -106,7 +103,7 @@ def process_inputs( mm_processor_kwargs=mm_processor_kwargs, ) - if use_hash: + if self.use_cache: # Add to cache assert mm_hash is not None self.mm_cache.put(mm_hash, mm_input) @@ -114,18 +111,15 @@ def process_inputs( self.mm_cache_hits += 1 mm_input = None # Avoids sending mm_input to Server - if use_hash: - assert mm_hash is not None - assert ret_hashes is not None - ret_hashes.append(mm_hash) ret_inputs.append(mm_input) - return ret_inputs, ret_hashes + return ret_inputs class MMInputMapperServer: - def __init__(self, ): + def __init__(self, model_config): + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( @@ -135,6 +129,9 @@ def process_inputs( ) -> List[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) + if not self.use_cache: + return mm_inputs + full_mm_inputs = [] for mm_input, mm_hash in zip(mm_inputs, mm_hashes): assert mm_hash is not None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 679bf8e25e9ca..732757d6b0ac2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,7 @@ import time from typing import Any, Dict, Mapping, Optional, Tuple, Union -from vllm.config import LoRAConfig, ModelConfig +from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -23,6 +23,7 @@ class Processor: def __init__( self, model_config: ModelConfig, + cache_config: CacheConfig, lora_config: Optional[LoRAConfig], tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, @@ -45,8 +46,9 @@ def __init__( self.mm_input_mapper_client = MMInputMapperClient(model_config) # Multi-modal hasher (for images) - self.mm_hasher = MMHasher( - ) if model_config.mm_cache_preprocessor else None + self.use_hash = model_config.mm_cache_preprocessor or \ + cache_config.enable_prefix_caching + self.mm_hasher = MMHasher() # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the @@ -77,7 +79,7 @@ def process_inputs( # Compute MM hashes (if enabled) mm_hashes = None - if self.mm_hasher is not None: + if self.use_hash: mm_hashes = self.mm_hasher.hash(prompt) # Process inputs. @@ -118,7 +120,7 @@ def process_inputs( # Apply MM mapper mm_inputs = None if len(decoder_inputs.multi_modal_data) > 0: - mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( + mm_inputs = self.mm_input_mapper_client.process_inputs( decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1737d096e811d..f4783ae366ef0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,5 +1,5 @@ import enum -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest @@ -9,6 +9,9 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import BlockHashType + class Request: @@ -45,6 +48,7 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 + # Multi-modal input metadata. mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. @@ -56,6 +60,12 @@ def __init__( if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs + self.mm_hashes: List[str] = self.inputs.multi_modal_hashes + + # Cache the computed kv block hashes of the request to avoid + # recomputing. + self._kv_block_hashes: List[BlockHashType] = [] + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -65,6 +75,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": prompt=request.prompt, multi_modal_data=None, multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, mm_processor_kwargs=None, ), @@ -121,6 +132,17 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens + @property + def kv_block_hashes(self) -> ConstantList["BlockHashType"]: + # Prevent directly appending to the kv_block_hashes. + return ConstantList(self._kv_block_hashes) + + def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + self._kv_block_hashes = value + + def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: + self._kv_block_hashes.append(block_hash) + class RequestStatus(enum.IntEnum): """Status of a request."""