diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7d7639b4a92ce..d2592016aff34 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData @@ -609,6 +610,25 @@ def _process_image_input(self, return self.language_projection(query_output) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + BLIP2_IMAGE_TOKEN_ID) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -616,6 +636,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: """Run forward pass for BLIP-2. @@ -648,32 +669,24 @@ def forward( See also: :class:`Blip2ImageInputs` """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) - - input_ids = None - else: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 5a6d6432112f0..a40c321ce0a58 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -29,6 +29,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) @@ -38,7 +39,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. @@ -987,6 +988,29 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values(pixel_values), ) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens(image_input["data"].to( + self.config.torch_dtype)) + vision_embeddings = self.model.get_input_embeddings(image_tokens) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.model.vocabulary_mapping.image_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -994,27 +1018,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) input_ids = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens( - image_input["data"].to(self.config.torch_dtype)) - image_token_id = self.model.vocabulary_mapping.image_token_id - special_image_mask = input_ids == image_token_id - image_tokens = image_tokens.to(input_ids.device, - input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, - image_tokens) - - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5bcbce7180ca4..6c50882d83c3b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -33,7 +33,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, + NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -545,6 +546,30 @@ def _parse_and_validate_image_input( """) return GLMImagePixelInputs(pixel_values=pixel_values) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input["pixel_values"] is None: + return None + pixel_values = image_input["pixel_values"].to( + dtype=self.config.torch_dtype) + vision_embeddings = self.vision(pixel_values) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.embedding(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_glm_vision_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=multimodal_embeddings, + boi_token_id=self.config.boi_token_id, + eoi_token_id=self.config.eoi_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -552,26 +577,17 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: - if intermediate_tensors is None: - inputs_embeds = self.embedding(input_ids) - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input["pixel_values"] is not None: - pixel_values = image_input["pixel_values"].to( - dtype=inputs_embeds.dtype) - image_embeds = self.vision(pixel_values) - - boi_token_id = self.config.boi_token_id - eoi_token_id = self.config.eoi_token_id - - inputs_embeds = merge_glm_vision_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - vision_embeddings=image_embeds, - boi_token_id=boi_token_id, - eoi_token_id=eoi_token_id) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + if intermediate_tensors is None and inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None else: inputs_embeds = intermediate_tensors["hidden_states"] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7b46907ac83ab..6e86900326c4b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,6 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -302,6 +303,25 @@ def _process_image_input( vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) return vision_embeddings + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _IMAGE_TOKEN_ID) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -309,24 +329,19 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.embed_tokens( - input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) - - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9b4a97abf9b51..1545ce332309f 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,7 +2,7 @@ Protocol, Type, Union, overload, runtime_checkable) import torch -from typing_extensions import TypeIs +from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger from vllm.utils import supports_kw @@ -10,10 +10,14 @@ from .interfaces_base import is_embedding_model if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.multimodal.inputs import NestedTensors # noqa: F401 from vllm.sequence import IntermediateTensors logger = init_logger(__name__) +T = TypeVar("T", default="NestedTensors") + @runtime_checkable class SupportsMultiModal(Protocol): @@ -28,6 +32,36 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: + """ + Returns multimodal embeddings generated from multimodal kwargs + to be merged with text embeddings. + """ + ... + + # Only for models that support v0 chunked prefill + # TODO(ywang96): Remove this overload once v0 is deprecated + @overload + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> torch.Tensor: + ... + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + ) -> torch.Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + ... + # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 47ac00b6afe9b..b1c0065afbf30 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,6 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -641,6 +642,26 @@ def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: visual_token_mask = None return visual_token_mask + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_context_token_id is not None + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.img_context_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -648,26 +669,22 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: + + visual_token_mask = None if intermediate_tensors is not None: input_ids = None inputs_embeds = None - visual_token_mask = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.img_context_token_id) - visual_token_mask = self._get_visual_token_mask(input_ids) - input_ids = None - else: - inputs_embeds = None - visual_token_mask = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None forward_kwargs = { "input_ids": input_ids, @@ -677,6 +694,13 @@ def forward( "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } + if self.img_context_token_id is not None: + visual_token_mask = self._get_visual_token_mask(input_ids) + + # We always overwrite it back to None after computing visual token + # mask so that this doesn't need to depend on encoder output + self.img_context_token_id = None + if self.is_mono: forward_kwargs.update({"visual_token_mask": visual_token_mask}) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 05c6cc62efcd7..e7757b3c7d405 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -478,7 +478,7 @@ def _process_image_input(self, image_features = self._process_image_pixels(image_input) return self.multi_modal_projector(image_features) - def process_mm_inputs(self, **kwargs): + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None @@ -488,12 +488,12 @@ def process_mm_inputs(self, **kwargs): def get_input_embeddings( self, input_ids: torch.Tensor, - vision_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if vision_embeddings is not None: + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, + input_ids, inputs_embeds, multimodal_embeddings, self.config.image_token_index) return inputs_embeds @@ -544,10 +544,11 @@ def forward( """ if intermediate_tensors is not None: inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. elif inputs_embeds is None: - vision_embeddings = self.process_mm_inputs(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent + vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index abeebb45fc4a7..e113f5862830d 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -19,6 +19,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of @@ -565,6 +566,30 @@ def _process_image_input( for i, patch_features_batch in enumerate(patch_embeddings) ] + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + if multimodal_embeddings is None: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.image_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -572,6 +597,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT. @@ -620,24 +646,14 @@ def forward( """ if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - inputs_embeds = embed_multimodal( - input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - lambda _: self._process_image_input(image_input), - ) - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, @@ -645,7 +661,6 @@ def forward( attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) - return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index e2880c76cf43d..b130791808924 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -18,6 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -388,6 +389,25 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): raise ValueError( f"Unsupported type of video input {type(video_pixels)}") + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is None: + return None + vision_embeddings = self._process_video_pixels(video_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.video_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -395,6 +415,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT-Video. @@ -404,22 +425,15 @@ def forward( pixel_values_videos: Pixels in each frames for each input videos. """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - video_input = self._parse_and_validate_video_input(**kwargs) - if video_input is not None: - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = self.language_model \ - .model.get_input_embeddings(input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 705ca1e4ab6e6..3166737d61582 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -824,6 +825,49 @@ def apply_pooling(self, image_features, stride=2): image_feature = image_feature.view(batch_frames, -1, dim) return image_feature + def get_multimodal_embeddings( + self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # We make a tuple of each embedding with its modality string. This is a + # temporary workaround for models to handle mixed modalities when + # get_multimodal_embeddings and get_input_embeddings are called + # separately. + # TODO(ywang96): Add support for mixed-modality inference for v1. + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if "images" in modalities: + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings.append((vision_embeddings, "image")) + if "videos" in modalities: + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + multimodal_embeddings.append((video_embeddings, "video")) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[List[Tuple[NestedTensors, + str]]] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + for embeddings, modality in multimodal_embeddings: + if modality == "image": + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, + self.config.image_token_index) + if modality == "video": + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, + self.config.video_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -831,6 +875,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-Onevision. @@ -840,28 +885,15 @@ def forward( pixel_values_videos: Pixels in each frames for each input videos. """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if modalities: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - if "images" in modalities: - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - if "videos" in modalities: - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index ee7b560fe1ee4..acedddd84d7cb 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -3,7 +3,7 @@ from array import array from dataclasses import dataclass from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union +from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict import torch from einops import rearrange @@ -36,6 +36,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.platforms import _Backend from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -756,6 +757,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -1098,19 +1105,16 @@ def _process_image_input( return image_features - def _merge_multimodal_embeddings( - self, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - image_input_idx: torch.Tensor, - seq_len: Union[torch.Tensor, List[torch.Tensor]], - ) -> torch.Tensor: + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + image_features = self._process_image_input(image_input) + image_input_idx = image_input["image_input_idx"] + seq_len = image_input["seq_len"] batch_size, num_image, num_patch = image_features.shape[:3] assert image_input_idx.shape == (batch_size, num_image, num_patch) - image_features = image_features.to(inputs_embeds.device) - seq_len = seq_len.to(inputs_embeds.device) - # insert the image feature into the embedding. image_features = image_features.view(batch_size, num_image * num_patch, -1) @@ -1130,12 +1134,24 @@ def _merge_multimodal_embeddings( image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) image_input_idx = image_input_idx.flatten()[:, None] mat = image_input_idx == torch.arange( - seq_len.sum().item(), device=inputs_embeds.device)[None, :] + seq_len.sum().item(), device=image_features.device)[None, :] mat = mat.to(image_features.dtype) - inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', - image_features, mat) + # Note: In this original implementation from AI2, the final + # vision_embeddings will be always be the same length + # of input embedddings, which is not very efficient. + # TODO(ywang96): see if this can be optimized. + vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + return vision_embeddings + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = inputs_embeds + multimodal_embeddings return inputs_embeds def forward( @@ -1145,39 +1161,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> SamplerOutput: + if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - inputs_embeds = self.model.embed_tokens(input_ids) - image_features = self._process_image_input(image_input) - - inputs_embeds = self._merge_multimodal_embeddings( - inputs_embeds, - image_features, - image_input["image_input_idx"], - image_input["seq_len"], - ) - else: - inputs_embeds = self.model.embed_tokens(input_ids) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index dd5256eb87ab3..2e5b6bee784e7 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors @@ -240,36 +241,45 @@ def _process_image_input( return self.multi_modal_projector(image_features) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - parsed_image_input = self._parse_and_validate_image_input(**kwargs) - - if parsed_image_input is not None: - vision_embeddings = self._process_image_input( - parsed_image_input) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa - vision_embeddings = vision_embeddings * ( - self.config.hidden_size**-0.5) - - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2e583bb08e87a..4cb874a13e0c1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -676,7 +676,7 @@ def _process_image_input( return image_embeds - def process_mm_inputs(self, **kwargs): + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None @@ -686,12 +686,12 @@ def process_mm_inputs(self, **kwargs): def get_input_embeddings( self, input_ids: torch.Tensor, - vision_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) - if vision_embeddings is not None: + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, + input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id) return inputs_embeds @@ -703,12 +703,14 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): + if intermediate_tensors is not None: inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility elif inputs_embeds is None: - vision_embeddings = self.process_mm_inputs(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent + vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 0c2374c3c3fc9..a0605fee82aca 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -42,10 +42,12 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -371,6 +373,25 @@ def _process_audio_input(self, return masked_audio_features + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -378,33 +399,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is None: - inputs_embeds = None - else: - inputs_embeds = self.language_model.embed_tokens(input_ids) - masked_audio_features = self._process_audio_input(audio_input) - # merge llm embeddings and audio features - mask = (input_ids == self.config.audio_token_index) - inputs_embeds[mask, :] = masked_audio_features - - input_ids = None - - hidden_states = self.language_model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 531608a877f2f..7956a98b21569 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -63,7 +63,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, - MultiModalKwargs) + MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData @@ -1238,6 +1238,55 @@ def _merge_multimodal_embeddings( inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds + def get_multimodal_embeddings( + self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + if image_input is None and video_input is None: + return None + + # We make a tuple of each embedding with its modality string. This is a + # temporary workaround for models to handle mixed modalities when + # get_multimodal_embeddings and get_input_embeddings are called + # separately. + # TODO(ywang96): Add support for mixed-modality inference for v1. + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[List[Tuple[NestedTensors, + str]]] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + for embeddings, modality in multimodal_embeddings: + if modality == "image": + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + embeddings, + placeholder_token_id=self.config.image_token_id, + ) + if modality == "video": + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + embeddings, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -1245,6 +1294,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. @@ -1266,42 +1316,26 @@ def forward( video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.model.embed_tokens(input_ids) - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - input_ids = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + + # We need to check for usage of mrope here in case there is + # multimodal data. + # TODO (ywang96): move this to model runner in V1. + if multimodal_embeddings is not None and uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 512adbc7db35e..b61deccde45b7 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -449,10 +449,36 @@ def _process_audio_input( return result - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + audio_embeddings = self._process_audio_input(audio_input) + return audio_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + + # TODO(ywang96): use merge_multimodal_embeddings after + # v0 is deprecated + merge_multimodal_embeddings_from_map( + inputs_embeds, multimodal_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[torch.Tensor], + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox @@ -466,30 +492,28 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Args: audio_features: A batch of audio inputs [B, N, 80, M]. """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is not None: - audio_embeddings = self._process_audio_input(audio_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - merge_multimodal_embeddings_from_map( - inputs_embeds, audio_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["audio"]) - input_ids = None - else: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + + # TODO(ywang96): remove attn_metadata from get_input_embeddings + # after v0 is deprecated + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings, + attn_metadata) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index dcfd2cb7d2622..4c13cbc953273 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -356,8 +356,7 @@ def embed_multimodal( input_ids: torch.Tensor, multimodal_token_id: int, get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor, - List[torch.Tensor]]], + multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]], ) -> torch.Tensor: """ Embed token IDs and multimodal inputs and combine their embeddings. @@ -374,8 +373,6 @@ def embed_multimodal( is_text = ~is_multimodal text_embeds = get_text_embeds(input_ids[is_text]) - multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal]) - merged_embeds = torch.empty( (input_ids.shape[0], text_embeds.shape[1]), dtype=text_embeds.dtype, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 13cbc8fa39c03..1fa47f553dfd6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -363,7 +363,8 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list (length: num_images) of tensors, each of shape # [feature_size, hidden_size] in case when the feature size is # dynamic depending on input images. - encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs) + encoder_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) # Cache the encoder outputs. for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):