Skip to content

Commit

Permalink
Consolidate dummy data code
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 31, 2024
1 parent bc976a7 commit f79f79a
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 86 deletions.
23 changes: 9 additions & 14 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from PIL import Image
from torch.nn.init import trunc_normal_
from transformers import BatchFeature, PretrainedConfig

Expand Down Expand Up @@ -453,17 +452,6 @@ def get_max_aria_image_tokens(ctx: InputContext):
return max(image_size2tokens.values())


def dummy_image_for_aria(
vision_config: AriaVisionConfig,
num_images: int,
):
max_image_size = vision_config.image_size
image = Image.new("RGB", (max_image_size, max_image_size), color=0)
images = [image] * num_images

return {"image": images}


class AriaMultiModalProcessor(BaseMultiModalProcessor):

def _get_mm_fields_config(
Expand Down Expand Up @@ -501,16 +489,23 @@ def _get_dummy_mm_inputs(
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config()
vision_config: AriaVisionConfig = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

data = dummy_image_for_aria(vision_config, num_images)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_data=mm_data,
)


Expand Down
18 changes: 0 additions & 18 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig

from vllm.attention.layer import MultiHeadAttention
Expand All @@ -28,23 +27,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
return grid_length * grid_length


def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override

image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PromptReplacement)
from vllm.sequence import IntermediateTensors

from .blip import BlipVisionModel, dummy_image_for_blip
from .blip import BlipVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -459,13 +459,20 @@ def _get_dummy_mm_inputs(
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

data = dummy_image_for_blip(vision_config, num_images)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=data,
mm_data=mm_data,
)


Expand Down
29 changes: 8 additions & 21 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
Tuple, TypedDict, Union)

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
ChameleonVQVAEConfig)

Expand Down Expand Up @@ -59,23 +58,6 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH


def dummy_image_for_chameleon(
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = CHAMELEON_CROP_SIZE_WIDTH
height = CHAMELEON_CROP_SIZE_HEIGHT
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override

image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}


class ChameleonMultiModalProcessor(BaseMultiModalProcessor):

def _get_hf_processor(self) -> ChameleonProcessor:
Expand Down Expand Up @@ -114,11 +96,16 @@ def _get_dummy_mm_inputs(
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)

data = dummy_image_for_chameleon(num_images)
mm_data = {
"image":
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
height=CHAMELEON_CROP_SIZE_HEIGHT,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=data,
mm_data=mm_data,
)

def apply(
Expand Down
19 changes: 10 additions & 9 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
from transformers import BatchFeature, FuyuProcessor

from vllm.attention import AttentionMetadata
Expand Down Expand Up @@ -161,15 +159,18 @@ def _get_prompt_replacements(
)
]

def _get_dummy_mm_inputs(self, mm_counts):
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)

image = Image.new(
"RGB",
(MAX_IMAGE_FEATURE_SIZE_WIDTH, MAX_IMAGE_FEATURE_SIZE_HEIGHT),
color=0,
)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
mm_data = {
"image":
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
Expand Down
14 changes: 8 additions & 6 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature
Expand Down Expand Up @@ -181,16 +180,19 @@ def _get_dummy_mm_inputs(
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()

sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)

audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}

return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
prompt_text="<|AUDIO|>" * num_audios,
mm_data=mm_data,
)


Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from PIL import Image
from transformers import BatchFeature
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
Qwen2VLProcessor)
Expand Down Expand Up @@ -891,27 +890,29 @@ def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)

data = {}
image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0)

dummy_image = Image.new("RGB", (resized_width, resized_height),
color=0)
data["image"] = [dummy_image] * num_images
mm_data = {
"image":
self._get_dummy_images(width=resized_width,
height=resized_height,
num_images=num_images)
}

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_data=mm_data,
)


Expand Down
13 changes: 8 additions & 5 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,19 @@ def _get_dummy_mm_inputs(
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()

sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)

audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}

return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
prompt_text="<|audio|>" * num_audios,
mm_data=mm_data,
)


Expand Down
35 changes: 33 additions & 2 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union

import numpy as np
import numpy.typing as npt
import torch
from blake3 import blake3
from PIL.Image import Image
from PIL import Image
from transformers import BatchFeature, ProcessorMixin

from vllm.inputs import DummyData, InputProcessingContext
Expand Down Expand Up @@ -513,7 +514,7 @@ def _serialize_item(self, obj: object) -> bytes:
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image):
if isinstance(obj, Image.Image):
return obj.tobytes()

# Convertible to NumPy arrays
Expand Down Expand Up @@ -1007,6 +1008,36 @@ def apply(
mm_placeholders=mm_placeholders,
)

def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
audio = np.zeros((length, ))
return [audio] * num_audios

def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=0)
return [image] * num_images

def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos

@abstractmethod
def _get_dummy_mm_inputs(
self,
Expand Down

0 comments on commit f79f79a

Please sign in to comment.