Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Dec 22, 2024
1 parent 29c7489 commit c0dd548
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
12 changes: 11 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,17 @@ def _process_image_input(self,

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)

feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]

image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,13 @@ def forward(
for img in pixel_values
]

patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]

# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
Expand Down Expand Up @@ -1042,6 +1046,8 @@ def forward(
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)

# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out

# (TODO) Add prefix argument for filtering out weights to be loaded
Expand Down

0 comments on commit c0dd548

Please sign in to comment.