Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Autotune encoder cache budget #11895

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,14 +1379,6 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
Expand Down
29 changes: 24 additions & 5 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,8 @@ def get_max_tokens_per_item_by_modality(
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.

Note:
This is currently directly used only in V1.
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
Expand All @@ -270,6 +267,28 @@ def get_max_tokens_per_item_by_modality(
for key, plugin in self._plugins.items()
}

def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.

Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand Down
86 changes: 85 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from typing import Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig

logger = init_logger(__name__)


class EncoderCacheManager:

Expand Down Expand Up @@ -46,3 +54,79 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed


def compute_encoder_cache_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> int:
"""Compute the encoder cache budget based on the model and scheduler
configurations.

Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.

Returns:
The encoder cache budget, in unit of number of tokens
in the input sequence.
"""

encoder_cache_budget = 0

# TODO: handle encoder-decoder models once we support them.
if not model_config.is_multimodal_model:
return encoder_cache_budget
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)

if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return encoder_cache_budget

modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_reqs = scheduler_config.max_num_seqs

# In case that the biggest possible multimodal item takes space more
# than the batch size, then it needs to be cached and chunk prefilled.
if max_tokens_per_mm_item > max_num_batched_tokens:
num_items = 1

# In case that the biggest possible multimodal item takes space less
# the batch size, then all items will be full prefilled except one.
else:
num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems a bit confusing to me. I tried to rephrase base on my understanding but please help clarify:

num_items == 1:

# The biggest possible multimodal item cannot be prefilled in a batch,
# so it must be cached and chunked prefill.

num_items > 1:

# A batch can cover all (except the last one) multimodal items.

Meanwhile, I don't fully understand what you meant by "cached" and "chunked prefill" tho. I suppose they are orthogonal to the number of items?

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will clarify this. During profiling we always take the worst case (i.e requests will all have the biggest possible multimodal item), so what I meant by "cached" and "chunked prefill" is that each multimodal item will always be needed in two engine steps, since the batch cannot cover the entirety of it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified via 2a4b1d5

# NOTE: We need the encoder cache to be able to compute & hold ONE
# ADDITIONAL multimodal item, and is required only when:
# - Two requests in the current batch share the same prefix with such item
# as part of the prefix.
# - AND the prefix length is divisible by the block size, triggering the
# recomputation of the last block.
# - AND the part of the embeddings of the item is in this last block.

# This can be improved when we have a global encoder cache that does
# not associate items to request id only.
num_items += 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only applicable to the else block?

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applicable to all cases, and is in fact in the if block is how I discovered this issue that wasn't addressed prior to this PR.

Here's a concrete example:
Suppose the max_num_batched_token=8192 and two identical requests have length 16032 after processing, and their image with start_index=7333 and end_index=16020 (thus length=8687), and suppose encoder_cache_budget=8687 for the sake of showing how the issue will happen when we don't add budget for one additional item.

Time 0: Request 0 gets scheduled for 8192 tokens. Since start_index=7333 < 8192 < end_index=16020 and cache is empty, image 0 gets processed and the result embeddings is cached, thus all space budget is used up.

Time 1:

  • Request 0 gets scheduled for the rest 16032 - 8192 = 7840 tokens. An important note here is that scheduling is synchronous, therefore we treat these tokens are already computed once scheduled.
  • The issue happens when we try to schedule Request 1 since there is still space in the batch. Because they're identical, the number of computed tokens for Request 1 is then 16032 from the get go, which triggers a recompute for the last 16 tokens. However, note that the image ends at 16020 > 16016, therefore the image 1 is needed here, but the space budget is used up since image 0 is still in the cache.
  • This then triggers the check here
    if num_encoder_tokens > encoder_budget:
    # The encoder budget is exhausted. We can only schedule the
    # decoder tokens up until the encoder input.
    # NOTE(woosuk): We assume that the encoder tokens should be
    # processed altogether, as the encoder usually uses
    # bidirectional attention.
    num_new_tokens = start_pos - num_computed_tokens
    break
    and set num_new_tokens to 7333 (start_pos) - 16016 (num_computed_tokens) = -8683, and then crash the server as we cannot have non-positive num_new_tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both cases would need this.

Also for this comment

    # This can be improved when we have a global encoder cache that does
    # not associate items to request id only.

This cannot address the issue fundamentally, because we also need to guarantee the item is always available in the encoder cache when we schedule the request. For example, an item used by request A and request B. Request A has finished so prefix and mm items are cached. However, due to encoder cache budget, one item in request A is evicted before request B comes. This would result in the same problem.

I guess this can somehow be avoided if we could guarantee all prefix cached mm items are always available in encoder cache as well, but fundamentally this has to be solved by supporting num_tokens=0 in the model runner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but fundamentally this has to be solved by supporting num_tokens=0 in the model runner.

That's a good callout! I've adjusted the comment accordingly.

# Number of items needed cannot be bigger than max number of running
# requests * max number of multimodal items per request.
max_mm_items_per_req = max(
MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values())

num_items = min(num_items, max_num_reqs * max_mm_items_per_req)
encoder_cache_budget = num_items * max_tokens_per_mm_item

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_cache_budget, num_items, modality)

return encoder_cache_budget
24 changes: 16 additions & 8 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_cache_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -24,6 +25,7 @@ class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
Expand Down Expand Up @@ -68,16 +70,22 @@ def __init__(
self.running_reqs_data: Dict[str, RunningRequestData] = {}

# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_cache_budget = compute_encoder_cache_budget(
model_config, scheduler_config)

# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.max_num_encoder_input_tokens = encoder_cache_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)
cache_size=encoder_cache_budget)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ def __init__(
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
)

self._last_logging_time = time.time()

Expand Down
55 changes: 15 additions & 40 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_cache_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -87,8 +88,8 @@ def __init__(
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
self.encoder_cache_budget = compute_encoder_cache_budget(
self.model_config, self.scheduler_config)

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -720,53 +721,27 @@ def profile_run(self) -> None:
]

# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
# TODO: handle encoder-decoder models once we support them.
if self.is_multimodal_model and self.encoder_cache_budget > 0:

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
self.model_config)

dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item

# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")

# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())

# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req

max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)
max_num_mm_items = self.encoder_cache_budget // max_tokens_per_mm_item # noqa: E501

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is just a reordering for better readability.


# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
Expand Down
Loading