Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][VLM] V1 support for selected single-image models. #11632

Merged
merged 43 commits into from
Dec 31, 2024

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Dec 30, 2024

This PR main adds V1 support for a number of single-image models since the code changes are sizable enough for a review.

To summarize, this PR:

  • Adds V1 support via merged multi-modal processor for aria, blip2, chameleon, fuyu.
  • Consolidates common code related to dummy data generation for merged multi-modal processor.
  • Fixes a few issues for aria (missing dummy data, incomplete input mapper, etc).
  • Adds a small optimization for llava-next to run batched projection versus projections for individual images.
  • Fixes some type errors in Pixtral model file.
  • Fixes V1 encoder profiling to correctly respect max_num_seqs and limits_mm_per_prompt.

All models have been tested with offline_inference_vision_language.py on both V0 and V1.

ywang96 and others added 9 commits December 29, 2024 13:23
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Dec 30, 2024
Signed-off-by: Roger Wang <[email protected]>
@@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
-
- ✅︎
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llava-next was already supported on V1 so this is just a doc update.

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Comment on lines -405 to +415
repeat_count=repeat_count[placeholder_token_idx],
repeat_count=curr_repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
offset = len(new_token_ids)
if pad_token_left is not None:
offset += 1
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
"offset": offset,
"length": curr_repeat_count,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously counting padding tokens as part of the placeholder tokens, which is not accurate.

@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
@INPUT_REGISTRY.register_input_processor(input_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_aria)
Copy link
Member Author

@ywang96 ywang96 Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code for dummy data generation was entirely missing and I'm not sure why, so I added in this PR since it's required for V1. cc @xffxff who originally added this model

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Comment on lines 561 to 564
image_size2tokens = {
int(math.sqrt(k) * hf_config.vision_config.patch_size): v
for k, v in hf_config.projector_patch_to_query_dict.items()
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this is a fixed value, perhaps we can move it to AriaMoELMConfig initialization in vllm/transformers_utils/configs/aria.py to avoid repeat calculation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I can do that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we actually don't need this calculation int(math.sqrt(k) * hf_config.vision_config.patch_size at all since we only care about the values here, so I will just simplify this.

@DarkLight1337 DarkLight1337 self-requested a review as a code owner December 30, 2024 16:48
DarkLight1337 and others added 23 commits December 30, 2024 17:48
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified that the models work on both V0 and V1. Let's see if the tests pass.

@Isotr0py Isotr0py enabled auto-merge (squash) December 31, 2024 17:21
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 31, 2024
@Isotr0py Isotr0py merged commit e7c7c5e into vllm-project:main Dec 31, 2024
56 checks passed
bjmsong pushed a commit to bjmsong/vllm that referenced this pull request Jan 2, 2025
…11632)

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: bjmsong <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants