diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0662d90e79b92..cb4642edbd4fa 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -477,7 +477,17 @@ def _process_image_input(self, assert self.vision_tower is not None image_features = self._process_image_pixels(image_input) - return self.multi_modal_projector(image_features) + + if isinstance(image_features, torch.Tensor): + return self.multi_modal_projector(image_features) + + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + + image_embeds = self.multi_modal_projector(torch.cat(image_features)) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f3d66c2313198..31d9b4dd684c9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1009,9 +1009,13 @@ def forward( for img in pixel_values ] + patch_embeds = [ + p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list + ] + embed_sizes = [p.shape[1] for p in patch_embeds] + # flatten to a single sequence - patch_embeds = torch.cat( - [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -1042,6 +1046,8 @@ def forward( out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, self.config.num_hidden_layers) + # squeeze dim 0 and split into separate tensors for each image + out = torch.split(torch.squeeze(out), embed_sizes) return out # (TODO) Add prefix argument for filtering out weights to be loaded