Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Idefics3 support #9767

Merged
merged 29 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dc48857
Init
jeejeelee Oct 28, 2024
f7664d1
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Oct 29, 2024
e2e9811
Modify model code
jeejeelee Oct 29, 2024
b605765
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Oct 29, 2024
1200475
Update code
B-201 Oct 31, 2024
41f8d76
Merge branch 'support-idefics3' of https://github.com/jeejeelee/vllm …
B-201 Oct 31, 2024
074545c
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Oct 31, 2024
6cd4bfa
Fix code format
jeejeelee Oct 31, 2024
9cb4e32
Update code
B-201 Oct 31, 2024
6620b7c
Delete dirty code
jeejeelee Oct 31, 2024
367f31e
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 1, 2024
63265c4
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 1, 2024
dc409c1
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 2, 2024
81168cf
Support multi-image
B-201 Nov 3, 2024
b353561
Add unit test
B-201 Nov 4, 2024
9e1d3cf
Disable yapf
B-201 Nov 4, 2024
835d9ba
Fix code format
B-201 Nov 4, 2024
56f0572
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 5, 2024
469a96e
Fix format
jeejeelee Nov 5, 2024
3285b91
Fix code format
B-201 Nov 5, 2024
d632fa8
Update docs
jeejeelee Nov 5, 2024
84d7428
Update docs
jeejeelee Nov 5, 2024
ab8eb7c
Integrate test code
B-201 Nov 5, 2024
041c034
Fix code format
B-201 Nov 5, 2024
9a562fc
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 6, 2024
ba7e85c
Update example & test
B-201 Nov 6, 2024
10415d3
Fix code format
B-201 Nov 6, 2024
e5bb291
Fix model registry
jeejeelee Nov 6, 2024
6f7fd31
Merge branch 'vllm-project:main' into support-idefics3
jeejeelee Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,22 @@ def run_glm4v(question: str, modality: str):
return llm, prompt, stop_token_ids


# Idefics3-8B-Llama3
def run_idefics3(question: str, modality: str):
assert modality == "image"
model_name = ("HuggingFaceM4/Idefics3-8B-Llama3")

llm = LLM(model=model_name,
max_model_len=2048,
max_num_seqs=2,
enforce_eager=True)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
)
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -372,6 +388,7 @@ def run_glm4v(question: str, modality: str):
"mllama": run_mllama,
"molmo": run_molmo,
"glm4v": run_glm4v,
"idefics3": run_idefics3,
}


Expand Down
25 changes: 24 additions & 1 deletion vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
"""PyTorch Idefics2 model."""

from typing import Optional
from typing import Optional, Iterable, Tuple

import torch
from torch import nn
Expand All @@ -26,6 +26,7 @@
from xformers import ops as xops

from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -331,3 +332,25 @@ def forward(
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading
Loading