Skip to content

Commit

Permalink
[Model] Support Pixtral models in the HF Transformers format (vllm-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Oct 18, 2024
1 parent 67a7e5e commit 3921a2f
Show file tree
Hide file tree
Showing 7 changed files with 503 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ Text Generation
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- T + I\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
- :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc.
-
- ✅︎
* - :code:`QWenLMHeadModel`
Expand Down
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids


# Pixtral HF-format
def run_pixtral_hf(question: str, modality: str):
assert modality == "image"

model_name = "mistral-community/pixtral-12b"

llm = LLM(
model=model_name,
max_model_len=8192,
)

prompt = f"<s>[INST]{question}\n[IMG][/INST]"
stop_token_ids = None
return llm, prompt, stop_token_ids


# LLama 3.2
def run_mllama(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -347,6 +363,7 @@ def run_glm4v(question: str, modality: str):
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"pixtral_hf": run_pixtral_hf,
"mllama": run_mllama,
"molmo": run_molmo,
"glm4v": run_glm4v,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"silu":
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
})
Expand Down
74 changes: 70 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -22,6 +23,10 @@
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
Expand All @@ -31,8 +36,13 @@

class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""


class LlavaImageEmbeddingInputs(TypedDict):
Expand Down Expand Up @@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,

mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)

mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand All @@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand All @@ -210,6 +245,15 @@ def __init__(self,
self.config = config
self.multimodal_config = multimodal_config

# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
Expand Down Expand Up @@ -246,6 +290,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
Expand All @@ -256,6 +301,26 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]

# TODO: Add validation based on image_sizes
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
Expand Down Expand Up @@ -286,7 +351,8 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

Expand Down
Loading

0 comments on commit 3921a2f

Please sign in to comment.