Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platforms] absorb worker cls difference into platforms folder #10555

Merged
merged 12 commits into from
Nov 22, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Nov 21, 2024

part of #9268

every platforms should specify the worker class inside their own code.

in addition, the default case is auto , and this allows users to specify custom classes for extensibility (which i'm working on as part of an RLHF support).

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao youkaichao changed the title [platforms] refactor worker class specification [platforms] absorb worker cls difference into platforms folder Nov 21, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 21, 2024
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@njhill
Copy link
Member

njhill commented Nov 22, 2024

Thanks @youkaichao! I am also reviewing this now...

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao I think this change is probably ok if we're going to continue iterating on the architecture, but it doesn't seem like the right end design to me.

I've already been thinking that we need to overhaul the executor hierarchy/abstractions a bit (hoping we can do this as part of v1), and that may be part of why this doesn't sit right.

In particular we have one or more executor class(es) per platform so this is in some way abstracting over the platforms. But then there's a parallel Platform abstraction. I think we should get rid of the platform-specific executors. I.e. no ray_*pu_executor.pys. Possibly the platform-specific aspects could be a mix-in.

It also feels a bit wrong to me to update the config objects in-place since these might be created/ "owned" by the user.

Also wdyt about changing this field to be custom_worker_cls: Optional[str] = None? Since it's a very specialized option I would consider it more overriding vLLMs native behaviour so it's not so much an "auto" thing.

vllm/executor/cpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/platforms/cuda.py Show resolved Hide resolved
vllm/executor/ray_hpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_tpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_tpu_executor.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member Author

I think we should get rid of the platform-specific executors. I.e. no ray_*pu_executor.pys. Possibly the platform-specific aspects could be a mix-in.

I strongly agree. We should only have {single worker executor, ray executor, mp executor} , and they should be able to initialize various workers.

I think this PR is a tiny step towards that direction.

@youkaichao
Copy link
Member Author

It also feels a bit wrong to me to update the config objects in-place since these might be created/ "owned" by the user.

what's the concern here? the problem is we don't have a scratch space for the model to store some information, and right now we use vllm_config a lot, to store per-model basis global variables.

@njhill
Copy link
Member

njhill commented Nov 22, 2024

It also feels a bit wrong to me to update the config objects in-place since these might be created/ "owned" by the user.

what's the concern here? the problem is we don't have a scratch space for the model to store some information, and right now we use vllm_config a lot, to store per-model basis global variables.

Yeah it's kind of a more general point than this PR ... like you say two things are being conflated a bit. Ideally the config should be treated as read-only I think (could have been passed in by the user) and model global mutable state should be separate.

Perhaps something like:

@dataclass
class ModelState:
    config: VllmConfig
    
    # ...

@youkaichao
Copy link
Member Author

class ModelState:
config: VllmConfig

# ...

@njhill this makes sense, but you need to figure out where to store it. All the classes, including engine, executor, worker, model runner, model needs to access it. And you cannot use module-level global state because people can create multiple LLM object in the same process.

@youkaichao
Copy link
Member Author

wdyt about changing this field to be custom_worker_cls: Optional[str] = None?

I use a single worker_cls because it will be printed out during init.

Printing:

worker_cls = "vllm.worker.worker.Worker"
custom_worker_cls = "whatever.user.provide"

looks less clear than:

worker_cls = "auto"

and

worker_cls = "whatever.user.provide"

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can address the other suggestions as follow-on refactoring

@youkaichao youkaichao merged commit a111d01 into vllm-project:main Nov 22, 2024
48 of 51 checks passed
@youkaichao youkaichao deleted the worker_cls branch November 22, 2024 05:00
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…project#10555)

Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
hijkzzz pushed a commit to OpenRLHF/OpenRLHF that referenced this pull request Nov 28, 2024
`worker_module_name` and `worker_class_name` is no longer supported.

Refer to vllm-project/vllm#10555

Signed-off-by: Hollow Man <[email protected]>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 28, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants