-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Remove hard-dependencies of Speculative decode to CUDA workers #10587
[Kernel] Remove hard-dependencies of Speculative decode to CUDA workers #10587
Conversation
This pull request has merge conflicts that must be resolved before it can be |
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
8579790
to
9b49363
Compare
@comaniac @youkaichao , based on #10555 , I submitted an alternative impl to derive from WorkerWrapperBase and ModelRunnerWrapperBase. Please check if this makes sense to you. |
Looks much better than previous one. I'll hand it over to @comaniac |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
if current_platform.is_cuda_alike(): | ||
draft_worker_kwargs[ | ||
"model_runner_cls"] = TP1DraftModelRunner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in non-CUDA platforms this key will be missing. Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@comaniac , yes, TP1DraftModelRunner
is a wrapper on ModelRunner, this wrapper will add extra data processing for
previous_hidden_states and other inputs before calling actual model_runner.execute_model()
In CPU enabling, we don't need extra work on previous_hidden_states to make it work, so CPUTP1DraftModelRunner
is not necessary. However, in order to improve Spec Decode performance on other HWs, adding HWTP1DraftModelRunner
might still be useful.
@comaniac , I have verified this PR on GPU and CPU in my local env. |
vllm/platforms/cpu.py
Outdated
if vllm_config.speculative_config: | ||
parallel_config.worker_cls = \ | ||
"vllm.spec_decode.spec_decode_worker.create_spec_worker" | ||
parallel_config.actual_worker_cls = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rename the field to sd_worker_cls
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a field sd_worker_cls
to ParallelConfig
, and default to None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now, thanks for the hard work!
This pull request has merge conflicts that must be resolved before it can be |
51e26f2
to
933e5e0
Compare
933e5e0
to
f58440a
Compare
if current_platform.is_cuda_alike(): | ||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to put this check in the module import? It would be better if this was only lazy imported within sampler_output
if current_platform.is_cuda_alike(): | ||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto but within create_worker
@xuechendi please address @mgoin 's comments, maybe in a followup PR. |
Got it, will do |
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
…project#10587) Signed-off-by: Chendi Xue <[email protected]> Signed-off-by: Andrew Feldman <[email protected]>
…project#10587) Signed-off-by: Chendi Xue <[email protected]>
…project#10587) Signed-off-by: Chendi Xue <[email protected]>
This PR is mainly target to remove hard dependency for CUDA in speculative decoding
Initiated this proposal in #10131
Original PR => use current_platform to provide Dynamic WorkerCls as BaseClass for spec_decode workers
This PR provides second solution => Use WorkerWrapperBase and ModerRunnerWrapperBase to lazily create spec decoder workers.
Test scripts: