Skip to content

Commit

Permalink
qwen2vl
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Dec 1, 2024
1 parent 39dd4f2 commit 6d0df5a
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 74 deletions.
11 changes: 6 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,12 +1232,13 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching

# NOTE: multimodal models support chunked prefill by design,
# thus always enabled in V1.
engine_config.scheduler_config.enable_chunked_prefill = True


@dataclass
Expand Down
160 changes: 97 additions & 63 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope
Expand All @@ -73,7 +75,8 @@
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
make_empty_intermediate_tensors_factory, maybe_prefix,
merge_multimodal_embeddings)

logger = init_logger(__name__)

Expand Down Expand Up @@ -747,6 +750,7 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1, min_pixels=min_pixels,
max_pixels=max_pixels)
print("max_llm_image_tokens", max_llm_image_tokens)
return max_llm_image_tokens


Expand Down Expand Up @@ -803,11 +807,18 @@ def dummy_data_for_qwen2_vl(

dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)

return DummyData(dummy_seqdata, {
dummy_multimodal_data = {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images
}
size_per_image = max_llm_image_tokens // num_images
dummy_mm_placeholders = {
"image":
dummy_image if num_images == 1 else [dummy_image] * num_images
})
consecutive_placeholder_ranges(num_items=num_images,
item_size=size_per_image,
initial_offset=1)
}
return DummyData(dummy_seqdata, dummy_multimodal_data,
dummy_mm_placeholders)


def _get_llm_num_vision_tokens(
Expand Down Expand Up @@ -839,10 +850,11 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens


def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int], min_pixels: Optional[int],
max_pixels: Optional[int]) -> List[int]:
def _expand_pad_tokens(
inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any, prompt_token_ids: List[int],
min_pixels: Optional[int],
max_pixels: Optional[int]) -> Tuple[List[int], List[PlaceholderRange]]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
Expand All @@ -858,6 +870,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
Returns:
List[int]: The list of token IDs for the multi-modal inputs.
List[PlaceholderRange]]: The list of PlaceholderRange objects with
the positions of the pad token in the prompt token ids.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
Expand All @@ -866,6 +880,7 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
placeholder_ranges = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
Expand All @@ -881,9 +896,12 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
placeholder_ranges.append(
PlaceholderRange(offset=len(prompt_token_ids_with_data),
length=num_tokens))
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data
return prompt_token_ids_with_data, placeholder_ranges


def input_processor_for_qwen2_vl(
Expand Down Expand Up @@ -929,7 +947,7 @@ def input_processor_for_qwen2_vl(
prompt_token_ids = inputs["prompt_token_ids"]

# Expand image pad tokens.

multi_modal_placeholders = {}
if image_inputs is not None:
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
Expand All @@ -945,13 +963,18 @@ def input_processor_for_qwen2_vl(

image_counter = 0
pad_token_counter = 0
placeholder_ranges = []
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
grid_thw = image_inputs["image_grid_thw"][image_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
placeholder_ranges.append(
PlaceholderRange(
offset=len(prompt_token_ids_with_image),
length=num_pad_tokens))
prompt_token_ids_with_image.extend([token] *
num_pad_tokens)
image_counter += 1
Expand All @@ -966,14 +989,17 @@ def input_processor_for_qwen2_vl(

prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images,
"image",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
prompt_token_ids, placeholder_ranges = _expand_pad_tokens(
image_inputs,
hf_config.image_token_id,
make_batched_images,
"image",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)

multi_modal_placeholders["image"] = placeholder_ranges

if video_inputs is not None:
if isinstance(video_inputs, dict):
Expand All @@ -990,13 +1016,18 @@ def input_processor_for_qwen2_vl(

video_counter = 0
pad_token_counter = 0
placeholder_ranges = []
for idx, token in enumerate(prompt_token_ids):
if idx in video_indices:
grid_thw = video_inputs["video_grid_thw"][video_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
placeholder_ranges.append(
PlaceholderRange(
offset=len(prompt_token_ids_with_image),
length=num_pad_tokens))
prompt_token_ids_with_video.extend([token] *
num_pad_tokens)
video_counter += 1
Expand All @@ -1011,14 +1042,17 @@ def input_processor_for_qwen2_vl(

prompt_token_ids = prompt_token_ids_with_video
else:
prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos,
"video",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
prompt_token_ids, placeholder_ranges = _expand_pad_tokens(
video_inputs,
hf_config.video_token_id,
make_batched_videos,
"video",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)

multi_modal_placeholders["video"] = placeholder_ranges

prompt = inputs.get("prompt")
if prompt is None:
Expand All @@ -1028,6 +1062,7 @@ def input_processor_for_qwen2_vl(
prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders=multi_modal_placeholders,
)


Expand Down Expand Up @@ -1214,6 +1249,14 @@ def _process_image_input(self,
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])

# Use grid information to get embedding sizes of each data item
merge_size = self.config.vision_config.spatial_merge_size
image_grids = [
torch.prod(image_grid) // merge_size // merge_size
for image_grid in image_input["image_grid_thw"]
]
image_embeds = image_embeds.split(image_grids)
return image_embeds

def _process_video_input(self,
Expand All @@ -1225,18 +1268,15 @@ def _process_video_input(self,
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,
grid_thw=video_input["video_grid_thw"])
return video_embeds

def _merge_multimodal_embeddings(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: torch.Tensor,
placeholder_token_id: int,
) -> torch.Tensor:
mask = (input_ids == placeholder_token_id)
inputs_embeds[mask, :] = multimodal_embeddings
return inputs_embeds
# Use grid information to get embedding sizes of each data item
merge_size = self.config.vision_config.spatial_merge_size
video_grids = [
torch.prod(video_grid) // merge_size // merge_size
for video_grid in video_input["video_grid_thw"]
]
video_embeds = video_embeds.split(video_grids)
return video_embeds

def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
Expand All @@ -1246,16 +1286,15 @@ def get_multimodal_embeddings(
if image_input is None and video_input is None:
return None

# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []

if image_input is not None:
image_embeds = self._process_image_input(image_input)
multimodal_embeddings.append((image_embeds, "image"))
return image_embeds

# We add a modality key along with the Nested tensor as a
# temporary solution to differentiate embeddings from modalities
# other than `image`.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
if video_input is not None:
video_embeds = self._process_video_input(video_input)
multimodal_embeddings.append((video_embeds, "video"))
Expand All @@ -1270,21 +1309,16 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.image_token_id,
)
if modality == "video":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.video_token_id,
)
if len(multimodal_embeddings[0]) == 2:
for embeddings, modality in multimodal_embeddings:
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_id)
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds

def forward(
Expand Down
10 changes: 6 additions & 4 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens(
return new_prompt, new_token_ids, placeholder_ranges


def consecutive_placeholder_ranges(num_items: int,
item_size: int) -> List[PlaceholderRange]:
def consecutive_placeholder_ranges(
num_items: int,
item_size: int,
initial_offset: int = 0) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""

return [
PlaceholderRange(offset=i * item_size, length=item_size)
for i in range(num_items)
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]
4 changes: 2 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def __init__(
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 8192
self.max_num_encoder_input_tokens = 16384
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=8192)
self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down

0 comments on commit 6d0df5a

Please sign in to comment.