diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0020562c3c3a..c605d02a36b70 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1232,12 +1232,13 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: - logger.warning( - "Prefix caching is currently not supported for multimodal " - "models and has been disabled.") - engine_config.cache_config.enable_prefix_caching = False + # TODO (ywang96): Enable APC by default when VLM supports it. + assert not engine_config.cache_config.enable_prefix_caching + + # NOTE: multimodal models support chunked prefill by design, + # thus always enabled in V1. + engine_config.scheduler_config.enable_chunked_prefill = True @dataclass diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7956a98b21569..4eb13081387fd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -63,8 +63,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, - MultiModalKwargs, NestedTensors) -from vllm.multimodal.utils import cached_get_tokenizer + MultiModalKwargs, NestedTensors, + PlaceholderRange) +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges) from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope @@ -73,7 +75,8 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) + make_empty_intermediate_tensors_factory, maybe_prefix, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -747,6 +750,7 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, _get_max_image_info(image_processor, data_type_key=data_type_key, mm_count=1, min_pixels=min_pixels, max_pixels=max_pixels) + print("max_llm_image_tokens", max_llm_image_tokens) return max_llm_image_tokens @@ -803,11 +807,18 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - - return DummyData(dummy_seqdata, { + dummy_multimodal_data = { + "image": dummy_image if num_images == 1 else [dummy_image] * num_images + } + size_per_image = max_llm_image_tokens // num_images + dummy_mm_placeholders = { "image": - dummy_image if num_images == 1 else [dummy_image] * num_images - }) + consecutive_placeholder_ranges(num_items=num_images, + item_size=size_per_image, + initial_offset=1) + } + return DummyData(dummy_seqdata, dummy_multimodal_data, + dummy_mm_placeholders) def _get_llm_num_vision_tokens( @@ -839,10 +850,11 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens -def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, - data_type_key: str, image_processor: Any, - prompt_token_ids: List[int], min_pixels: Optional[int], - max_pixels: Optional[int]) -> List[int]: +def _expand_pad_tokens( + inputs: list, token_id: int, make_batched_fn: Callable, + data_type_key: str, image_processor: Any, prompt_token_ids: List[int], + min_pixels: Optional[int], + max_pixels: Optional[int]) -> Tuple[List[int], List[PlaceholderRange]]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -858,6 +870,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, Returns: List[int]: The list of token IDs for the multi-modal inputs. + List[PlaceholderRange]]: The list of PlaceholderRange objects with + the positions of the pad token in the prompt token ids. """ indices = [ idx for idx, token in enumerate(prompt_token_ids) if token == token_id @@ -866,6 +880,7 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, assert len(indices) == len(inputs) prompt_token_ids_with_data = [] + placeholder_ranges = [] for cnt, data in enumerate(inputs): num_tokens = _get_llm_num_vision_tokens( [data] if data_type_key == "image" else data, @@ -881,9 +896,12 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, non_data_tokens = prompt_token_ids[indices[cnt - 1] + 1:indices[cnt]] prompt_token_ids_with_data.extend(non_data_tokens) + placeholder_ranges.append( + PlaceholderRange(offset=len(prompt_token_ids_with_data), + length=num_tokens)) prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) - return prompt_token_ids_with_data + return prompt_token_ids_with_data, placeholder_ranges def input_processor_for_qwen2_vl( @@ -929,7 +947,7 @@ def input_processor_for_qwen2_vl( prompt_token_ids = inputs["prompt_token_ids"] # Expand image pad tokens. - + multi_modal_placeholders = {} if image_inputs is not None: if isinstance(image_inputs, dict): prompt_token_ids_with_image = [] @@ -945,6 +963,7 @@ def input_processor_for_qwen2_vl( image_counter = 0 pad_token_counter = 0 + placeholder_ranges = [] for idx, token in enumerate(prompt_token_ids): if idx in image_indices: grid_thw = image_inputs["image_grid_thw"][image_counter] @@ -952,6 +971,10 @@ def input_processor_for_qwen2_vl( num_pad_tokens = (grid_t * grid_h * grid_w // image_processor.merge_size // image_processor.merge_size) + placeholder_ranges.append( + PlaceholderRange( + offset=len(prompt_token_ids_with_image), + length=num_pad_tokens)) prompt_token_ids_with_image.extend([token] * num_pad_tokens) image_counter += 1 @@ -966,14 +989,17 @@ def input_processor_for_qwen2_vl( prompt_token_ids = prompt_token_ids_with_image else: - prompt_token_ids = _expand_pad_tokens(image_inputs, - hf_config.image_token_id, - make_batched_images, - "image", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + prompt_token_ids, placeholder_ranges = _expand_pad_tokens( + image_inputs, + hf_config.image_token_id, + make_batched_images, + "image", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) + + multi_modal_placeholders["image"] = placeholder_ranges if video_inputs is not None: if isinstance(video_inputs, dict): @@ -990,6 +1016,7 @@ def input_processor_for_qwen2_vl( video_counter = 0 pad_token_counter = 0 + placeholder_ranges = [] for idx, token in enumerate(prompt_token_ids): if idx in video_indices: grid_thw = video_inputs["video_grid_thw"][video_counter] @@ -997,6 +1024,10 @@ def input_processor_for_qwen2_vl( num_pad_tokens = (grid_t * grid_h * grid_w // image_processor.merge_size // image_processor.merge_size) + placeholder_ranges.append( + PlaceholderRange( + offset=len(prompt_token_ids_with_image), + length=num_pad_tokens)) prompt_token_ids_with_video.extend([token] * num_pad_tokens) video_counter += 1 @@ -1011,14 +1042,17 @@ def input_processor_for_qwen2_vl( prompt_token_ids = prompt_token_ids_with_video else: - prompt_token_ids = _expand_pad_tokens(video_inputs, - hf_config.video_token_id, - make_batched_videos, - "video", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + prompt_token_ids, placeholder_ranges = _expand_pad_tokens( + video_inputs, + hf_config.video_token_id, + make_batched_videos, + "video", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) + + multi_modal_placeholders["video"] = placeholder_ranges prompt = inputs.get("prompt") if prompt is None: @@ -1028,6 +1062,7 @@ def input_processor_for_qwen2_vl( prompt_token_ids=prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, + multi_modal_placeholders=multi_modal_placeholders, ) @@ -1214,6 +1249,14 @@ def _process_image_input(self, pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) + + # Use grid information to get embedding sizes of each data item + merge_size = self.config.vision_config.spatial_merge_size + image_grids = [ + torch.prod(image_grid) // merge_size // merge_size + for image_grid in image_input["image_grid_thw"] + ] + image_embeds = image_embeds.split(image_grids) return image_embeds def _process_video_input(self, @@ -1225,18 +1268,15 @@ def _process_video_input(self, self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_input["video_grid_thw"]) - return video_embeds - def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, - ) -> torch.Tensor: - mask = (input_ids == placeholder_token_id) - inputs_embeds[mask, :] = multimodal_embeddings - return inputs_embeds + # Use grid information to get embedding sizes of each data item + merge_size = self.config.vision_config.spatial_merge_size + video_grids = [ + torch.prod(video_grid) // merge_size // merge_size + for video_grid in video_input["video_grid_thw"] + ] + video_embeds = video_embeds.split(video_grids) + return video_embeds def get_multimodal_embeddings( self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: @@ -1246,16 +1286,15 @@ def get_multimodal_embeddings( if image_input is None and video_input is None: return None - # We make a tuple of each embedding with its modality string. This is a - # temporary workaround for models to handle mixed modalities when - # get_multimodal_embeddings and get_input_embeddings are called - # separately. - # TODO(ywang96): Add support for mixed-modality inference for v1. - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] - if image_input is not None: image_embeds = self._process_image_input(image_input) - multimodal_embeddings.append((image_embeds, "image")) + return image_embeds + + # We add a modality key along with the Nested tensor as a + # temporary solution to differentiate embeddings from modalities + # other than `image`. + # TODO(ywang96): Add support for mixed-modality inference for v1. + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] if video_input is not None: video_embeds = self._process_video_input(video_input) multimodal_embeddings.append((video_embeds, "video")) @@ -1270,21 +1309,16 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - for embeddings, modality in multimodal_embeddings: - if modality == "image": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.image_token_id, - ) - if modality == "video": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.video_token_id, - ) + if len(multimodal_embeddings[0]) == 2: + for embeddings, modality in multimodal_embeddings: + if modality == "video": + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.video_token_id) + else: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_id) return inputs_embeds def forward( diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d4333b7519b47..c898ca4e6573e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f8375cea2a24e..1203d35fc985f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,12 +73,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 8192 + self.max_num_encoder_input_tokens = 16384 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=8192) + self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: