From a0f7d53beb176034546c6deb328a3d49e94e1f6d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 19 Dec 2024 21:22:00 +0800 Subject: [PATCH] [Bugfix] Cleanup Pixtral HF code (#11333) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/pixtral.py | 155 +++----------------------- 1 file changed, 14 insertions(+), 141 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f05ea195e043d..6676dd16e005f 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -10,12 +10,12 @@ from PIL import Image from transformers import PixtralVisionConfig from transformers.models.pixtral.image_processing_pixtral import ( - _num_image_tokens) + _num_image_tokens as _get_pixtral_hf_num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) from vllm.attention import AttentionMetadata -from vllm.config import ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -27,7 +27,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange @@ -35,11 +34,10 @@ consecutive_placeholder_ranges, resolve_visual_encoder_outputs) from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.processor import cached_get_processor -from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import init_vllm_registered_model, maybe_prefix +from .utils import (init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) try: from xformers import ops as xops @@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: return grid_length * grid_length -def get_max_pixtral_hf_image_feature_size( - hf_config: PixtralVisionConfig) -> int: - return get_pixtral_hf_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) - - def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: - return get_max_pixtral_hf_image_feature_size(hf_config) + grid_length = get_pixtral_hf_patch_grid_length( + image_size=hf_config.image_size, + patch_size=hf_config.patch_size, + ) - -def dummy_seq_data_for_pixtral_hf( - hf_config: PixtralVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, - mm_key: str = "image"): - if image_feature_size_override is None: - image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - mm_key: - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } + # Consider the image_break_token + return (grid_length + 1) * grid_length def dummy_image_for_pixtral_hf( @@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, image_width = int(numpy.ceil(image_width / ratio)) image_height = int(numpy.ceil(image_height / ratio)) - num_height_tokens, num_width_tokens = _num_image_tokens( - (image_height, image_width), (patch_height, patch_width)) + num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( + (image_height, image_width), + (patch_height, patch_width), + ) return num_width_tokens, num_height_tokens -def input_processor_for_pixtral_hf( - model_config: ModelConfig, - hf_config: PixtralVisionConfig, - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[Union[int, List[int]]] = None, -) -> DecoderOnlyInputs: - assert image_feature_size_override is None, ( - "image_feature_size_override is not supported for Pixtral") - - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - processor = cached_get_processor(model_config.model) - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_data = [image_data] - elif not is_list_of(image_data, Image.Image): - raise TypeError(f"Invalid image type: {type(image_data)}") - - new_prompt = inputs.get("prompt") - new_token_ids = inputs["prompt_token_ids"] - - image_token = processor.image_token - image_break_token = processor.image_break_token - image_end_token = processor.image_end_token - - # Update new_prompt if present - if new_prompt: - parts = new_prompt.split(image_token) - assert len(parts) - 1 == len(image_data) - new_parts = [parts[0]] # Start with the part before any image tokens - - for image, next_part in zip(image_data, parts[1:]): - w, h = image.size - (num_width_tokens, - num_height_tokens) = get_pixtral_hf_image_feature_size( - hf_config, image_width=w, image_height=h) - - replace_tokens = [image_token] * num_width_tokens + [ - image_break_token - ] - replace_tokens = replace_tokens * num_height_tokens - replace_tokens[-1] = image_end_token - - new_parts.append("".join(replace_tokens)) - new_parts.append(next_part) - - new_prompt = "".join(new_parts) - - # Update new_token_ids - convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids - image_token_id = convert_tokens_to_ids(image_token) - image_break_id = convert_tokens_to_ids(image_break_token) - image_end_id = convert_tokens_to_ids(image_end_token) - placeholder_token_id = -999 - # Find all image token indices at once - placeholder_indices = [ - idx for idx, token_id in enumerate(new_token_ids) - if token_id == image_token_id - ] - assert len(placeholder_indices) == len(image_data) - replace_tokens_list = [] - for placeholder_idx, image in zip(placeholder_indices, image_data): - new_token_ids[placeholder_idx] = placeholder_token_id - - w, h = image.size - (num_width_tokens, - num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config, - image_width=w, - image_height=h) - - replace_tokens = [image_token_id] * num_width_tokens + [image_break_id] - replace_tokens = replace_tokens * num_height_tokens - replace_tokens[-1] = image_end_id - replace_tokens_list.append(replace_tokens) - - reverse_offsets: List[int] = [] - # Backward iteration for replacement without affecting known indices - for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices), - reversed(replace_tokens_list)): - reverse_offsets.append( - len(new_token_ids) - placeholder_idx + len(replace_tokens)) - new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens - - placeholder_ranges: List[PlaceholderRange] = [] - for reverse_offset, replace_tokens in zip(reversed(reverse_offsets), - replace_tokens_list): - placeholder_ranges.append( - PlaceholderRange( - offset=len(new_token_ids) - reverse_offset, - length=len(replace_tokens), - )) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) - - class PixtralHFMLP(nn.Module): def __init__(