diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index e5d2edbd81eb1..17e772e7faa32 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -60,7 +60,8 @@ class Idefics3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """ - Shape: `(batch_size * num_images, num_channels, height, width)` + Shape: `(batch_size * num_images * num_patches, + num_channels, height, width)` """ pixel_attention_mask: Optional[torch.BoolTensor] @@ -520,13 +521,17 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - return Idefics3ImagePixelInputs(type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, - concat=True)), - pixel_attention_mask=flatten_bn( - pixel_attention_mask, - concat=True)) + if isinstance(pixel_values, list): + pixel_values = torch.cat(pixel_values, dim=1) + pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1) + else: + pixel_values = flatten_bn(pixel_values) + pixel_attention_mask = flatten_bn(pixel_attention_mask) + + return Idefics3ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + pixel_attention_mask=pixel_attention_mask) raise AssertionError("This line should be unreachable.")