Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Remove hard-dependencies of Speculative decode to CUDA workers #10587

Merged
merged 5 commits into from
Nov 27, 2024

Conversation

xuechendi
Copy link
Contributor

@xuechendi xuechendi commented Nov 23, 2024

This PR is mainly target to remove hard dependency for CUDA in speculative decoding

Initiated this proposal in #10131

Original PR => use current_platform to provide Dynamic WorkerCls as BaseClass for spec_decode workers
This PR provides second solution => Use WorkerWrapperBase and ModerRunnerWrapperBase to lazily create spec decoder workers.

Test scripts:

pytest -v tests/spec_decode/e2e/test_mlp_correctness.py::test_mlp_e2e_greedy_correctness 
pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
pytest -v tests/spec_decode/e2e/test_ngram_correctness.py::test_ngram_e2e_greedy_correctness
pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness

Copy link

mergify bot commented Nov 23, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 23, 2024
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@xuechendi
Copy link
Contributor Author

@comaniac @youkaichao , based on #10555 , I submitted an alternative impl to derive from WorkerWrapperBase and ModelRunnerWrapperBase.

Please check if this makes sense to you.

@youkaichao
Copy link
Member

Looks much better than previous one.

I'll hand it over to @comaniac

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

Comment on lines +173 to +175
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in non-CUDA platforms this key will be missing. Is this expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , yes, TP1DraftModelRunner is a wrapper on ModelRunner, this wrapper will add extra data processing for
previous_hidden_states and other inputs before calling actual model_runner.execute_model()

In CPU enabling, we don't need extra work on previous_hidden_states to make it work, so CPUTP1DraftModelRunner is not necessary. However, in order to improve Spec Decode performance on other HWs, adding HWTP1DraftModelRunner might still be useful.

@xuechendi
Copy link
Contributor Author

@comaniac , I have verified this PR on GPU and CPU in my local env.

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 25, 2024
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.actual_worker_cls = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename the field to sd_worker_cls ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a field sd_worker_cls to ParallelConfig , and default to None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks for the hard work!

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines +12 to +13
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to put this check in the module import? It would be better if this was only lazy imported within sampler_output

Comment on lines +24 to +25
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto but within create_worker

@youkaichao youkaichao merged commit 0a71900 into vllm-project:main Nov 27, 2024
49 of 51 checks passed
@youkaichao
Copy link
Member

@xuechendi please address @mgoin 's comments, maybe in a followup PR.

@xuechendi
Copy link
Contributor Author

@xuechendi please address @mgoin 's comments, maybe in a followup PR.

Got it, will do

afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
@xuechendi xuechendi deleted the spec_decode_generalize branch December 19, 2024 21:49
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants