Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix for Spec model TP + Chunked Prefill #10232

Merged
merged 19 commits into from
Nov 26, 2024

Conversation

andoorve
Copy link
Collaborator

@andoorve andoorve commented Nov 11, 2024

Fixes the issue I raised here: #9291. Chunked prefill + spec decoding + TP on the spec model fails for me with KeyError: 'num_seq_groups' when I used the following command.

vllm serve meta-llama/Llama-3.1-405B-Instruct-FP8 --tensor-parallel-size 8 --max-num-seqs 32  --block-size 32  --speculative-model meta-llama/Llama-3.1-8B-Instruct  --num-speculative-tokens 8 --gpu-memory-utilization  0.98 --use-v2-block-manager --distributed-executor-backend ray --enable-chunked-prefill --max-num-batched-tokens 4096 --max-model-len 32768

This fix makes it so the proposer only runs once on the non driver processes when no_spec is on to match the driver.

One thing that is still confusing is I would expect this issue to show up without chunked prefill as well. Unsure why it doesn't show up in that case. Would be good to get an opinion from someone more familiar with spec decode path.

FIX #10276

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@andoorve
Copy link
Collaborator Author

@NickLucche @sroy745

@andoorve andoorve force-pushed the andoorve/spec-fix-chunked branch from f1ff8aa to 6863d1f Compare November 11, 2024 20:33
@sroy745
Copy link
Collaborator

sroy745 commented Nov 12, 2024

Hi,
Thanks for the fix.

Based on our DM discussions my understanding is that the main issue seems to be that even when all the sequences are prompts (only prefill) we have num_lookahead_slots as > 0. I added some logs in this pr (https://github.com/vllm-project/vllm/pull/10186/files) and the output if I run with and without chunked-prefill enabled is the following

Without chunked prefill

num_lookahead_slots in _schedule_default 0
prefills in _schedule_default_prefill 1
decodes in _schedule_default_prefill 0

With chunked prefill

num_lookahead_slots in _schedule_chunked_prefill 4
prefills in _schedule_chunked_prefill 1
decodes in _schedule_chunked_prefill 0

In without chunked-prefill run if it is a complete prefill batch num_lookahead_slots is set to 0 but it is not the case for the chunked-prefill run. I wonder if we should fix __schedule_chunked_prefill to set num_lookahead_slots to 0 if it is a complete prefill batch and add an assertion in spec_decode_worker for that?

@NickLucche
Copy link
Contributor

I wonder if we should fix __schedule_chunked_prefill to set num_lookahead_slots to 0 if it is a complete prefill batch and add an assertion in spec_decode_worker for that

I like that, I think this would be more in line with the expected semantics (no speculation on prefills-only).

Thanks for looking into it!!

@andoorve andoorve force-pushed the andoorve/spec-fix-chunked branch 2 times, most recently from d5f6392 to 10f69a4 Compare November 12, 2024 19:16
@andoorve
Copy link
Collaborator Author

As discussed over DM, moving this up to the scheduler level is a cleaner fix, moved the check there. @NickLucche @sroy745 PTAL if this logic looks good, then I'll mark this ready!

@andoorve andoorve self-assigned this Nov 13, 2024
Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Added a couple of comments about tests. Logic LGTM

Thanks

vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
num_lookahead_slots=num_lookahead_slots,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath could you also review this part to see if this will break multi-step scheduling with chunked prefill?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Tag. I believe it will affect performance.
multi-step + chunked-prefill allows for having look-ahead slots even when all the sequences are prefills. The sequences are processed as prefills in step 1 and are processed as decodes in steps 2 - n.
Setting the lookahead_slots to 0, will force single stepping for the all-prefills case. I can get some profiles.

@andoorve is there a way to make this update only if spec decode is enabled ? I believe that would be safer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @varun-sundar-rabindranath I think that should be possible, thanks for the feedback! Let me see how we can do that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath @comaniac Can you check whether this condition makes sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @andoorve - The condition looks good 👍

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 13, 2024
@andoorve andoorve added the bug Something isn't working label Nov 15, 2024
@andoorve andoorve force-pushed the andoorve/spec-fix-chunked branch from 5893379 to 0b300d2 Compare November 15, 2024 03:06
@andoorve andoorve marked this pull request as ready for review November 19, 2024 19:28
@andoorve
Copy link
Collaborator Author

Waiting on reviews from @sroy745 @varun-sundar-rabindranath @comaniac

@@ -653,6 +659,9 @@ def _run_non_driver_rank(self) -> bool:

if not data["no_spec"]:
self.scorer_worker.execute_model()
data = broadcast_tensor_dict(src=self._driver_rank)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra broadcast is not ideal. But since this is a bugfix and should be low impact perf wise may have to live with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi as discussed offline I will try to run a benchmark with and without this change to try and measure the impact of this if any.

Copy link
Collaborator Author

@andoorve andoorve Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it with Llama 405B + 8B spec model @ 32k sequence length w/ speculative_tokens = 8 and TP 8 on H100.

The results below are for "prompt_tokens":30,"total_tokens":1054,"completion_tokens":1024 averaged over 3 runs.
Before Change: 9.321 s
After Change: 9.268 s

Speedup: 99.4%

Therefore, it slows down the regular speculative path very slightly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @njhill - Can we consider this acceptable as it is a necessary bugfix?

Copy link
Collaborator

@sroy745 sroy745 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for running the comparison. Following up on our conversation yesterday I was wondering if we can piggyback on the existing broadcast that we do. For the cp + sd case we need to broadcast the additional information that there are prefills in the speculative batch and if so run the proposer worker. For that I was wondering if we can check for prompts in the input batch and based on that we set the 'run_spec_proposer' in the initial broadcast. Looking through the proposer code I think its doing something similar when deciding which sequences to use for decoding vs prefill (https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/vllm/spec_decode/top1_proposer.py?L131) .

@NickLucche can we do the is_prompt check and use that to set the 'run_spec_proposer' field in the initial broadcast itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's sufficient to check for only prompts specifically. It doesn't go down that path when we try a single request scenario.

Copy link
Collaborator

@sroy745 sroy745 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look. Can you please share some more details on what the issue is for the single request case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we only needed to add this extra broadcast after trying multiple requests. When we were simply sending single requests, this broadcast never came into play at all. So the condition is not simply "all prompts".

@@ -413,6 +413,39 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots):
assert out.num_batched_tokens == max_num_batched_tokens


def test_chunked_prefill_spec_prefill():
"""Verify preempt works with chunked prefill requests"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - please update the comment to "Verify that the num_lookahead_slots is set appropriately for an all prefill batch depending on whether multi-step scheduling is enabled or not"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
assert out.num_lookahead_slots == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we can parameterize this test to run for both multi_step = True/False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with smaller tp for draft models.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - Verify spec decode works well with draft models for tp > 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this, thanks for the catch

@@ -653,6 +659,9 @@ def _run_non_driver_rank(self) -> bool:

if not data["no_spec"]:
self.scorer_worker.execute_model()
data = broadcast_tensor_dict(src=self._driver_rank)
Copy link
Collaborator

@sroy745 sroy745 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for running the comparison. Following up on our conversation yesterday I was wondering if we can piggyback on the existing broadcast that we do. For the cp + sd case we need to broadcast the additional information that there are prefills in the speculative batch and if so run the proposer worker. For that I was wondering if we can check for prompts in the input batch and based on that we set the 'run_spec_proposer' in the initial broadcast. Looking through the proposer code I think its doing something similar when deciding which sequences to use for decoding vs prefill (https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/vllm/spec_decode/top1_proposer.py?L131) .

@NickLucche can we do the is_prompt check and use that to set the 'run_spec_proposer' field in the initial broadcast itself?

# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill=atleast_one_prompt,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we got a sanity check from @NickLucche or someone!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on it, sorry for the late ack

@sroy745 sroy745 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 22, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work @andoorve @sroy745 @NickLucche!

@andoorve
Copy link
Collaborator Author

Rebased, waiting for tests to pass to push

@andoorve
Copy link
Collaborator Author

I think there's a real error here:


[2024-11-25T20:26:10Z]         if all_prompt:
--
  | [2024-11-25T20:26:10Z] >           assert num_lookahead_slots == 0, (
  | [2024-11-25T20:26:10Z]                 "Prompt only runs should have num_lookahead_slots equal to 0. "
  | [2024-11-25T20:26:10Z]                 "This should never happen, please file a bug at "
  | [2024-11-25T20:26:10Z]                 "https://github.com/vllm-project/vllm/issues")
  | [2024-11-25T20:26:10Z] E           AssertionError: Prompt only runs should have num_lookahead_slots equal to 0. This should never happen, please file a bug at https://github.com/vllm-project/vllm/issues

FAILED spec_decode/test_spec_decode_worker.py::test_empty_input_batch[typical_acceptance_sampler-0-5]

@sroy745 maybe this is why the None part was necessary with all_prompt?

andoorve and others added 19 commits November 26, 2024 05:52
This reverts commit 6863d1f.

Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Signed-off-by: andoorve <[email protected]>
Signed-off-by: andoorve <[email protected]>
@andoorve andoorve force-pushed the andoorve/spec-fix-chunked branch from e59bb79 to 01b43aa Compare November 26, 2024 05:52
@NickLucche
Copy link
Contributor

Hey @andoorve thanks for the quick fix on the last hiccup!

I did take a look at test_empty_input_batch but I couldn't find any real use-case where we're sending empty batches to signal or set something.
Were you able to find some? Could be useful for reference to write it here imho.

PS we're still missing the DCO check to get the green lights

@andoorve
Copy link
Collaborator Author

@NickLucche No I didn't find any - I just included that quick fix based on what @sroy745 did previously.

DCO is on one of @sroy745's commits but he wasn't able to get it to work. I think when we squash and merge though it should be good to go and signed off by both.

@njhill Would you mind merging when you get a chance?

@njhill njhill merged commit db66e01 into vllm-project:main Nov 26, 2024
51 checks passed
@andoorve andoorve deleted the andoorve/spec-fix-chunked branch November 26, 2024 17:28
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Nov 26, 2024
Signed-off-by: andoorve <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Co-authored-by: Sourashis Roy <[email protected]>
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
Signed-off-by: andoorve <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Co-authored-by: Sourashis Roy <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: andoorve <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Co-authored-by: Sourashis Roy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Speculative Decoding + TP on Spec Worker + Chunked Prefill does not work.
6 participants