Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Autotune encoder cache budget #11895

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Jan 9, 2025

This PR refactors the logic for multimodal encoder cache budget from hardcoded values to be autotuned from the underlying model & scheduler configurations.

Signed-off-by: Roger Wang <[email protected]>
Copy link

github-actions bot commented Jan 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
vllm/v1/core/encoder_cache_manager.py Show resolved Hide resolved
Comment on lines 90 to 100
# NOTE: We need the encoder cache to be able to compute & hold ONE
# ADDITIONAL multimodal item, and is required only when:
# - Two requests in the current batch share the same prefix with such item
# as part of the prefix.
# - AND the prefix length is divisible by the block size, triggering the
# recomputation of the last block.
# - AND the part of the embeddings of the item is in this last block.

# This can be improved when we have a global encoder cache that does
# not associate items to request id only.
num_items += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only applicable to the else block?

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applicable to all cases, and is in fact in the if block is how I discovered this issue that wasn't addressed prior to this PR.

Here's a concrete example:
Suppose the max_num_batched_token=8192 and two identical requests have length 16032 after processing, and their image with start_index=7333 and end_index=16020 (thus length=8687), and suppose encoder_cache_budget=8687 for the sake of showing how the issue will happen when we don't add budget for one additional item.

Time 0: Request 0 gets scheduled for 8192 tokens. Since start_index=7333 < 8192 < end_index=16020 and cache is empty, image 0 gets processed and the result embeddings is cached, thus all space budget is used up.

Time 1:

  • Request 0 gets scheduled for the rest 16032 - 8192 = 7840 tokens. An important note here is that scheduling is synchronous, therefore we treat these tokens are already computed once scheduled.
  • The issue happens when we try to schedule Request 1 since there is still space in the batch. Because they're identical, the number of computed tokens for Request 1 is then 16032 from the get go, which triggers a recompute for the last 16 tokens. However, note that the image ends at 16020 > 16016, therefore the image 1 is needed here, but the space budget is used up since image 0 is still in the cache.
  • This then triggers the check here
    if num_encoder_tokens > encoder_budget:
    # The encoder budget is exhausted. We can only schedule the
    # decoder tokens up until the encoder input.
    # NOTE(woosuk): We assume that the encoder tokens should be
    # processed altogether, as the encoder usually uses
    # bidirectional attention.
    num_new_tokens = start_pos - num_computed_tokens
    break
    and set num_new_tokens to 7333 (start_pos) - 16016 (num_computed_tokens) = -8683, and then crash the server as we cannot have non-positive num_new_tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both cases would need this.

Also for this comment

    # This can be improved when we have a global encoder cache that does
    # not associate items to request id only.

This cannot address the issue fundamentally, because we also need to guarantee the item is always available in the encoder cache when we schedule the request. For example, an item used by request A and request B. Request A has finished so prefix and mm items are cached. However, due to encoder cache budget, one item in request A is evicted before request B comes. This would result in the same problem.

I guess this can somehow be avoided if we could guarantee all prefix cached mm items are always available in encoder cache as well, but fundamentally this has to be solved by supporting num_tokens=0 in the model runner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but fundamentally this has to be solved by supporting num_tokens=0 in the model runner.

That's a good callout! I've adjusted the comment accordingly.

vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/core/encoder_cache_manager.py Outdated Show resolved Hide resolved
Comment on lines 80 to 88
# In case that the biggest possible multimodal item takes space more
# than the batch size, then it needs to be cached and chunk prefilled.
if max_tokens_per_mm_item > max_num_batched_tokens:
num_items = 1

# In case that the biggest possible multimodal item takes space less
# the batch size, then all items will be full prefilled except one.
else:
num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems a bit confusing to me. I tried to rephrase base on my understanding but please help clarify:

num_items == 1:

# The biggest possible multimodal item cannot be prefilled in a batch,
# so it must be cached and chunked prefill.

num_items > 1:

# A batch can cover all (except the last one) multimodal items.

Meanwhile, I don't fully understand what you meant by "cached" and "chunked prefill" tho. I suppose they are orthogonal to the number of items?

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will clarify this. During profiling we always take the worst case (i.e requests will all have the biggest possible multimodal item), so what I meant by "cached" and "chunked prefill" is that each multimodal item will always be needed in two engine steps, since the batch cannot cover the entirety of it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified via 2a4b1d5

Comment on lines 90 to 100
# NOTE: We need the encoder cache to be able to compute & hold ONE
# ADDITIONAL multimodal item, and is required only when:
# - Two requests in the current batch share the same prefix with such item
# as part of the prefix.
# - AND the prefix length is divisible by the block size, triggering the
# recomputation of the last block.
# - AND the part of the embeddings of the item is in this last block.

# This can be improved when we have a global encoder cache that does
# not associate items to request id only.
num_items += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both cases would need this.

Also for this comment

    # This can be improved when we have a global encoder cache that does
    # not associate items to request id only.

This cannot address the issue fundamentally, because we also need to guarantee the item is always available in the encoder cache when we schedule the request. For example, an item used by request A and request B. Request A has finished so prefix and mm items are cached. However, due to encoder cache budget, one item in request A is evicted before request B comes. This would result in the same problem.

I guess this can somehow be avoided if we could guarantee all prefix cached mm items are always available in encoder cache as well, but fundamentally this has to be solved by supporting num_tokens=0 in the model runner.

# requests * max number of multimodal items per request.
max_mm_items_per_req = max(
MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values())
num_items = min(num_items, max_num_reqs * max_mm_items_per_req)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be better to have a warning if num_items < max_num_reqs * max_mm_items_per_req, because it means we are overriding user configurations.

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're actually not overriding user configurations because the user doesn't get to specify the encoder cache budget (neither they could do before this PR since it's hardcoded).

What we are doing here is simply to have the encoder budget calculation to respect max_num_reqs (Consider when max_num_reqs=1, the encoder cache will then only need to be able to compute & hold for one request every step)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I thought MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config) can be configured by users using mm_liimt? Is that a different config?

Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - that's configured by user, but we're not overwriting this value.

Also keep in mind that this limit technically speaking is only still needed today because in V0 we don't support chunked prefill for multimodal models, so the sequence (and thus all multimodal items in it) needs to be prefilled as a whole, therefore profiling will need to be done accordingly.

In V1 chunked prefill is by nature, so this limit doesn't affect how we schedule requests at all, and only affect how engine profiling is done at this specific check, so technically we don't need it anymore, but we still want to keep this argument just so that users have a way to set a cap themselves.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the clarification!

Signed-off-by: Roger Wang <[email protected]>
Comment on lines 738 to 744
# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is just a reordering for better readability.

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Should we have a unit test for this feature?

@ywang96
Copy link
Member Author

ywang96 commented Jan 10, 2025

LGTM. Should we have a unit test for this feature?

Yea - I want @WoosukKwon to take a look at this design before I spend some time on writing the test for it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants