Skip to content

Commit

Permalink
[V1] Prefix caching for vision language models (#11187)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
  • Loading branch information
comaniac authored Dec 18, 2024
1 parent c77eb8a commit bf8717e
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 98 deletions.
88 changes: 86 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import pytest

from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens


def make_request(request_id, prompt_token_ids):
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
Expand All @@ -38,6 +45,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
Expand All @@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
Expand Down Expand Up @@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
Expand Down Expand Up @@ -416,3 +426,77 @@ def test_cache_blocks():
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None


def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)

# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids = list(range(10)) + [-1] * 6
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
common_token_ids += [-1] * 16

common_mm_positions = [
PlaceholderRange(offset=11, length=10),
PlaceholderRange(offset=30, length=18),
]
common_mm_hashes = ["aaa", "bbb"]

# A unique image plus some text tokens.
unique_token_ids = [-1] * 7 + [100] * 4
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )

blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0

# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
15 changes: 0 additions & 15 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert engine_args.enable_prefix_caching


def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"


def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
Expand All @@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048


def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
27 changes: 14 additions & 13 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __post_init__(self):
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

# Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
Expand Down Expand Up @@ -1026,11 +1027,11 @@ def create_engine_config(self,
device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config()

if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
and self.enable_prefix_caching):
logger.warning("--enable-prefix-caching is currently not "
"supported for multimodal models in v0 and "
"has been disabled.")
self.enable_prefix_caching = False

cache_config = CacheConfig(
Expand Down Expand Up @@ -1249,11 +1250,14 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
if self.max_num_batched_tokens is None:
if usage_context == UsageContext.LLM_CLASS:
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
self.max_num_batched_tokens = 2048
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.warning(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, usage_context.value)
Expand All @@ -1263,9 +1267,6 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
if engine_config.model_config.is_multimodal_model:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching


@dataclass
Expand Down
20 changes: 20 additions & 0 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
"""

multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""

mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
Expand All @@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
Expand All @@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
Expand Down Expand Up @@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:

assert_never(inputs)

@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])

if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])

assert_never(inputs)

@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
Expand Down
3 changes: 3 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""

mm_hashes: NotRequired[List[str]]
"""The hashes of the multi-modal data."""

mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
Expand Down
74 changes: 50 additions & 24 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request

Expand Down Expand Up @@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:

computed_blocks = []

# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -242,14 +246,16 @@ def allocate_slots(
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size

self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)

return new_blocks

Expand Down Expand Up @@ -376,6 +382,8 @@ def _cache_full_blocks(
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
Expand All @@ -387,17 +395,35 @@ def _cache_full_blocks(
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i

block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")

# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
Expand Down
Loading

0 comments on commit bf8717e

Please sign in to comment.